mac
chat template
b21e239
raw
history blame
4.75 kB
# MiniCPM5-1B Demo
from pathlib import Path
import os
import time
import logging
import threading
import gradio as gr
import spaces
import torch
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from utils_chatbot import organize_messages_from_messages, stream2display_text
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_PATH = "openbmb/MiniCPM5-1B"
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
logger.info("Logged in to Hugging Face Hub")
else:
logger.warning("HF_TOKEN not set β€” private/gated models will be inaccessible")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to("cuda")
@spaces.GPU(duration=60)
def gpu_generate_stream(inputs, history, temperature, top_p):
prompt_text = tokenizer.apply_chat_template(
inputs,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda")
history.append({"role": "assistant", "content": ""})
yield "", history
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=False,
)
gen_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=4096,
)
if temperature > 0:
gen_kwargs.update(temperature=temperature, top_p=top_p, do_sample=True)
else:
gen_kwargs.update(do_sample=False)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
stream_text = ""
gen_tk_count = 0
start_time = time.time()
for new_token_text in streamer:
if not new_token_text:
continue
stream_text += new_token_text
gen_tk_count += 1
elapsed = time.time() - start_time
token_per_sec = gen_tk_count / elapsed if elapsed > 0 else 0
display_text = stream2display_text(stream_text, token_per_sec)
history[-1]["content"] = display_text
yield "", history
thread.join()
history[-1]["content"] = stream_text.replace("<|im_end|>", "")
yield "", history
def gen_response_stream(message, history, temperature, top_p):
chat_msg_ls = organize_messages_from_messages(message, history)
history.append({"role": "user", "content": message})
yield from gpu_generate_stream(
chat_msg_ls, history,
temperature=temperature,
top_p=top_p,
)
def create_app():
assets_path = Path.cwd().absolute() / "assets"
gr.set_static_paths(paths=[assets_path])
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.HTML(
'<div class="logo-container">'
'<img src="/gradio_api/file=assets/OpenBMB-MiniCPM.png" alt="MiniCPM Logo">'
"</div>"
)
gr.HTML("<div style='height:1px;'></div>")
temperature = gr.Slider(
minimum=0, maximum=1, value=0.6, step=0.05, label="Temperature"
)
top_p = gr.Slider(
minimum=0, maximum=1, value=0.95, step=0.01, label="Top-p"
)
gr.HTML("<div style='height:128px;'></div>")
clear = gr.Button("Clear History")
with gr.Column(scale=4):
chatbot = gr.Chatbot(
label="Chat History",
placeholder="Input to start a new chat",
height=500,
)
prompt = gr.Textbox(
label="Input Text",
placeholder="Type your message here...",
lines=1,
elem_classes=["input-box"],
)
prompt.submit(
gen_response_stream,
inputs=[prompt, chatbot, temperature, top_p],
outputs=[prompt, chatbot],
)
clear.click(lambda: None, None, chatbot, queue=False)
return demo
THEME = gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
)
CSS = """
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 96px;
width: auto;
max-width: 200px;
display: inline-block;
}
.input-box {
border: 1px solid #2f63b8;
border-radius: 8px;
}
"""
demo = create_app()
if __name__ == "__main__":
demo.launch(theme=THEME, css=CSS)