import gradio as gr from unsloth import FastLanguageModel from huggingface_hub import snapshot_download from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList import torch from threading import Thread import io from contextlib import redirect_stdout # Ensure environment variables are set for the model (optional) import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" # Define constants and model name MODEL_NAME = "thanawatpi/appherb-treatment-pred-beta" # Define the preset Alpaca prompt (you can change the wording if needed) ALPACA_PROMPT = """ You are an AI assistant trained to provide helpful, accurate, and friendly responses. Your tone should be polite, clear, and concise. If you don't know the answer, respond honestly and suggest alternatives or request clarification. For this session, your role is to assist the user with general queries, help with problem-solving, and guide them through different topics. """ # Initialize the model and tokenizer print("Loading model ... Please wait 1 more minute! ...") with redirect_stdout(io.StringIO()): model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_NAME, max_seq_length=None, dtype=None, load_in_4bit=True, ) FastLanguageModel.for_inference(model) # Define stop token handler class StopOnTokens(StoppingCriteria): def __init__(self, stop_token_ids): self.stop_token_ids = tuple(set(stop_token_ids)) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: return input_ids[0][-1].item() in self.stop_token_ids # Asynchronous function to handle chatbot responses def async_process_chatbot(message, history): eos_token = tokenizer.eos_token stop_on_tokens = StopOnTokens([eos_token]) text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) history_transformer_format = history + [[message, ""]] messages = [] for item in history_transformer_format: messages.append({"role": "user", "content": item[0]}) messages.append({"role": "assistant", "content": item[1]}) # Remove last assistant response messages.pop(-1) # Check if CUDA is available, otherwise use CPU device = "cuda" if torch.cuda.is_available() else "cpu" input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", ).to(device, non_blocking=True) generation_kwargs = dict( input_ids=input_ids, streamer=text_streamer, max_new_tokens=1024, stopping_criteria=StoppingCriteriaList([stop_on_tokens]), temperature=0.7, do_sample=True, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() generated_text = "" for new_text in text_streamer: if new_text.endswith(eos_token): new_text = new_text[:len(new_text) - len(eos_token)] generated_text += new_text yield generated_text # Define the response function for Gradio def respond(message, history, system_message, max_tokens, temperature, top_p): # Add preset Alpaca prompt as the system message if it's not provided if system_message == "": system_message = ALPACA_PROMPT # Prepare the system message system_message = [{"role": "system", "content": system_message}] # Convert history into the correct message format for chat for val in history: if val[0]: system_message.append({"role": "user", "content": val[0]}) if val[1]: system_message.append({"role": "assistant", "content": val[1]}) # Add user message to the conversation system_message.append({"role": "user", "content": message}) # Stream response from Unsloth model response = "" try: async_gen = async_process_chatbot(message, history) for generated_text in async_gen: response += generated_text yield response except Exception as e: print(f"Error during API request: {e}") yield "An error occurred during the request." # Setup the Gradio interface demo = gr.Interface( fn=respond, inputs=[ gr.Textbox(label="User Message"), gr.Chatbot(label="Chat History", height=325), gr.Textbox(value=ALPACA_PROMPT, label="System Message (Leave blank to use default)"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") ], outputs=[gr.Chatbot()], theme="compact", live=True ) # Launch the app if __name__ == "__main__": demo.launch()