Spaces:
Runtime error
Runtime error
| # MiniCPM5-1B Demo | |
| from pathlib import Path | |
| import os | |
| import time | |
| import logging | |
| import threading | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from huggingface_hub import login | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from utils_chatbot import organize_messages_from_messages, stream2display_text | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| MODEL_PATH = "openbmb/MiniCPM5-1B" | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| logger.info("Logged in to Hugging Face Hub") | |
| else: | |
| logger.warning("HF_TOKEN not set β private/gated models will be inaccessible") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| ).to("cuda") | |
| def gpu_generate_stream(inputs, history, temperature, top_p): | |
| prompt_text = tokenizer.apply_chat_template( | |
| inputs, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| model_inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda") | |
| history.append({"role": "assistant", "content": ""}) | |
| yield "", history | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=False, | |
| ) | |
| gen_kwargs = dict( | |
| **model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=4096, | |
| ) | |
| if temperature > 0: | |
| gen_kwargs.update(temperature=temperature, top_p=top_p, do_sample=True) | |
| else: | |
| gen_kwargs.update(do_sample=False) | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| stream_text = "" | |
| gen_tk_count = 0 | |
| start_time = time.time() | |
| for new_token_text in streamer: | |
| if not new_token_text: | |
| continue | |
| stream_text += new_token_text | |
| gen_tk_count += 1 | |
| elapsed = time.time() - start_time | |
| token_per_sec = gen_tk_count / elapsed if elapsed > 0 else 0 | |
| display_text = stream2display_text(stream_text, token_per_sec) | |
| history[-1]["content"] = display_text | |
| yield "", history | |
| thread.join() | |
| history[-1]["content"] = stream_text.replace("<|im_end|>", "") | |
| yield "", history | |
| def gen_response_stream(message, history, temperature, top_p): | |
| chat_msg_ls = organize_messages_from_messages(message, history) | |
| history.append({"role": "user", "content": message}) | |
| yield from gpu_generate_stream( | |
| chat_msg_ls, history, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| def create_app(): | |
| assets_path = Path.cwd().absolute() / "assets" | |
| gr.set_static_paths(paths=[assets_path]) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML( | |
| '<div class="logo-container">' | |
| '<img src="/gradio_api/file=assets/OpenBMB-MiniCPM.png" alt="MiniCPM Logo">' | |
| "</div>" | |
| ) | |
| gr.HTML("<div style='height:1px;'></div>") | |
| temperature = gr.Slider( | |
| minimum=0, maximum=1, value=0.6, step=0.05, label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0, maximum=1, value=0.95, step=0.01, label="Top-p" | |
| ) | |
| gr.HTML("<div style='height:128px;'></div>") | |
| clear = gr.Button("Clear History") | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| label="Chat History", | |
| placeholder="Input to start a new chat", | |
| height=500, | |
| ) | |
| prompt = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Type your message here...", | |
| lines=1, | |
| elem_classes=["input-box"], | |
| ) | |
| prompt.submit( | |
| gen_response_stream, | |
| inputs=[prompt, chatbot, temperature, top_p], | |
| outputs=[prompt, chatbot], | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| return demo | |
| THEME = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="gray", | |
| neutral_hue="slate", | |
| font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"], | |
| ) | |
| CSS = """ | |
| .logo-container { | |
| text-align: center; | |
| margin: 0.5rem 0 1rem 0; | |
| } | |
| .logo-container img { | |
| height: 96px; | |
| width: auto; | |
| max-width: 200px; | |
| display: inline-block; | |
| } | |
| .input-box { | |
| border: 1px solid #2f63b8; | |
| border-radius: 8px; | |
| } | |
| """ | |
| demo = create_app() | |
| if __name__ == "__main__": | |
| demo.launch(theme=THEME, css=CSS) | |