jason1966 commited on
Commit
d041524
·
verified ·
1 Parent(s): 20b71be

Upload conversation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. conversation.py +416 -0
conversation.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+
4
+ We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
+ If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
+
7
+ Modified from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
8
+ """
9
+
10
+ import dataclasses
11
+ from enum import IntEnum, auto
12
+ from typing import Dict, List, Tuple, Union
13
+
14
+
15
+ class SeparatorStyle(IntEnum):
16
+ """Separator styles."""
17
+
18
+ ADD_COLON_SINGLE = auto()
19
+ ADD_COLON_TWO = auto()
20
+ ADD_COLON_SPACE_SINGLE = auto()
21
+ NO_COLON_SINGLE = auto()
22
+ NO_COLON_TWO = auto()
23
+ ADD_NEW_LINE_SINGLE = auto()
24
+ LLAMA2 = auto()
25
+ CHATGLM = auto()
26
+ CHATML = auto()
27
+ CHATINTERN = auto()
28
+ DOLLY = auto()
29
+ RWKV = auto()
30
+ PHOENIX = auto()
31
+ ROBIN = auto()
32
+ FALCON_CHAT = auto()
33
+ CHATGLM3 = auto()
34
+ INTERNVL_ZH = auto()
35
+ MPT = auto()
36
+ QIANFANVL = auto()
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class Conversation:
41
+ """A class that manages prompt templates and keeps all conversation history."""
42
+
43
+ # The name of this template
44
+ name: str
45
+ # The template of the system prompt
46
+ system_template: str = '{system_message}'
47
+ # The system message
48
+ system_message: str = ''
49
+ # The names of two roles
50
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
51
+ # All messages. Each item is (role, message).
52
+ messages: List[List[str]] = ()
53
+ # The number of few shot examples
54
+ offset: int = 0
55
+ # The separator style and configurations
56
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
57
+ sep: str = '\n'
58
+ sep2: str = None
59
+ # Stop criteria (the default one is EOS token)
60
+ stop_str: Union[str, List[str]] = None
61
+ # Stops generation if meeting any token in this list
62
+ stop_token_ids: List[int] = None
63
+
64
+ def get_prompt(self) -> str:
65
+ """Get the prompt for generation."""
66
+ system_prompt = self.system_template.format(system_message=self.system_message)
67
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
68
+ ret = system_prompt + self.sep
69
+ for role, message in self.messages:
70
+ if message:
71
+ ret += role + ': ' + message + self.sep
72
+ else:
73
+ ret += role + ':'
74
+ return ret
75
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
76
+ seps = [self.sep, self.sep2]
77
+ ret = system_prompt + seps[0]
78
+ for i, (role, message) in enumerate(self.messages):
79
+ if message:
80
+ ret += role + ': ' + message + seps[i % 2]
81
+ else:
82
+ ret += role + ':'
83
+ return ret
84
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
85
+ ret = system_prompt + self.sep
86
+ for role, message in self.messages:
87
+ if message:
88
+ ret += role + ': ' + message + self.sep
89
+ else:
90
+ ret += role + ': ' # must be end with a space
91
+ return ret
92
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
93
+ ret = '' if system_prompt == '' else system_prompt + self.sep
94
+ for role, message in self.messages:
95
+ if message:
96
+ ret += role + '\n' + message + self.sep
97
+ else:
98
+ ret += role + '\n'
99
+ return ret
100
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
101
+ ret = system_prompt
102
+ for role, message in self.messages:
103
+ if message:
104
+ ret += role + message + self.sep
105
+ else:
106
+ ret += role
107
+ return ret
108
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
109
+ seps = [self.sep, self.sep2]
110
+ ret = system_prompt
111
+ for i, (role, message) in enumerate(self.messages):
112
+ if message:
113
+ ret += role + message + seps[i % 2]
114
+ else:
115
+ ret += role
116
+ return ret
117
+ elif self.sep_style == SeparatorStyle.RWKV:
118
+ ret = system_prompt
119
+ for i, (role, message) in enumerate(self.messages):
120
+ if message:
121
+ ret += (
122
+ role
123
+ + ': '
124
+ + message.replace('\r\n', '\n').replace('\n\n', '\n')
125
+ )
126
+ ret += '\n\n'
127
+ else:
128
+ ret += role + ':'
129
+ return ret
130
+ elif self.sep_style == SeparatorStyle.LLAMA2:
131
+ seps = [self.sep, self.sep2]
132
+ if self.system_message:
133
+ ret = system_prompt
134
+ else:
135
+ ret = '[INST] '
136
+ for i, (role, message) in enumerate(self.messages):
137
+ tag = self.roles[i % 2]
138
+ if message:
139
+ if i == 0:
140
+ ret += message + ' '
141
+ else:
142
+ ret += tag + ' ' + message + seps[i % 2]
143
+ else:
144
+ ret += tag
145
+ return ret
146
+ elif self.sep_style == SeparatorStyle.CHATGLM:
147
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
148
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
149
+ round_add_n = 1 if self.name == 'chatglm2' else 0
150
+ if system_prompt:
151
+ ret = system_prompt + self.sep
152
+ else:
153
+ ret = ''
154
+
155
+ for i, (role, message) in enumerate(self.messages):
156
+ if i % 2 == 0:
157
+ ret += f'[Round {i//2 + round_add_n}]{self.sep}'
158
+
159
+ if message:
160
+ ret += f'{role}:{message}{self.sep}'
161
+ else:
162
+ ret += f'{role}:'
163
+ return ret
164
+ elif self.sep_style == SeparatorStyle.CHATML:
165
+ ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
166
+ for role, message in self.messages:
167
+ if message:
168
+ ret += role + '\n' + message + self.sep + '\n'
169
+ else:
170
+ ret += role + '\n'
171
+ return ret
172
+ elif self.sep_style == SeparatorStyle.CHATGLM3:
173
+ ret = ''
174
+ if self.system_message:
175
+ ret += system_prompt
176
+ for role, message in self.messages:
177
+ if message:
178
+ ret += role + '\n' + ' ' + message
179
+ else:
180
+ ret += role
181
+ return ret
182
+ elif self.sep_style == SeparatorStyle.CHATINTERN:
183
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
184
+ seps = [self.sep, self.sep2]
185
+ ret = system_prompt
186
+ for i, (role, message) in enumerate(self.messages):
187
+ # if i % 2 == 0:
188
+ # ret += "<s>"
189
+ if message:
190
+ ret += role + ':' + message + seps[i % 2] + '\n'
191
+ else:
192
+ ret += role + ':'
193
+ return ret
194
+ elif self.sep_style == SeparatorStyle.DOLLY:
195
+ seps = [self.sep, self.sep2]
196
+ ret = system_prompt
197
+ for i, (role, message) in enumerate(self.messages):
198
+ if message:
199
+ ret += role + ':\n' + message + seps[i % 2]
200
+ if i % 2 == 1:
201
+ ret += '\n\n'
202
+ else:
203
+ ret += role + ':\n'
204
+ return ret
205
+ elif self.sep_style == SeparatorStyle.PHOENIX:
206
+ ret = system_prompt
207
+ for role, message in self.messages:
208
+ if message:
209
+ ret += role + ': ' + '<s>' + message + '</s>'
210
+ else:
211
+ ret += role + ': ' + '<s>'
212
+ return ret
213
+ elif self.sep_style == SeparatorStyle.ROBIN:
214
+ ret = system_prompt + self.sep
215
+ for role, message in self.messages:
216
+ if message:
217
+ ret += role + ':\n' + message + self.sep
218
+ else:
219
+ ret += role + ':\n'
220
+ return ret
221
+ elif self.sep_style == SeparatorStyle.FALCON_CHAT:
222
+ ret = ''
223
+ if self.system_message:
224
+ ret += system_prompt + self.sep
225
+ for role, message in self.messages:
226
+ if message:
227
+ ret += role + ': ' + message + self.sep
228
+ else:
229
+ ret += role + ':'
230
+
231
+ return ret
232
+ elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
233
+ seps = [self.sep, self.sep2]
234
+ ret = self.system_message + seps[0]
235
+ for i, (role, message) in enumerate(self.messages):
236
+ if message:
237
+ ret += role + ': ' + message + seps[i % 2]
238
+ else:
239
+ ret += role + ':'
240
+ return ret
241
+ elif self.sep_style == SeparatorStyle.MPT:
242
+ ret = system_prompt + self.sep
243
+ for role, message in self.messages:
244
+ if message:
245
+ if type(message) is tuple:
246
+ message, _, _ = message
247
+ ret += role + message + self.sep
248
+ else:
249
+ ret += role
250
+ return ret
251
+ elif self.sep_style == SeparatorStyle.QIANFANVL:
252
+ ret = ''
253
+ if self.system_message:
254
+ ret = system_prompt + self.sep
255
+ for role, message in self.messages:
256
+ if message:
257
+ if type(message) is tuple:
258
+ message, _, _ = message
259
+ ret += role + message + self.sep
260
+ else:
261
+ ret += role
262
+ return ret
263
+ else:
264
+ raise ValueError(f'Invalid style: {self.sep_style}')
265
+
266
+ def set_system_message(self, system_message: str):
267
+ """Set the system message."""
268
+ self.system_message = system_message
269
+
270
+ def append_message(self, role: str, message: str):
271
+ """Append a new message."""
272
+ self.messages.append([role, message])
273
+
274
+ def update_last_message(self, message: str):
275
+ """Update the last output.
276
+
277
+ The last message is typically set to be None when constructing the prompt,
278
+ so we need to update it in-place after getting the response from a model.
279
+ """
280
+ self.messages[-1][1] = message
281
+
282
+ def to_gradio_chatbot(self):
283
+ """Convert the conversation to gradio chatbot format."""
284
+ ret = []
285
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
286
+ if i % 2 == 0:
287
+ ret.append([msg, None])
288
+ else:
289
+ ret[-1][-1] = msg
290
+ return ret
291
+
292
+ def to_openai_api_messages(self):
293
+ """Convert the conversation to OpenAI chat completion format."""
294
+ ret = [{'role': 'system', 'content': self.system_message}]
295
+
296
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
297
+ if i % 2 == 0:
298
+ ret.append({'role': 'user', 'content': msg})
299
+ else:
300
+ if msg is not None:
301
+ ret.append({'role': 'assistant', 'content': msg})
302
+ return ret
303
+
304
+ def copy(self):
305
+ return Conversation(
306
+ name=self.name,
307
+ system_template=self.system_template,
308
+ system_message=self.system_message,
309
+ roles=self.roles,
310
+ messages=[[x, y] for x, y in self.messages],
311
+ offset=self.offset,
312
+ sep_style=self.sep_style,
313
+ sep=self.sep,
314
+ sep2=self.sep2,
315
+ stop_str=self.stop_str,
316
+ stop_token_ids=self.stop_token_ids,
317
+ )
318
+
319
+ def dict(self):
320
+ return {
321
+ 'template_name': self.name,
322
+ 'system_message': self.system_message,
323
+ 'roles': self.roles,
324
+ 'messages': self.messages,
325
+ 'offset': self.offset,
326
+ }
327
+
328
+
329
+ # A global registry for all conversation templates
330
+ conv_templates: Dict[str, Conversation] = {}
331
+
332
+
333
+ def register_conv_template(template: Conversation, override: bool = False):
334
+ """Register a new conversation template."""
335
+ if not override:
336
+ assert (
337
+ template.name not in conv_templates
338
+ ), f'{template.name} has been registered.'
339
+
340
+ conv_templates[template.name] = template
341
+
342
+
343
+ def get_conv_template(name: str) -> Conversation:
344
+ """Get a conversation template."""
345
+ return conv_templates[name].copy()
346
+
347
+
348
+ # Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference
349
+ # is that during training, the preprocessing function for the Hermes-2 template doesn't add
350
+ # <s> at the beginning of the tokenized sequence, while the internlm2-chat template does.
351
+ # Therefore, they are completely equivalent during inference.
352
+ register_conv_template(
353
+ Conversation(
354
+ name='Hermes-2',
355
+ system_template='<|im_start|>system\n{system_message}',
356
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
357
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
358
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
359
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
360
+ sep_style=SeparatorStyle.MPT,
361
+ sep='<|im_end|>',
362
+ stop_str='<|endoftext|>',
363
+ )
364
+ )
365
+
366
+
367
+ register_conv_template(
368
+ Conversation(
369
+ name='internlm2-chat',
370
+ system_template='<|im_start|>system\n{system_message}',
371
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
372
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
373
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
374
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
375
+ sep_style=SeparatorStyle.MPT,
376
+ sep='<|im_end|>',
377
+ )
378
+ )
379
+
380
+
381
+ register_conv_template(
382
+ Conversation(
383
+ name='phi3-chat',
384
+ system_template='<|system|>\n{system_message}',
385
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
386
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、���华大学及多家合作单位联合开发的多模态大语言模型。',
387
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
388
+ roles=('<|user|>\n', '<|assistant|>\n'),
389
+ sep_style=SeparatorStyle.MPT,
390
+ sep='<|end|>',
391
+ )
392
+ )
393
+
394
+
395
+ register_conv_template(
396
+ Conversation(
397
+ name='internvl2_5',
398
+ system_template='<|im_start|>system\n{system_message}',
399
+ system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
400
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
401
+ sep_style=SeparatorStyle.MPT,
402
+ sep='<|im_end|>\n',
403
+ )
404
+ )
405
+
406
+
407
+ register_conv_template(
408
+ Conversation(
409
+ name='qianfanvl',
410
+ system_template='<|im_start|>system\n{system_message}',
411
+ system_message='',
412
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
413
+ sep_style=SeparatorStyle.QIANFANVL,
414
+ sep='<|im_end|>\n',
415
+ )
416
+ )