Qifan Zhang
Add Qwen3 embedding support with last-token pooling
b91a6bd
from types import SimpleNamespace
import unittest
from unittest.mock import patch
import torch
from utils import models
class FakeTokenizer:
def __init__(self):
self.calls = []
def __call__(
self,
text,
padding,
truncation,
max_length,
return_tensors,
):
self.calls.append(
{
'text': text,
'padding': padding,
'truncation': truncation,
'max_length': max_length,
'return_tensors': return_tensors,
}
)
return {
'input_ids': torch.tensor([[101, 102, 0]]),
'attention_mask': torch.tensor([[1, 1, 0]]),
}
class FakeModel:
def __init__(self):
self.device = torch.device('cpu')
self.eval_called = False
def to(self, device):
self.device = torch.device(device)
return self
def eval(self):
self.eval_called = True
return self
def __call__(self, **inputs):
hidden_states = torch.tensor(
[[[3.0, 0.0], [0.0, 4.0], [5.0, 12.0]]],
device=self.device,
)
return SimpleNamespace(last_hidden_state=hidden_states)
class Qwen3EmbeddingTest(unittest.TestCase):
def tearDown(self):
models.get_embedding_model.cache_clear()
def test_qwen3_model_is_available_without_changing_default(self):
self.assertEqual(
models.list_models[0],
'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
)
self.assertIn(models.QWEN3_EMBEDDING_MODEL, models.list_models)
def test_qwen3_uses_official_pooling_shape_and_normalization(self):
tokenizer = FakeTokenizer()
model = FakeModel()
models.get_embedding_model.cache_clear()
with (
patch.object(models.AutoTokenizer, 'from_pretrained', return_value=tokenizer) as load_tokenizer,
patch.object(models.AutoModel, 'from_pretrained', return_value=model) as load_model,
):
embedding_model = models.get_embedding_model(models.QWEN3_EMBEDDING_MODEL)
embedding = embedding_model('hello', pooling='cls')
load_tokenizer.assert_called_once_with(
models.QWEN3_EMBEDDING_MODEL,
padding_side='left',
)
load_model.assert_called_once_with(models.QWEN3_EMBEDDING_MODEL)
self.assertTrue(model.eval_called)
self.assertEqual(tokenizer.calls[0]['max_length'], 8192)
self.assertEqual(tuple(embedding.shape), (2,))
torch.testing.assert_close(embedding.cpu(), torch.tensor([0.0, 1.0]))
self.assertAlmostEqual(torch.linalg.vector_norm(embedding).item(), 1.0)
if __name__ == '__main__':
unittest.main()