thanawatpi commited on
Commit
d97d93d
·
verified ·
1 Parent(s): faa572f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -43
app.py CHANGED
@@ -1,62 +1,134 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- client = InferenceClient("thanawatpi/appherb-treatment-pred-beta")
5
-
6
- def respond(
7
- message,
8
- history: list[tuple[str, str]],
9
- system_message,
10
- max_tokens,
11
- temperature,
12
- top_p,
13
- ):
14
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  for val in history:
17
  if val[0]:
18
- messages.append({"role": "user", "content": val[0]})
19
  if val[1]:
20
- messages.append({"role": "assistant", "content": val[1]})
21
 
22
- messages.append({"role": "user", "content": message})
 
23
 
 
24
  response = ""
25
-
26
- # Debugging: printing the messages to verify the format
27
- print("Messages sent to model:", messages)
28
-
29
  try:
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- print(message) # Debugging line to see the response structure
38
- token = message.choices[0].delta.content
39
- response += token
40
  yield response
41
  except Exception as e:
42
- print("Error during API request:", e)
43
  yield "An error occurred during the request."
44
 
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
 
 
 
50
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
  ],
 
 
 
59
  )
60
 
 
61
  if __name__ == "__main__":
62
  demo.launch()
 
 
1
  import gradio as gr
2
+ from unsloth import FastLanguageModel
3
+ from huggingface_hub import snapshot_download
4
+ from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
5
+ import torch
6
+ from threading import Thread
7
+ import io
8
+ from contextlib import redirect_stdout
9
+
10
+ # Ensure environment variables are set for the model (optional)
11
+ import os
12
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
13
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
14
+
15
+ # Define constants and model name
16
+ MODEL_NAME = "thanawatpi/appherb-treatment-pred-beta"
17
+
18
+ # Define the preset Alpaca prompt (you can change the wording if needed)
19
+ ALPACA_PROMPT = """
20
+ You are an AI assistant trained to provide helpful, accurate, and friendly responses. Your tone should be polite, clear, and concise.
21
+ If you don't know the answer, respond honestly and suggest alternatives or request clarification.
22
+
23
+ For this session, your role is to assist the user with general queries, help with problem-solving, and guide them through different topics.
24
+ """
25
+
26
+ # Initialize the model and tokenizer
27
+ print("Loading model ... Please wait 1 more minute! ...")
28
+ with redirect_stdout(io.StringIO()):
29
+ model, tokenizer = FastLanguageModel.from_pretrained(
30
+ model_name = MODEL_NAME,
31
+ max_seq_length = None,
32
+ dtype = None,
33
+ load_in_4bit = True,
34
+ )
35
+ FastLanguageModel.for_inference(model)
36
+
37
+ # Define stop token handler
38
+ class StopOnTokens(StoppingCriteria):
39
+ def __init__(self, stop_token_ids):
40
+ self.stop_token_ids = tuple(set(stop_token_ids))
41
+
42
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
43
+ return input_ids[0][-1].item() in self.stop_token_ids
44
+
45
+ # Asynchronous function to handle chatbot responses
46
+ def async_process_chatbot(message, history):
47
+ eos_token = tokenizer.eos_token
48
+ stop_on_tokens = StopOnTokens([eos_token])
49
+ text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
50
+
51
+ history_transformer_format = history + [[message, ""]]
52
+ messages = []
53
+ for item in history_transformer_format:
54
+ messages.append({"role": "user", "content": item[0]})
55
+ messages.append({"role": "assistant", "content": item[1]})
56
 
57
+ # Remove last assistant response
58
+ messages.pop(-1)
59
+
60
+ input_ids = tokenizer.apply_chat_template(
61
+ messages,
62
+ add_generation_prompt=True,
63
+ return_tensors="pt",
64
+ ).to("cuda", non_blocking=True)
65
+
66
+ generation_kwargs = dict(
67
+ input_ids=input_ids,
68
+ streamer=text_streamer,
69
+ max_new_tokens=1024,
70
+ stopping_criteria=StoppingCriteriaList([stop_on_tokens]),
71
+ temperature=0.7,
72
+ do_sample=True,
73
+ )
74
+
75
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
76
+ thread.start()
77
+
78
+ generated_text = ""
79
+ for new_text in text_streamer:
80
+ if new_text.endswith(eos_token):
81
+ new_text = new_text[:len(new_text) - len(eos_token)]
82
+ generated_text += new_text
83
+ yield generated_text
84
+
85
+ # Define the response function for Gradio
86
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
87
+ # Add preset Alpaca prompt as the system message if it's not provided
88
+ if system_message == "":
89
+ system_message = ALPACA_PROMPT
90
+
91
+ # Prepare the system message
92
+ system_message = [{"role": "system", "content": system_message}]
93
+
94
+ # Convert history into the correct message format for chat
95
  for val in history:
96
  if val[0]:
97
+ system_message.append({"role": "user", "content": val[0]})
98
  if val[1]:
99
+ system_message.append({"role": "assistant", "content": val[1]})
100
 
101
+ # Add user message to the conversation
102
+ system_message.append({"role": "user", "content": message})
103
 
104
+ # Stream response from Unsloth model
105
  response = ""
 
 
 
 
106
  try:
107
+ async_gen = async_process_chatbot(message, history)
108
+ for generated_text in async_gen:
109
+ response += generated_text
 
 
 
 
 
 
 
110
  yield response
111
  except Exception as e:
112
+ print(f"Error during API request: {e}")
113
  yield "An error occurred during the request."
114
 
115
+ # Setup the Gradio interface
116
+ demo = gr.Interface(
117
+ fn=respond,
118
+ inputs=[
119
+ gr.Textbox(label="User Message"),
120
+ gr.Chatbot(label="Chat History", height=325),
121
+ gr.Textbox(value=ALPACA_PROMPT, label="System Message (Leave blank to use default)"),
122
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens"),
123
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
124
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
 
 
 
 
 
 
125
  ],
126
+ outputs=[gr.Chatbot()],
127
+ theme="compact",
128
+ live=True
129
  )
130
 
131
+ # Launch the app
132
  if __name__ == "__main__":
133
  demo.launch()
134
+