mac commited on
Commit
b21e239
Β·
1 Parent(s): 1e8128a

chat template

Browse files
Files changed (2) hide show
  1. app.py +35 -33
  2. utils_chatbot.py +9 -10
app.py CHANGED
@@ -12,7 +12,7 @@ import torch
12
  from huggingface_hub import login
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
14
 
15
- from utils_chatbot import organize_messages, stream2display_text
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
@@ -44,6 +44,7 @@ def gpu_generate_stream(inputs, history, temperature, top_p):
44
  )
45
  model_inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda")
46
 
 
47
  yield "", history
48
 
49
  streamer = TextIteratorStreamer(
@@ -77,17 +78,17 @@ def gpu_generate_stream(inputs, history, temperature, top_p):
77
  elapsed = time.time() - start_time
78
  token_per_sec = gen_tk_count / elapsed if elapsed > 0 else 0
79
  display_text = stream2display_text(stream_text, token_per_sec)
80
- history[-1] = (history[-1][0], display_text)
81
  yield "", history
82
 
83
  thread.join()
84
- history[-1] = (history[-1][0], stream_text.replace("<|im_end|>", ""))
85
  yield "", history
86
 
87
 
88
  def gen_response_stream(message, history, temperature, top_p):
89
- chat_msg_ls = organize_messages(message, history)
90
- history.append((message, ""))
91
  yield from gpu_generate_stream(
92
  chat_msg_ls, history,
93
  temperature=temperature,
@@ -99,32 +100,7 @@ def create_app():
99
  assets_path = Path.cwd().absolute() / "assets"
100
  gr.set_static_paths(paths=[assets_path])
101
 
102
- theme = gr.themes.Soft(
103
- primary_hue="blue",
104
- secondary_hue="gray",
105
- neutral_hue="slate",
106
- font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
107
- )
108
-
109
- with gr.Blocks(
110
- theme=theme,
111
- css="""
112
- .logo-container {
113
- text-align: center;
114
- margin: 0.5rem 0 1rem 0;
115
- }
116
- .logo-container img {
117
- height: 96px;
118
- width: auto;
119
- max-width: 200px;
120
- display: inline-block;
121
- }
122
- .input-box {
123
- border: 1px solid #2f63b8;
124
- border-radius: 8px;
125
- }
126
- """,
127
- ) as demo:
128
  with gr.Row():
129
  with gr.Column(scale=1):
130
  gr.HTML(
@@ -165,6 +141,32 @@ def create_app():
165
  return demo
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  if __name__ == "__main__":
169
- demo = create_app()
170
- demo.launch()
 
12
  from huggingface_hub import login
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
14
 
15
+ from utils_chatbot import organize_messages_from_messages, stream2display_text
16
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
 
44
  )
45
  model_inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda")
46
 
47
+ history.append({"role": "assistant", "content": ""})
48
  yield "", history
49
 
50
  streamer = TextIteratorStreamer(
 
78
  elapsed = time.time() - start_time
79
  token_per_sec = gen_tk_count / elapsed if elapsed > 0 else 0
80
  display_text = stream2display_text(stream_text, token_per_sec)
81
+ history[-1]["content"] = display_text
82
  yield "", history
83
 
84
  thread.join()
85
+ history[-1]["content"] = stream_text.replace("<|im_end|>", "")
86
  yield "", history
87
 
88
 
89
  def gen_response_stream(message, history, temperature, top_p):
90
+ chat_msg_ls = organize_messages_from_messages(message, history)
91
+ history.append({"role": "user", "content": message})
92
  yield from gpu_generate_stream(
93
  chat_msg_ls, history,
94
  temperature=temperature,
 
100
  assets_path = Path.cwd().absolute() / "assets"
101
  gr.set_static_paths(paths=[assets_path])
102
 
103
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  with gr.Row():
105
  with gr.Column(scale=1):
106
  gr.HTML(
 
141
  return demo
142
 
143
 
144
+ THEME = gr.themes.Soft(
145
+ primary_hue="blue",
146
+ secondary_hue="gray",
147
+ neutral_hue="slate",
148
+ font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
149
+ )
150
+
151
+ CSS = """
152
+ .logo-container {
153
+ text-align: center;
154
+ margin: 0.5rem 0 1rem 0;
155
+ }
156
+ .logo-container img {
157
+ height: 96px;
158
+ width: auto;
159
+ max-width: 200px;
160
+ display: inline-block;
161
+ }
162
+ .input-box {
163
+ border: 1px solid #2f63b8;
164
+ border-radius: 8px;
165
+ }
166
+ """
167
+
168
+
169
+ demo = create_app()
170
+
171
  if __name__ == "__main__":
172
+ demo.launch(theme=THEME, css=CSS)
 
utils_chatbot.py CHANGED
@@ -1,13 +1,12 @@
1
- def organize_messages(message, history):
2
- msg_ls = [dict(
3
- role="system",
4
- content="You are a helpful assistant.",
5
- )]
6
- for user, assistant in history:
7
- msg_ls.append(dict(role="user", content=user))
8
- if assistant:
9
- msg_ls.append(dict(role="assistant", content=assistant))
10
- msg_ls.append(dict(role="user", content=message))
11
  return msg_ls
12
 
13
 
 
1
+ def organize_messages_from_messages(message, history):
2
+ """Build chat messages from Gradio 6.x messages-format history."""
3
+ msg_ls = [{"role": "system", "content": "You are a helpful assistant."}]
4
+ for msg in history:
5
+ role = msg.get("role", "")
6
+ content = msg.get("content", "")
7
+ if role in ("user", "assistant") and content:
8
+ msg_ls.append({"role": role, "content": content})
9
+ msg_ls.append({"role": "user", "content": message})
 
10
  return msg_ls
11
 
12