File size: 4,735 Bytes
4ccf6a3
0e97d35
 
b91a6bd
e691ea0
cf575f8
dd2409d
3f6f474
cf575f8
b91a6bd
cf575f8
b91a6bd
d654474
4ccf6a3
 
 
dd2409d
 
 
d654474
 
b91a6bd
 
cf575f8
 
 
f54f0db
cf575f8
e691ea0
cf575f8
 
0e97d35
 
cf575f8
dd2409d
 
 
 
f54f0db
dd2409d
 
e691ea0
dd2409d
e691ea0
dd2409d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b91a6bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd2409d
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from functools import lru_cache

import torch
import torch.nn.functional as F
from loguru import logger
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
QWEN3_EMBEDDING_MODEL = 'Qwen/Qwen3-Embedding-0.6B'

LEGACY_MODELS = [
    'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
    'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',
    'sentence-transformers/all-mpnet-base-v2',
    'sentence-transformers/all-MiniLM-L12-v2',
    'cyclone/simcse-chinese-roberta-wwm-ext',
    'bert-base-chinese',
    'IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese',
]

list_models = [*LEGACY_MODELS, QWEN3_EMBEDDING_MODEL]


class SBert:
    def __init__(self, path):
        logger.info(f'Start loading {self.__class__} from {path} ...')
        self.model = SentenceTransformer(path, device=DEVICE)
        logger.info(f'Load {self.__class__} from {path} ...')

    @lru_cache(maxsize=10000)
    def __call__(self, x) -> torch.Tensor:
        y = self.model.encode(x, convert_to_tensor=True)
        return y


class ModelWithPooling:
    def __init__(self, path):
        logger.info(f'Start loading {self.__class__} from {path} ...')
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModel.from_pretrained(path)
        logger.info(f'Load {self.__class__} from {path} ...')

    @lru_cache(maxsize=100)
    @torch.no_grad()
    def __call__(self, text: str, pooling='mean'):
        inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
        outputs = self.model(**inputs, output_hidden_states=True)

        if pooling == 'cls':
            o = outputs.last_hidden_state[:, 0]  # [b, h]

        elif pooling == 'pooler':
            o = outputs.pooler_output  # [b, h]

        elif pooling in ['mean', 'last-avg']:
            last = outputs.last_hidden_state.transpose(1, 2)  # [b, h, s]
            o = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)  # [b, h]

        elif pooling == 'first-last-avg':
            first = outputs.hidden_states[1].transpose(1, 2)  # [b, h, s]
            last = outputs.hidden_states[-1].transpose(1, 2)  # [b, h, s]
            first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1)  # [b, h]
            last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)  # [b, h]
            avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1)  # [b, 2, h]
            o = torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)  # [b, h]

        else:
            raise Exception(f'Unknown pooling {pooling}')

        o = o.squeeze(0)
        return o


class Qwen3Embedding:
    def __init__(self, path):
        logger.info(f'Start loading {self.__class__} from {path} ...')
        self.tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left')
        self.model = AutoModel.from_pretrained(path)
        self.model.to(DEVICE)
        self.model.eval()
        logger.info(f'Load {self.__class__} from {path} ...')

    @staticmethod
    def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
        if left_padding:
            return last_hidden_states[:, -1]

        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[
            torch.arange(batch_size, device=last_hidden_states.device),
            sequence_lengths,
        ]

    @lru_cache(maxsize=100)
    @torch.no_grad()
    def __call__(self, text: str, pooling='mean'):
        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=8192,
            return_tensors='pt',
        )
        inputs = {key: value.to(DEVICE) for key, value in inputs.items()}
        outputs = self.model(**inputs)
        embeddings = self.last_token_pool(outputs.last_hidden_state, inputs['attention_mask'])
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings.squeeze(0)


@lru_cache(maxsize=8)
def get_embedding_model(model_name: str):
    if model_name == QWEN3_EMBEDDING_MODEL:
        return Qwen3Embedding(model_name)
    return ModelWithPooling(model_name)


def test_sbert():
    m = SBert('bert-base-chinese')
    o = m('hello')
    print(o.size())
    assert o.size() == (768,)


def test_hf_model():
    m = ModelWithPooling('IDEA-CCNL/Erlangshen-SimCSE-110M-Chinese')
    o = m('hello', pooling='cls')
    print(o.size())
    assert o.size() == (768,)