jmcinern commited on
Commit
0671b54
·
verified ·
1 Parent(s): 555ebaa

Update app.py

Browse files

cleaner code, manual think strip, concurrent CPU use

Files changed (1) hide show
  1. app.py +80 -359
app.py CHANGED
@@ -1,388 +1,109 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import json
5
  import re
6
- from typing import List, Tuple, Optional
7
- import time
8
-
9
- # Thinking tag regex pattern for hard stripping
10
- THINK_TAG_PATTERN = re.compile(r'<think>.*?</think>\s*', flags=re.DOTALL)
11
 
12
  # Model configuration
13
  MODEL_NAME = "jmcinern/qwen3-8B-cpt-sft-awq"
 
14
 
15
- class IrishEnglishChatbot:
16
  def __init__(self):
17
  self.model = None
18
  self.tokenizer = None
19
- self.load_model()
 
 
 
 
20
 
21
  def load_model(self):
22
- """Load the quantized model and tokenizer"""
23
- print(f"Loading model: {MODEL_NAME}")
24
- try:
 
25
  print("Loading tokenizer...")
26
- self.tokenizer = AutoTokenizer.from_pretrained(
27
- MODEL_NAME,
28
- trust_remote_code=True
29
- )
30
-
31
- print("Loading model with optimized settings...")
32
- # Try different loading strategies in order of preference
33
-
34
- # Strategy 1: Try with llm-compressor (modern approach)
35
- try:
36
- from llmcompressor.transformers import SparseAutoModelForCausalLM
37
- print("Attempting to load with llm-compressor...")
38
- self.model = SparseAutoModelForCausalLM.from_pretrained(
39
- MODEL_NAME,
40
- trust_remote_code=True,
41
- device_map="auto",
42
- torch_dtype="auto",
43
- low_cpu_mem_usage=True
44
- )
45
- print("✅ Loaded with llm-compressor")
46
- return
47
- except ImportError:
48
- print("llm-compressor not available, trying AutoAWQ...")
49
- except Exception as e:
50
- print(f"llm-compressor failed: {e}, trying AutoAWQ...")
51
-
52
- # Strategy 2: Try with AutoAWQ (suppress deprecation warning)
53
- try:
54
- import warnings
55
- warnings.filterwarnings("ignore", category=DeprecationWarning)
56
- from awq import AutoAWQForCausalLM
57
- print("Attempting to load with AutoAWQ...")
58
- self.model = AutoAWQForCausalLM.from_quantized(
59
- MODEL_NAME,
60
- trust_remote_code=True,
61
- device_map="auto",
62
- low_cpu_mem_usage=True
63
- )
64
- print("✅ Loaded with AutoAWQ")
65
- return
66
- except Exception as e:
67
- print(f"AutoAWQ failed: {e}, falling back to transformers...")
68
-
69
- # Strategy 3: Fall back to standard transformers
70
- print("Attempting to load with standard transformers...")
71
- self.model = AutoModelForCausalLM.from_pretrained(
72
  MODEL_NAME,
73
  trust_remote_code=True,
74
  device_map="auto",
75
- torch_dtype=torch.float16,
76
- low_cpu_mem_usage=True,
77
- use_safetensors=True
78
  )
79
- print("✅ Loaded with transformers (fallback)")
80
-
81
- except Exception as e:
82
- print(f"❌ Error loading model: {e}")
83
- # Show user-friendly error
84
- self.model = None
85
- self.tokenizer = None
86
- raise RuntimeError(f"Failed to load model. This might be due to insufficient GPU memory or network issues. Error: {str(e)}")
87
-
88
- def format_chat_prompt(self, messages: List[dict], add_generation_prompt: bool = True) -> str:
89
- """Format messages using the custom Qwen3 chat template"""
90
  try:
91
- formatted = self.tokenizer.apply_chat_template(
92
- messages,
93
- tokenize=False,
94
- add_generation_prompt=add_generation_prompt,
95
- enable_thinking=False # Disable thinking mode as per your training
96
- )
97
- return formatted
 
 
 
 
 
98
  except Exception as e:
99
- print(f"Template error: {e}")
100
- # Fallback manual formatting
101
- formatted = ""
102
- for msg in messages:
103
- role = msg["role"]
104
- content = msg["content"]
105
- formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n"
106
- if add_generation_prompt:
107
- formatted += "<|im_start|>assistant\n"
108
- return formatted
109
 
110
- def generate_response(
111
- self,
112
- message: str,
113
- history: List[Tuple[str, str]],
114
- temperature: float = 0.7,
115
- max_tokens: int = 512,
116
- top_p: float = 0.9
117
- ) -> Tuple[str, List[Tuple[str, str]]]:
118
- """Generate response from the model"""
119
 
120
- if self.model is None:
121
- return "❌ Model not loaded. Please refresh the page.", history + [(message, "Model not loaded. Please refresh the page.")]
122
 
123
- try:
124
- # Build conversation history
125
- messages = []
126
-
127
- # Add conversation history
128
- for user_msg, assistant_msg in history:
129
- messages.append({"role": "user", "content": user_msg})
130
- messages.append({"role": "assistant", "content": assistant_msg})
131
-
132
- # Add current message
133
- messages.append({"role": "user", "content": message})
134
-
135
- # Format prompt
136
- formatted_prompt = self.format_chat_prompt(messages, add_generation_prompt=True)
137
-
138
- # Tokenize with length limits
139
- inputs = self.tokenizer(
140
- formatted_prompt,
141
- return_tensors="pt",
142
- truncation=True,
143
- max_length=3072 # Leave room for response
144
- ).to(self.model.device)
145
-
146
- # Generate with timeout protection
147
- with torch.no_grad():
148
- outputs = self.model.generate(
149
- **inputs,
150
- max_new_tokens=max_tokens,
151
- temperature=temperature,
152
- top_p=top_p,
153
- do_sample=temperature > 0,
154
- pad_token_id=self.tokenizer.eos_token_id,
155
- eos_token_id=self.tokenizer.eos_token_id,
156
- repetition_penalty=1.1,
157
- use_cache=True
158
- )
159
-
160
- # Decode response
161
- full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
162
-
163
- # Hard strip thinking tags (safety measure) - do this FIRST
164
- full_response = THINK_TAG_PATTERN.sub('', full_response)
165
-
166
- # Extract just the assistant's response
167
- if "<|im_start|>assistant" in full_response:
168
- response = full_response.split("<|im_start|>assistant")[-1]
169
- response = response.replace("<|im_end|>", "").strip()
170
- else:
171
- # Fallback - take everything after the input
172
- input_length = len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True))
173
- response = full_response[input_length:].strip()
174
-
175
- # Hard strip thinking tags (safety measure)
176
- response = THINK_TAG_PATTERN.sub('', response)
177
-
178
- # Clean up other chat tokens
179
- response = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', response, flags=re.DOTALL)
180
- response = response.strip()
181
-
182
- # Final safety check - remove any remaining thinking artifacts
183
- response = re.sub(r'</?think[^>]*>', '', response)
184
- response = response.strip()
185
-
186
- # Handle empty responses
187
- if not response:
188
- response = "I apologize, but I couldn't generate a proper response. Please try again."
189
-
190
- # Update history
191
- new_history = history + [(message, response)]
192
-
193
- return response, new_history
194
-
195
- except Exception as e:
196
- error_msg = f"❌ Generation error: {str(e)}"
197
- print(f"Generation error: {e}")
198
- new_history = history + [(message, error_msg)]
199
- return error_msg, new_history
200
-
201
- # Initialize chatbot with error handling
202
- print("Initializing chatbot...")
203
- try:
204
- chatbot = IrishEnglishChatbot()
205
- print("✅ Chatbot initialized successfully!")
206
- except Exception as e:
207
- print(f"❌ Failed to initialize chatbot: {e}")
208
- chatbot = None
209
-
210
- # Gradio interface functions
211
- def chat_fn(message, history, temperature, max_tokens, top_p):
212
- """Main chat function for Gradio"""
213
- if not message.strip():
214
- return history, history, ""
215
-
216
- if chatbot is None:
217
- error_msg = "❌ Model not available. Please contact the space owner."
218
- new_history = history + [(message, error_msg)]
219
- return new_history, new_history, ""
220
-
221
- try:
222
- response, new_history = chatbot.generate_response(
223
- message=message,
224
- history=history,
225
- temperature=temperature,
226
- max_tokens=max_tokens,
227
- top_p=top_p
228
  )
229
- return new_history, new_history, ""
230
-
231
- except Exception as e:
232
- error_msg = f"❌ Error: {str(e)}"
233
- new_history = history + [(message, error_msg)]
234
- return new_history, new_history, ""
235
-
236
- def clear_chat():
237
- """Clear chat history"""
238
- return [], []
239
-
240
- # Example prompts for different languages
241
- example_prompts = [
242
- "Conas atá tú inniu?", # Irish: How are you today?
243
- "What is the capital of Ireland?",
244
- "Inis dom faoi stair na hÉireann", # Irish: Tell me about Irish history
245
- "Translate 'hello' to Irish",
246
- "Cad iad na príomhchathracha in Éirinn?", # Irish: What are the main cities in Ireland?
247
- "Explain machine learning in simple terms"
248
- ]
249
-
250
- # Custom CSS
251
- custom_css = """
252
- .gradio-container {
253
- font-family: 'Arial', sans-serif;
254
- }
255
- .chat-message {
256
- padding: 10px;
257
- margin: 5px 0;
258
- border-radius: 8px;
259
- }
260
- .user-message {
261
- background-color: #e3f2fd;
262
- margin-left: 20%;
263
- }
264
- .bot-message {
265
- background-color: #f5f5f5;
266
- margin-right: 20%;
267
- }
268
- #title {
269
- text-align: center;
270
- color: #1976d2;
271
- font-size: 2em;
272
- margin-bottom: 1em;
273
- }
274
- """
275
-
276
- # Create Gradio interface
277
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
278
- gr.HTML("<h1 id='title'>🇮🇪 Irish-English Qwen3 Chatbot 🤖</h1>")
279
-
280
- gr.Markdown("""
281
- ## Fáilte! Welcome!
282
-
283
- This is an Irish-English bilingual AI assistant based on Qwen3-8B, fine-tuned for both Irish (Gaeilge) and English.
284
- You can chat with me in either language!
285
-
286
- **Features:**
287
- - 🇮🇪 Native Irish language support
288
- - 🇬🇧 English language support
289
- - ⚡ AWQ quantized for fast inference
290
- - 💬 Conversational chat interface
291
- """)
292
-
293
- with gr.Row():
294
- with gr.Column(scale=4):
295
- chatbot_interface = gr.Chatbot(
296
- label="Chat History",
297
- height=500,
298
- show_label=True,
299
- bubble_full_width=False
300
- )
301
-
302
- msg_box = gr.Textbox(
303
- label="Your message",
304
- placeholder="Type your message in Irish or English...",
305
- lines=2,
306
- max_lines=4
307
- )
308
-
309
- with gr.Row():
310
- submit_btn = gr.Button("Send", variant="primary", size="sm")
311
- clear_btn = gr.Button("Clear Chat", variant="secondary", size="sm")
312
 
313
- with gr.Column(scale=1):
314
- gr.Markdown("### Settings")
315
-
316
- temperature = gr.Slider(
317
- minimum=0.1,
318
- maximum=1.0,
319
- value=0.7,
320
- step=0.1,
321
- label="Temperature",
322
- info="Higher = more creative"
323
- )
324
-
325
- max_tokens = gr.Slider(
326
- minimum=50,
327
- maximum=1024,
328
- value=512,
329
- step=50,
330
- label="Max Tokens",
331
- info="Maximum response length"
332
- )
333
-
334
- top_p = gr.Slider(
335
- minimum=0.1,
336
- maximum=1.0,
337
- value=0.9,
338
- step=0.1,
339
- label="Top P",
340
- info="Nucleus sampling"
341
  )
342
-
343
- gr.Markdown("### Example Prompts")
344
-
345
- for prompt in example_prompts:
346
- gr.Button(
347
- prompt,
348
- size="sm",
349
- variant="outline"
350
- ).click(
351
- fn=lambda x=prompt: x,
352
- outputs=msg_box
353
- )
354
-
355
- # Event handlers
356
- submit_btn.click(
357
- fn=chat_fn,
358
- inputs=[msg_box, chatbot_interface, temperature, max_tokens, top_p],
359
- outputs=[chatbot_interface, chatbot_interface, msg_box]
360
- )
361
-
362
- msg_box.submit(
363
- fn=chat_fn,
364
- inputs=[msg_box, chatbot_interface, temperature, max_tokens, top_p],
365
- outputs=[chatbot_interface, chatbot_interface, msg_box]
366
- )
367
 
368
- clear_btn.click(
369
- fn=clear_chat,
370
- outputs=[chatbot_interface, chatbot_interface]
371
- )
372
 
373
- # Footer
374
- gr.HTML("""
375
- <div style="text-align: center; margin-top: 2em; color: #666;">
376
- <p>Model: <a href="https://huggingface.co/jmcinern/qwen3-8B-cpt-sft-awq" target="_blank">jmcinern/qwen3-8B-cpt-sft-awq</a></p>
377
- <p>Based on Qwen3-8B | AWQ Quantized | Irish-English Bilingual</p>
378
- </div>
379
- """)
380
 
381
- # Launch configuration
382
  if __name__ == "__main__":
383
- demo.launch(
384
- share=False,
385
- server_name="0.0.0.0",
386
- server_port=7860,
387
- show_error=True
388
- )
 
1
  import gradio as gr
2
  import torch
 
 
3
  import re
4
+ import threading
5
+ from llmcompressor.transformers import SparseAutoModelForCausalLM
6
+ from transformers import AutoTokenizer
 
 
7
 
8
  # Model configuration
9
  MODEL_NAME = "jmcinern/qwen3-8B-cpt-sft-awq"
10
+ THINK_TAG_PATTERN = re.compile(r'<think>.*?</think>\s*', flags=re.DOTALL)
11
 
12
+ class ChatBot:
13
  def __init__(self):
14
  self.model = None
15
  self.tokenizer = None
16
+ self.loading = True
17
+
18
+ # Load model in separate thread
19
+ thread = threading.Thread(target=self.load_model)
20
+ thread.start()
21
 
22
  def load_model(self):
23
+ """Load model and tokenizer with concurrent loading"""
24
+ import concurrent.futures
25
+
26
+ def load_tokenizer():
27
  print("Loading tokenizer...")
28
+ return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
29
+
30
+ def load_model():
31
+ print("Loading model...")
32
+ return SparseAutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  MODEL_NAME,
34
  trust_remote_code=True,
35
  device_map="auto",
36
+ torch_dtype="auto",
37
+ max_workers=4 # Use 4 threads for model loading
 
38
  )
39
+
 
 
 
 
 
 
 
 
 
 
40
  try:
41
+ # Load tokenizer and model concurrently
42
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
43
+ tokenizer_future = executor.submit(load_tokenizer)
44
+ model_future = executor.submit(load_model)
45
+
46
+ # Get results
47
+ self.tokenizer = tokenizer_future.result()
48
+ print("Tokenizer loaded!")
49
+
50
+ self.model = model_future.result()
51
+ print("Model loaded!")
52
+
53
  except Exception as e:
54
+ print(f"Error loading: {e}")
55
+ finally:
56
+ self.loading = False
 
 
 
 
 
 
 
57
 
58
+ def chat(self, message, history):
59
+ if self.loading:
60
+ return history + [(message, "Model is loading, please wait...")]
 
 
 
 
 
 
61
 
62
+ if not self.model:
63
+ return history + [(message, "Model failed to load")]
64
 
65
+ # Build messages
66
+ messages = []
67
+ for user_msg, bot_msg in history:
68
+ messages.append({"role": "user", "content": user_msg})
69
+ messages.append({"role": "assistant", "content": bot_msg})
70
+ messages.append({"role": "user", "content": message})
71
+
72
+ # Apply chat template and strip thinking
73
+ prompt = self.tokenizer.apply_chat_template(
74
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  )
76
+ prompt = THINK_TAG_PATTERN.sub("", prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Generate
79
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
80
+
81
+ with torch.no_grad():
82
+ outputs = self.model.generate(
83
+ **inputs,
84
+ max_new_tokens=512,
85
+ temperature=0.7,
86
+ do_sample=True,
87
+ pad_token_id=self.tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  )
89
+
90
+ # Extract response
91
+ response = self.tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
92
+ response = THINK_TAG_PATTERN.sub("", response).strip()
93
+
94
+ return history + [(message, response)]
95
+
96
+ # Initialize chatbot
97
+ bot = ChatBot()
98
+
99
+ # Create interface
100
+ with gr.Blocks() as demo:
101
+ gr.HTML("<h1 style='text-align: center;'>Qomhrá: A Bilingual Irish-English LLM</h1>")
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ chatbot = gr.Chatbot(height=500)
104
+ msg = gr.Textbox(placeholder="Type your message...", show_label=False)
 
 
105
 
106
+ msg.submit(bot.chat, [msg, chatbot], [chatbot]).then(lambda: "", outputs=msg)
 
 
 
 
 
 
107
 
 
108
  if __name__ == "__main__":
109
+ demo.launch()