Qifan Zhang
Add Qwen3 embedding support with last-token pooling
b91a6bd
from functools import lru_cache
import torch
import torch.nn.functional as F
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
QWEN3_EMBEDDING_MODEL = 'Qwen/Qwen3-Embedding-0.6B'
LEGACY_MODELS = [
'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
'sentence-transformers/all-mpnet-base-v2',
'sentence-transformers/all-MiniLM-L12-v2',
'cyclone/simcse-chinese-roberta-wwm-ext',
'bert-base-chinese',
'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese',
]
list_models = [*LEGACY_MODELS, QWEN3_EMBEDDING_MODEL]
class SBert:
def __init__(self, path):
logger.info(f'Start loading {self.__class__} from {path} ...')
self.model = SentenceTransformer(path, device=DEVICE)
logger.info(f'Load {self.__class__} from {path} ...')
@lru_cache(maxsize=10000)
def __call__(self, x) -> torch.Tensor:
y = self.model.encode(x, convert_to_tensor=True)
return y
class ModelWithPooling:
def __init__(self, path):
logger.info(f'Start loading {self.__class__} from {path} ...')
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModel.from_pretrained(path)
logger.info(f'Load {self.__class__} from {path} ...')
@lru_cache(maxsize=100)
@torch.no_grad()
def __call__(self, text: str, pooling='mean'):
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
outputs = self.model(**inputs, output_hidden_states=True)
if pooling == 'cls':
o = outputs.last_hidden_state[:, 0] # [b, h]
elif pooling == 'pooler':
o = outputs.pooler_output # [b, h]
elif pooling in ['mean', 'last-avg']:
last = outputs.last_hidden_state.transpose(1, 2) # [b, h, s]
o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
elif pooling == 'first-last-avg':
first = outputs.hidden_states[1].transpose(1, 2) # [b, h, s]
last = outputs.hidden_states[-1].transpose(1, 2) # [b, h, s]
first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [b, h]
avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [b, 2, h]
o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [b, h]
else:
raise Exception(f'Unknown pooling {pooling}')
o = o.squeeze(0)
return o
class Qwen3Embedding:
def __init__(self, path):
logger.info(f'Start loading {self.__class__} from {path} ...')
self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
self.model = AutoModel.from_pretrained(path)
self.model.to(DEVICE)
self.model.eval()
logger.info(f'Load {self.__class__} from {path} ...')
@staticmethod
def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
if left_padding:
return last_hidden_states[:, -1]
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[
torch.arange(batch_size, device=last_hidden_states.device),
sequence_lengths,
]
@lru_cache(maxsize=100)
@torch.no_grad()
def __call__(self, text: str, pooling='mean'):
inputs = self.tokenizer(
text,
padding=True,
truncation=True,
max_length=8192,
return_tensors='pt',
)
inputs = {key: value.to(DEVICE) for key, value in inputs.items()}
outputs = self.model(**inputs)
embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings.squeeze(0)
@lru_cache(maxsize=8)
def get_embedding_model(model_name: str):
if model_name == QWEN3_EMBEDDING_MODEL:
return Qwen3Embedding(model_name)
return ModelWithPooling(model_name)
def test_sbert():
m = SBert('bert-base-chinese')
o = m('hello')
print(o.size())
assert o.size() == (768,)
def test_hf_model():
m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese')
o = m('hello', pooling='cls')
print(o.size())
assert o.size() == (768,)