File size: 2,802 Bytes
b91a6bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from types import SimpleNamespace
import unittest
from unittest.mock import patch

import torch

from utils import models


class FakeTokenizer:
    def __init__(self):
        self.calls = []

    def __call__(
        self,
        text,
        padding,
        truncation,
        max_length,
        return_tensors,
    ):
        self.calls.append(
            {
                'text': text,
                'padding': padding,
                'truncation': truncation,
                'max_length': max_length,
                'return_tensors': return_tensors,
            }
        )
        return {
            'input_ids': torch.tensor([[101, 102, 0]]),
            'attention_mask': torch.tensor([[1, 1, 0]]),
        }


class FakeModel:
    def __init__(self):
        self.device = torch.device('cpu')
        self.eval_called = False

    def to(self, device):
        self.device = torch.device(device)
        return self

    def eval(self):
        self.eval_called = True
        return self

    def __call__(self, **inputs):
        hidden_states = torch.tensor(
            [[[3.0, 0.0], [0.0, 4.0], [5.0, 12.0]]],
            device=self.device,
        )
        return SimpleNamespace(last_hidden_state=hidden_states)


class Qwen3EmbeddingTest(unittest.TestCase):
    def tearDown(self):
        models.get_embedding_model.cache_clear()

    def test_qwen3_model_is_available_without_changing_default(self):
        self.assertEqual(
            models.list_models[0],
            'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
        )
        self.assertIn(models.QWEN3_EMBEDDING_MODEL, models.list_models)

    def test_qwen3_uses_official_pooling_shape_and_normalization(self):
        tokenizer = FakeTokenizer()
        model = FakeModel()
        models.get_embedding_model.cache_clear()

        with (
            patch.object(models.AutoTokenizer, 'from_pretrained', return_value=tokenizer) as load_tokenizer,
            patch.object(models.AutoModel, 'from_pretrained', return_value=model) as load_model,
        ):
            embedding_model = models.get_embedding_model(models.QWEN3_EMBEDDING_MODEL)
            embedding = embedding_model('hello', pooling='cls')

        load_tokenizer.assert_called_once_with(
            models.QWEN3_EMBEDDING_MODEL,
            padding_side='left',
        )
        load_model.assert_called_once_with(models.QWEN3_EMBEDDING_MODEL)
        self.assertTrue(model.eval_called)
        self.assertEqual(tokenizer.calls[0]['max_length'], 8192)
        self.assertEqual(tuple(embedding.shape), (2,))
        torch.testing.assert_close(embedding.cpu(), torch.tensor([0.0, 1.0]))
        self.assertAlmostEqual(torch.linalg.vector_norm(embedding).item(), 1.0)


if __name__ == '__main__':
    unittest.main()