mkj69 commited on
Commit
d7f0c84
·
verified ·
1 Parent(s): d6d1177

Upload op_tokenizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. op_tokenizer.py +268 -0
op_tokenizer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Optional, Dict
3
+ from transformers import PreTrainedTokenizer
4
+ import os
5
+ import json
6
+ import re
7
+ default_config = {
8
+ "custom_digits": "0123456789ABCDEF",
9
+ "variable_atoms": {
10
+ "left_operand": "a", # 左操作数变量名
11
+ "right_operand": "b" # 右操作数变量名
12
+ },
13
+
14
+ "other_symbols_atoms": {
15
+ "left_parenthesis": "(", # 左括号
16
+ "right_parenthesis": ")", # 右括号
17
+ "equals_sign": "=", # 等号,常用于赋值或比较
18
+ "nan_symbol": "NaN", # 非数(Not a Number)
19
+ "inf_symbol": "Inf" # 无穷大(Infinity)
20
+ },
21
+
22
+ "operator_symbol_min_len": 1,
23
+ "operator_symbol_max_len": 3,
24
+
25
+ "basic_operator_symbols": ["+", "-", "*", "/", "%"],
26
+
27
+ "base_symbols": [
28
+ "≮⫘↔",
29
+ "⫏≰",
30
+ "⪩⨒∯",
31
+ "⇑⪆",
32
+ "↹⩛",
33
+ "≴∭⊉",
34
+ "⪪⊹⋣",
35
+ "⋋%⋟",
36
+ "⊺⇮",
37
+ "⋰*⋻",
38
+ "⫖↰⪸",
39
+ "⪎⋱⫍",
40
+ "⨗⨭⨅",
41
+ "⫶⩼⫲",
42
+ "∃⊬"
43
+ ],
44
+
45
+ "comparison_ops": ["==", ">", "<", ">=", "<=", "!="],
46
+
47
+ "logical_connectors": ["and", "or"],
48
+
49
+ "definition_symbols": [
50
+ ",",
51
+ ";",
52
+ "if",
53
+ "else",
54
+ "{",
55
+ "}",
56
+ "abs"
57
+ ]
58
+ }
59
+
60
+ class OpTokenizer(PreTrainedTokenizer):
61
+ def __init__(self, vocab_file, **kwargs):
62
+
63
+ self.param_config= default_config
64
+ self.vocab = self.load_vocab(vocab_file)
65
+ self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
66
+ super().__init__(**kwargs)
67
+ # 定义基础符号
68
+ self.basic_symbols = list("0123456789()=ABCDEFab")
69
+ self.special_results = ['NaN', 'Inf']
70
+ self.comparison_ops = ["==", ">", "<", ">=", "<=", "!="]
71
+ self.logical_connectors = ["and", "or"]
72
+ self.definition_symbols = [",", ";", "if", "else", "{", "}", "abs"]
73
+
74
+ self.token_regex = self.build_token_regex()
75
+
76
+ # 初始化特殊标记 ID
77
+ self.pad_id = self.vocab['[PAD]']
78
+ self.unk_id = self.vocab['[UNK]']
79
+ self.sep_id = self.vocab['[SEP]']
80
+ self.mask_id = self.vocab['[MASK]']
81
+ self.bos_id = self.vocab['[BOS]']
82
+ self.eos_id = self.vocab['[EOS]']
83
+ self.eod_id = self.vocab['[EOD]']
84
+
85
+ def load_vocab(self, vocab_file):
86
+ # 实现你的词表加载逻辑
87
+ with open(vocab_file, encoding="utf-8") as f:
88
+ vocab = json.load(f)
89
+ return vocab
90
+
91
+ def save_vocabulary(self, save_directory, filename_prefix=""):
92
+ if filename_prefix is None:
93
+ filename_prefix = ""
94
+
95
+ if not os.path.exists(save_directory):
96
+ os.makedirs(save_directory)
97
+
98
+ vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.json")
99
+
100
+ with open(vocab_file_path, "w", encoding="utf-8") as f:
101
+ json.dump(self.vocab, f, ensure_ascii=False, indent=4)
102
+
103
+ print(f"Vocabulary saved to {vocab_file_path}")
104
+
105
+ return (vocab_file_path,) # 返回元组而不是列表
106
+
107
+ def build_token_regex(self):
108
+ """构建分词正则表达式,逐字符、符号进行匹配"""
109
+ # 特殊结果的正则表达式(比如 NaN, Inf)
110
+ special_results = [re.escape(result) for result in self.special_results]
111
+ # 比较操作符的正则表达式
112
+ comparison_ops = [re.escape(op) for op in self.comparison_ops]
113
+ # 逻辑连接符的正则表达式
114
+ logical_connectors = [re.escape(connector) for connector in self.logical_connectors]
115
+
116
+ operator_pattern = r"(?P<OPERATOR>([+\-*/%]|[\u2200-\u22FF\u2A00-\u2BFF\u2190-\u21FF])+)"
117
+ variable_pattern = r"(?P<VARIABLE>[a-b])"
118
+ digit_pattern = r"(?P<DIGIT>[0-9A-F])"
119
+ special_result_pattern = r"(?P<SPECIAL_RESULT>" + "|".join(special_results) + ")"
120
+ comparison_ops_pattern = r"(?P<COMPARISON_OP>" + "|".join(comparison_ops) + ")"
121
+ logical_connectors_pattern = r"(?P<LOGICAL_CONNECTOR>" + "|".join(logical_connectors) + ")"
122
+ if_else_pattern = r"(?P<IF_ELSE>if|else)"
123
+ whitespace_pattern = r"(?P<WHITESPACE>\s+)"
124
+ abs_pattern = r"(?P<ABS>abs)"
125
+ punctuation_patterns = [
126
+ r"(?P<PARENTHESIS_LEFT>\()",
127
+ r"(?P<PARENTHESIS_RIGHT>\))",
128
+ r"(?P<CURLY_BRACE_LEFT>{)",
129
+ r"(?P<CURLY_BRACE_RIGHT>})",
130
+ r"(?P<SEMICOLON>;)",
131
+ r"(?P<COMMA>,)",
132
+ r"(?P<EQUAL>=)"
133
+ ]
134
+
135
+ # 所有模式结合在一起,注意先后顺序,应该先匹配长的
136
+ token_patterns = [
137
+ operator_pattern,
138
+ special_result_pattern, # 特殊符号(如 NaN, Inf)
139
+ comparison_ops_pattern, # 比较操作符
140
+ logical_connectors_pattern, # 逻辑连接符
141
+ if_else_pattern, # if 和 else
142
+ abs_pattern,
143
+ digit_pattern,
144
+ variable_pattern, # 小写字母(变量名)
145
+ whitespace_pattern, # 空格和换行符
146
+
147
+ ] + punctuation_patterns # 将标点符号的正则表达式添加到列表中
148
+
149
+ # 使用 | 连接所有模式
150
+ combined_pattern = "|".join(token_patterns)
151
+
152
+ # 返回编译后的正则表达式对象
153
+ return re.compile(combined_pattern)
154
+
155
+ def tokenize(self, text: str, mode: str = 'text', add_special_tokens: bool = True):
156
+ if mode == 'definition':
157
+ return self._tokenize_definition(text, add_special_tokens)
158
+ elif mode == 'text':
159
+ return self._tokenize_equation(text, add_special_tokens)
160
+ elif mode == 'withdef_text':
161
+ return self._tokenize_withdef_text(text, add_special_tokens)
162
+ else:
163
+ raise ValueError(f"Unsupported mode: {self.mode}")
164
+
165
+ def _tokenize_definition(self, text, add_special_tokens):
166
+ tokens = []
167
+ if add_special_tokens:
168
+ tokens.append('[DEF_START]')
169
+ for match in self.token_regex.finditer(text):
170
+ token_type = match.lastgroup
171
+ token_value = match.group(token_type)
172
+ if token_type != "WHITESPACE":
173
+ tokens.append(token_value)
174
+ if add_special_tokens:
175
+ tokens.append('[DEF_END]')
176
+ return tokens
177
+
178
+ def _tokenize_equation(self, text, add_special_tokens):
179
+ tokens = []
180
+ if add_special_tokens:
181
+ tokens.append('[EQ_START]')
182
+
183
+ self.digit_pattern = f"[{re.escape(self.param_config['custom_digits'])}]"
184
+ self.number_pattern = f"[-]?{self.digit_pattern}+"
185
+ self.base_symbols_pattern = f"(?:{'|'.join(map(re.escape, self.param_config['base_symbols']))})"
186
+ self.base_symbols_number_pattern = f"({self.base_symbols_pattern}{self.number_pattern})"
187
+
188
+ parts = re.split(self.base_symbols_number_pattern, text)
189
+ final_parts = []
190
+ for part in parts:
191
+ if re.search(self.number_pattern, part):
192
+ sub_parts = re.split(f"({self.number_pattern})", part)
193
+ final_parts.extend(sub_parts)
194
+ else:
195
+ final_parts.append(part)
196
+
197
+ for part in final_parts:
198
+ for match in self.token_regex.finditer(part):
199
+ token_type = match.lastgroup
200
+ token_value = match.group(token_type)
201
+ if token_type != "WHITESPACE":
202
+ tokens.append(token_value)
203
+
204
+ if add_special_tokens:
205
+ tokens.append('[EQ_END]')
206
+ return tokens
207
+
208
+ def _tokenize_withdef_text(self, text, add_special_tokens):
209
+ tokens = []
210
+ segments = re.split(r'(\[DEF_START\]|\[DEF_JOIN\]|\[DEF_END\]|\[EQ_START\]|\[EQ_END\])', text)
211
+ current_mode = None
212
+
213
+ for seg in segments:
214
+ seg = seg.strip()
215
+ if not seg:
216
+ continue
217
+
218
+ if seg in ['[DEF_START]', '[DEF_JOIN]']:
219
+ if add_special_tokens:
220
+ tokens.append(seg)
221
+ current_mode = 'definition'
222
+ elif seg == '[DEF_END]':
223
+ if add_special_tokens:
224
+ tokens.append(seg)
225
+ current_mode = None
226
+ elif seg == '[EQ_START]':
227
+ if add_special_tokens:
228
+ tokens.append(seg)
229
+ current_mode = 'text'
230
+ elif seg == '[EQ_END]':
231
+ if add_special_tokens:
232
+ tokens.append(seg)
233
+ current_mode = None
234
+ else:
235
+ if current_mode == 'definition':
236
+ inner_tokens = self._tokenize_definition(seg, add_special_tokens=False)
237
+ tokens.extend(inner_tokens)
238
+ elif current_mode == 'text':
239
+ inner_tokens = self._tokenize_equation(seg, add_special_tokens=False)
240
+ tokens.extend(inner_tokens)
241
+ else:
242
+ tokens.extend(seg.split())
243
+ return tokens
244
+
245
+
246
+ def convert_tokens_to_ids(self, tokens):
247
+ if isinstance(tokens[0], str):
248
+ return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens]
249
+ return tokens
250
+
251
+ def convert_ids_to_tokens(self, ids):
252
+ reverse_vocab = {v: k for k, v in self.vocab.items()}
253
+ return [reverse_vocab.get(i, '[UNK]') for i in ids]
254
+
255
+ def encode(self, text, mode=None, add_special_tokens=None):
256
+ tokens = self.tokenize(text, mode=mode, add_special_tokens=add_special_tokens)
257
+ return self.convert_tokens_to_ids(tokens)
258
+
259
+
260
+ def decode(self, ids, skip_special_tokens=False):
261
+ tokens = self.convert_ids_to_tokens(ids)
262
+ if skip_special_tokens:
263
+ tokens = [t for t in tokens if not (t.startswith('[') and t.endswith(']'))]
264
+ return " ".join(tokens).replace(" ##", "")
265
+
266
+ def get_vocab(self):
267
+ return self.vocab
268
+