ruslanmv commited on
Commit
77854f6
·
verified ·
1 Parent(s): 672d64d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -84
app.py CHANGED
@@ -1,94 +1,227 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import torch
5
- import spaces
6
- import os
 
 
 
 
 
7
  IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
8
  IS_SPACE = os.environ.get("SPACE_ID", None) is not None
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
12
- print(f"Using device: {device}")
13
- print(f"low memory: {LOW_MEMORY}")
14
- # Define BitsAndBytesConfig
15
- bnb_config = BitsAndBytesConfig(load_in_4bit=True,
16
- bnb_4bit_quant_type="nf4",
17
- bnb_4bit_compute_dtype=torch.float16)
18
-
19
- # Model name
20
- model_name = "ruslanmv/Medical-Llama3-v2"
21
-
22
- # Load tokenizer and model with BitsAndBytesConfig
23
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, bnb_config=bnb_config)
24
- model = AutoModelForCausalLM.from_pretrained(model_name, config=bnb_config)
25
-
26
- # Ensure model is on the correct device
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- model.to(device)
29
- @spaces.GPU
30
- # Define the respond function
31
- def respond(
32
- message,
33
- history: list[tuple[str, str]],
34
- system_message,
35
- max_tokens,
36
- temperature,
37
- top_p,
38
- ):
39
- messages = [{"role": "system", "content": system_message}]
40
-
41
- for val in history:
42
- if val[0]:
43
- messages.append({"role": "user", "content": val[0]})
44
- if val[1]:
45
- messages.append({"role": "assistant", "content": val[1]})
46
-
47
- messages.append({"role": "user", "content": message})
48
-
49
- # Format the conversation as a single string for the model
50
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
51
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512)
52
-
53
- # Move inputs to device
54
- input_ids = inputs['input_ids'].to(device)
55
- attention_mask = inputs['attention_mask'].to(device)
56
-
57
- # Generate the response
58
- with torch.no_grad():
59
- outputs = model.generate(
60
- input_ids=input_ids,
61
- attention_mask=attention_mask,
62
- max_length=max_tokens,
63
- temperature=temperature,
64
- top_p=top_p,
65
- use_cache=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
-
68
- # Extract the response
69
- response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
70
-
71
- # Remove the prompt and system message from the response
72
- response_text = response_text.replace(system_message, '').strip()
73
- response_text = response_text.replace(f"Human: {message}\n\nAssistant: ", '').strip()
74
-
75
- return response_text
76
-
77
- # Create the Gradio interface
78
- demo = gr.ChatInterface(
79
- respond,
80
- additional_inputs=[
81
- gr.Textbox(value="You are a Medical AI Assistant. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.", label="System message", lines=3),
82
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
83
- gr.Slider(minimum=0.1, maximum=4.0, value=0.8, step=0.1, label="Temperature"),
84
- gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
85
- ],
86
- title="Medical AI Assistant",
87
- description="Give me your symptoms and ask me a health problem. The AI will provide informative answers. If the AI doesn't know the answer, it will advise seeking professional help.",
88
-
89
- examples=[["I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, increased sensitivity to cold, and dry, itchy skin. Could these symptoms be related to hypothyroidism?"], ["I have a headache and a fever. What should I do?"], ["How can I improve my sleep?"]],
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  )
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if __name__ == "__main__":
94
- demo.launch()
 
1
+ import os
2
  import gradio as gr
 
 
3
  import torch
4
+ import spaces
5
+
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
+
8
+
9
+ model_name = "ruslanmv/Medical-Llama3-8B"
10
+
11
  IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
12
  IS_SPACE = os.environ.get("SPACE_ID", None) is not None
13
 
14
+ print(f"Running in Hugging Face Space: {IS_SPACE}")
15
+ print(f"Running with ZeroGPU: {IS_SPACES_ZERO}")
16
+ print(f"CUDA available: {torch.cuda.is_available()}")
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(
19
+ model_name,
20
+ trust_remote_code=True,
21
+ )
22
+
23
+ if tokenizer.pad_token is None:
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+
26
+ model = None
27
+
28
+
29
+ def load_model():
30
+ global model
31
+
32
+ if model is not None:
33
+ return model
34
+
35
+ print("Loading model...")
36
+
37
+ if torch.cuda.is_available():
38
+ quantization_config = BitsAndBytesConfig(
39
+ load_in_4bit=True,
40
+ bnb_4bit_compute_dtype=torch.bfloat16,
41
+ bnb_4bit_use_double_quant=True,
42
+ bnb_4bit_quant_type="nf4",
43
+ )
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ quantization_config=quantization_config,
48
+ device_map="auto",
49
+ torch_dtype=torch.bfloat16,
50
+ trust_remote_code=True,
51
+ )
52
+ else:
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ model_name,
55
+ torch_dtype=torch.float32,
56
+ trust_remote_code=True,
57
+ )
58
+
59
+ model.eval()
60
+ print("Model loaded.")
61
+ return model
62
+
63
+
64
+ @spaces.GPU(duration=120)
65
+ def askme(symptoms, question):
66
+ try:
67
+ current_model = load_model()
68
+
69
+ sys_message = """
70
+ You are an AI Medical Assistant trained on a vast dataset of health information.
71
+ Please be thorough and provide an informative answer.
72
+ If you don't know the answer to a specific medical inquiry, advise seeking professional help.
73
+ Always remind users that your answer is not a substitute for professional medical advice.
74
+ """
75
+
76
+ symptoms = symptoms.strip() if symptoms else ""
77
+ question = question.strip() if question else ""
78
+
79
+ if not symptoms and not question:
80
+ return "Please enter your symptoms and/or medical question."
81
+
82
+ content = f"Symptoms: {symptoms}\n\nQuestion: {question}"
83
+
84
+ messages = [
85
+ {"role": "system", "content": sys_message},
86
+ {"role": "user", "content": content},
87
+ ]
88
+
89
+ prompt = tokenizer.apply_chat_template(
90
+ messages,
91
+ tokenize=False,
92
+ add_generation_prompt=True,
93
+ )
94
+
95
+ inputs = tokenizer(
96
+ prompt,
97
+ return_tensors="pt",
98
+ padding=True,
99
+ truncation=True,
100
+ max_length=2048,
101
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ inputs = {
104
+ key: value.to(current_model.device)
105
+ for key, value in inputs.items()
106
+ }
107
+
108
+ with torch.no_grad():
109
+ outputs = current_model.generate(
110
+ **inputs,
111
+ max_new_tokens=300,
112
+ do_sample=True,
113
+ temperature=0.7,
114
+ top_p=0.9,
115
+ repetition_penalty=1.1,
116
+ pad_token_id=tokenizer.eos_token_id,
117
+ eos_token_id=tokenizer.eos_token_id,
118
+ use_cache=True,
119
+ )
120
+
121
+ generated_ids = outputs[0][inputs["input_ids"].shape[-1]:]
122
+
123
+ answer = tokenizer.decode(
124
+ generated_ids,
125
+ skip_special_tokens=True,
126
+ ).strip()
127
+
128
+ if not answer:
129
+ answer = "I could not generate a response. Please try rephrasing your question."
130
+
131
+ return answer
132
+
133
+ except Exception as e:
134
+ return f"Error: {type(e).__name__}: {str(e)}"
135
+
136
+
137
+ symptoms_example = """
138
+ I'm a 35-year-old male and for the past few months, I've been experiencing fatigue,
139
+ increased sensitivity to cold, and dry, itchy skin.
140
+ """
141
+
142
+ question_example = """
143
+ Could these symptoms be related to hypothyroidism?
144
+ If so, what steps should I take to get a proper diagnosis and discuss treatment options?
145
+ """
146
+
147
+ examples = [
148
+ [symptoms_example, question_example]
149
+ ]
150
+
151
+
152
+ css = """
153
+ .gradio-container {
154
+ font-family: "IBM Plex Sans", sans-serif;
155
+ background-color: #212529;
156
+ color: #fff;
157
+ background-image: url("https://huggingface.co/spaces/ruslanmv/AI-Medical-Chatbot/resolve/main/notebook/local/img/background.jpg");
158
+ background-size: cover;
159
+ background-position: center;
160
+ }
161
+
162
+ .gr-button {
163
+ color: white;
164
+ background: #007bff;
165
+ white-space: nowrap;
166
+ border: none;
167
+ padding: 10px 20px;
168
+ border-radius: 8px;
169
+ cursor: pointer;
170
+ }
171
+
172
+ .gr-button:hover {
173
+ background-color: #0056b3;
174
+ }
175
+
176
+ .gradio-textbox textarea {
177
+ background-color: #343a40;
178
+ color: #fff;
179
+ border-color: #343a40;
180
+ border-radius: 8px;
181
+ }
182
+ """
183
+
184
+
185
+ welcome_message = """
186
+ # AI Medical Llama 3 Chatbot
187
+
188
+ Ask any medical question by first giving your symptoms.
189
+
190
+ Developed by Ruslan Magana. Visit [https://ruslanmv.com/](https://ruslanmv.com/) for more information.
191
+
192
+ **Disclaimer:** This chatbot is for educational purposes only and is not a substitute for professional medical advice, diagnosis, or treatment.
193
+ """
194
+
195
+
196
+ symptoms_input = gr.Textbox(
197
+ label="Symptoms",
198
+ placeholder="Enter your symptoms here...",
199
+ lines=6,
200
  )
201
 
202
+ question_input = gr.Textbox(
203
+ label="Question",
204
+ placeholder="Enter your medical question here...",
205
+ lines=4,
206
+ )
207
+
208
+ answer_output = gr.Textbox(
209
+ label="Answer",
210
+ lines=12,
211
+ )
212
+
213
+
214
+ iface = gr.Interface(
215
+ fn=askme,
216
+ inputs=[symptoms_input, question_input],
217
+ outputs=answer_output,
218
+ examples=examples,
219
+ cache_examples=False,
220
+ css=css,
221
+ title="AI Medical Llama 3 Chatbot",
222
+ description=welcome_message,
223
+ )
224
+
225
+
226
  if __name__ == "__main__":
227
+ iface.launch()