| from types import SimpleNamespace |
| import unittest |
| from unittest.mock import patch |
|
|
| import torch |
|
|
| from utils import models |
|
|
|
|
| class FakeTokenizer: |
| def __init__(self): |
| self.calls = [] |
|
|
| def __call__( |
| self, |
| text, |
| padding, |
| truncation, |
| max_length, |
| return_tensors, |
| ): |
| self.calls.append( |
| { |
| 'text': text, |
| 'padding': padding, |
| 'truncation': truncation, |
| 'max_length': max_length, |
| 'return_tensors': return_tensors, |
| } |
| ) |
| return { |
| 'input_ids': torch.tensor([[101, 102, 0]]), |
| 'attention_mask': torch.tensor([[1, 1, 0]]), |
| } |
|
|
|
|
| class FakeModel: |
| def __init__(self): |
| self.device = torch.device('cpu') |
| self.eval_called = False |
|
|
| def to(self, device): |
| self.device = torch.device(device) |
| return self |
|
|
| def eval(self): |
| self.eval_called = True |
| return self |
|
|
| def __call__(self, **inputs): |
| hidden_states = torch.tensor( |
| [[[3.0, 0.0], [0.0, 4.0], [5.0, 12.0]]], |
| device=self.device, |
| ) |
| return SimpleNamespace(last_hidden_state=hidden_states) |
|
|
|
|
| class Qwen3EmbeddingTest(unittest.TestCase): |
| def tearDown(self): |
| models.get_embedding_model.cache_clear() |
|
|
| def test_qwen3_model_is_available_without_changing_default(self): |
| self.assertEqual( |
| models.list_models[0], |
| 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2', |
| ) |
| self.assertIn(models.QWEN3_EMBEDDING_MODEL, models.list_models) |
|
|
| def test_qwen3_uses_official_pooling_shape_and_normalization(self): |
| tokenizer = FakeTokenizer() |
| model = FakeModel() |
| models.get_embedding_model.cache_clear() |
|
|
| with ( |
| patch.object(models.AutoTokenizer, 'from_pretrained', return_value=tokenizer) as load_tokenizer, |
| patch.object(models.AutoModel, 'from_pretrained', return_value=model) as load_model, |
| ): |
| embedding_model = models.get_embedding_model(models.QWEN3_EMBEDDING_MODEL) |
| embedding = embedding_model('hello', pooling='cls') |
|
|
| load_tokenizer.assert_called_once_with( |
| models.QWEN3_EMBEDDING_MODEL, |
| padding_side='left', |
| ) |
| load_model.assert_called_once_with(models.QWEN3_EMBEDDING_MODEL) |
| self.assertTrue(model.eval_called) |
| self.assertEqual(tokenizer.calls[0]['max_length'], 8192) |
| self.assertEqual(tuple(embedding.shape), (2,)) |
| torch.testing.assert_close(embedding.cpu(), torch.tensor([0.0, 1.0])) |
| self.assertAlmostEqual(torch.linalg.vector_norm(embedding).item(), 1.0) |
|
|
|
|
| if __name__ == '__main__': |
| unittest.main() |
|
|