Spaces:
Runtime error
Runtime error
File size: 2,408 Bytes
2a3fb06 1e8128a 2a3fb06 d70f24f 2a3fb06 d70f24f 1e8128a 2a3fb06 d70f24f 2a3fb06 1e8128a 2a3fb06 d70f24f 2a3fb06 d70f24f 2a3fb06 d70f24f 2a3fb06 d70f24f 2a3fb06 d70f24f 2a3fb06 d70f24f 2a3fb06 d70f24f 2a3fb06 fb348b2 b21e239 d70f24f b21e239 2a3fb06 d70f24f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | # MiniCPM5-1B Demo
import os
import logging
import threading
from typing import Generator
import spaces
import torch
from fastapi.responses import HTMLResponse
from gradio import Server
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from utils_chatbot import organize_messages
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODEL_PATH = "openbmb/MiniCPM5-1B"
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
logger.info("Logged in to Hugging Face Hub")
else:
logger.warning("HF_TOKEN not set β private/gated models will be inaccessible")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to("cuda")
demo = Server()
@demo.api()
@spaces.GPU(duration=60)
def predict(
message: str,
history: list[list] | None = None,
thinking_mode: bool = True,
temperature: float = 0.9,
top_p: float = 0.95,
) -> Generator[str, None, None]:
messages = organize_messages(message, history)
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=thinking_mode,
)
model_inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=False,
)
gen_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=4096,
)
if temperature > 0:
gen_kwargs.update(temperature=temperature, top_p=top_p, do_sample=True)
else:
gen_kwargs.update(do_sample=False)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
full_text = ""
for new_token_text in streamer:
if not new_token_text:
continue
full_text += new_token_text
yield full_text
thread.join()
@demo.get("/", response_class=HTMLResponse)
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return f.read()
if __name__ == "__main__":
demo.launch(show_error=True)
|