| from functools import lru_cache |
|
|
| import torch |
| import torch.nn.functional as F |
| from loguru import logger |
| from sentence_transformers import SentenceTransformer |
| from transformers import AutoTokenizer, AutoModel |
|
|
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
| QWEN3_EMBEDDING_MODEL = 'Qwen/Qwen3-Embedding-0.6B' |
|
|
| LEGACY_MODELS = [ |
| 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', |
| 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', |
| 'sentence-transformers/all-mpnet-base-v2', |
| 'sentence-transformers/all-MiniLM-L12-v2', |
| 'cyclone/simcse-chinese-roberta-wwm-ext', |
| 'bert-base-chinese', |
| 'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese', |
| ] |
|
|
| list_models = [*LEGACY_MODELS, QWEN3_EMBEDDING_MODEL] |
|
|
|
|
| class SBert: |
| def __init__(self, path): |
| logger.info(f'Start loading {self.__class__} from {path} ...') |
| self.model = SentenceTransformer(path, device=DEVICE) |
| logger.info(f'Load {self.__class__} from {path} ...') |
|
|
| @lru_cache(maxsize=10000) |
| def __call__(self, x) -> torch.Tensor: |
| y = self.model.encode(x, convert_to_tensor=True) |
| return y |
|
|
|
|
| class ModelWithPooling: |
| def __init__(self, path): |
| logger.info(f'Start loading {self.__class__} from {path} ...') |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| self.model = AutoModel.from_pretrained(path) |
| logger.info(f'Load {self.__class__} from {path} ...') |
|
|
| @lru_cache(maxsize=100) |
| @torch.no_grad() |
| def __call__(self, text: str, pooling='mean'): |
| inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt") |
| outputs = self.model(**inputs, output_hidden_states=True) |
|
|
| if pooling == 'cls': |
| o = outputs.last_hidden_state[:, 0] |
|
|
| elif pooling == 'pooler': |
| o = outputs.pooler_output |
|
|
| elif pooling in ['mean', 'last-avg']: |
| last = outputs.last_hidden_state.transpose(1, 2) |
| o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) |
|
|
| elif pooling == 'first-last-avg': |
| first = outputs.hidden_states[1].transpose(1, 2) |
| last = outputs.hidden_states[-1].transpose(1, 2) |
| first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) |
| last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) |
| avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) |
| o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) |
|
|
| else: |
| raise Exception(f'Unknown pooling {pooling}') |
|
|
| o = o.squeeze(0) |
| return o |
|
|
|
|
| class Qwen3Embedding: |
| def __init__(self, path): |
| logger.info(f'Start loading {self.__class__} from {path} ...') |
| self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left') |
| self.model = AutoModel.from_pretrained(path) |
| self.model.to(DEVICE) |
| self.model.eval() |
| logger.info(f'Load {self.__class__} from {path} ...') |
|
|
| @staticmethod |
| def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] |
| if left_padding: |
| return last_hidden_states[:, -1] |
|
|
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = last_hidden_states.shape[0] |
| return last_hidden_states[ |
| torch.arange(batch_size, device=last_hidden_states.device), |
| sequence_lengths, |
| ] |
|
|
| @lru_cache(maxsize=100) |
| @torch.no_grad() |
| def __call__(self, text: str, pooling='mean'): |
| inputs = self.tokenizer( |
| text, |
| padding=True, |
| truncation=True, |
| max_length=8192, |
| return_tensors='pt', |
| ) |
| inputs = {key: value.to(DEVICE) for key, value in inputs.items()} |
| outputs = self.model(**inputs) |
| embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask']) |
| embeddings = F.normalize(embeddings, p=2, dim=1) |
| return embeddings.squeeze(0) |
|
|
|
|
| @lru_cache(maxsize=8) |
| def get_embedding_model(model_name: str): |
| if model_name == QWEN3_EMBEDDING_MODEL: |
| return Qwen3Embedding(model_name) |
| return ModelWithPooling(model_name) |
|
|
|
|
| def test_sbert(): |
| m = SBert('bert-base-chinese') |
| o = m('hello') |
| print(o.size()) |
| assert o.size() == (768,) |
|
|
|
|
| def test_hf_model(): |
| m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese') |
| o = m('hello', pooling='cls') |
| print(o.size()) |
| assert o.size() == (768,) |
|
|