File size: 4,735 Bytes
4ccf6a3 0e97d35 b91a6bd e691ea0 cf575f8 dd2409d 3f6f474 cf575f8 b91a6bd cf575f8 b91a6bd d654474 4ccf6a3 dd2409d d654474 b91a6bd cf575f8 f54f0db cf575f8 e691ea0 cf575f8 0e97d35 cf575f8 dd2409d f54f0db dd2409d e691ea0 dd2409d e691ea0 dd2409d b91a6bd dd2409d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from functools import lru_cache
import torch
import torch.nn.functional as F
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
QWEN3_EMBEDDING_MODEL = 'Qwen/Qwen3-Embedding-0.6B'
LEGACY_MODELS = [
'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
'sentence-transformers/all-mpnet-base-v2',
'sentence-transformers/all-MiniLM-L12-v2',
'cyclone/simcse-chinese-roberta-wwm-ext',
'bert-base-chinese',
'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese',
]
list_models = [*LEGACY_MODELS, QWEN3_EMBEDDING_MODEL]
class SBert:
def __init__(self, path):
logger.info(f'Start loading {self.__class__} from {path} ...')
self.model = SentenceTransformer(path, device=DEVICE)
logger.info(f'Load {self.__class__} from {path} ...')
@lru_cache(maxsize=10000)
def __call__(self, x) -> torch.Tensor:
y = self.model.encode(x, convert_to_tensor=True)
return y
class ModelWithPooling:
def __init__(self, path):
logger.info(f'Start loading {self.__class__} from {path} ...')
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModel.from_pretrained(path)
logger.info(f'Load {self.__class__} from {path} ...')
@lru_cache(maxsize=100)
@torch.no_grad()
def __call__(self, text: str, pooling='mean'):
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
outputs = self.model(**inputs, output_hidden_states=True)
if pooling == 'cls':
o = outputs.last_hidden_state[:, 0] # [b, h]
elif pooling == 'pooler':
o = outputs.pooler_output # [b, h]
elif pooling in ['mean', 'last-avg']:
last = outputs.last_hidden_state.transpose(1, 2) # [b, h, s]
o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
elif pooling == 'first-last-avg':
first = outputs.hidden_states[1].transpose(1, 2) # [b, h, s]
last = outputs.hidden_states[-1].transpose(1, 2) # [b, h, s]
first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [b, 2, h]
o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [b, h]
else:
raise Exception(f'Unknown pooling {pooling}')
o = o.squeeze(0)
return o
class Qwen3Embedding:
def __init__(self, path):
logger.info(f'Start loading {self.__class__} from {path} ...')
self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
self.model = AutoModel.from_pretrained(path)
self.model.to(DEVICE)
self.model.eval()
logger.info(f'Load {self.__class__} from {path} ...')
@staticmethod
def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
return last_hidden_states[:, -1]
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[
torch.arange(batch_size, device=last_hidden_states.device),
sequence_lengths,
]
@lru_cache(maxsize=100)
@torch.no_grad()
def __call__(self, text: str, pooling='mean'):
inputs = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=8192,
return_tensors='pt',
)
inputs = {key: value.to(DEVICE) for key, value in inputs.items()}
outputs = self.model(**inputs)
embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings.squeeze(0)
@lru_cache(maxsize=8)
def get_embedding_model(model_name: str):
if model_name == QWEN3_EMBEDDING_MODEL:
return Qwen3Embedding(model_name)
return ModelWithPooling(model_name)
def test_sbert():
m = SBert('bert-base-chinese')
o = m('hello')
print(o.size())
assert o.size() == (768,)
def test_hf_model():
m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese')
o = m('hello', pooling='cls')
print(o.size())
assert o.size() == (768,)
|