yongqiang
Initial AX620E axllm serve package
80ad90c
Raw
History Blame Contribute Delete
8.67 kB
import argparse
import os
import socket
import numpy as np
import gradio as gr
from transformers import AutoConfig, AutoTokenizer
from ml_dtypes import bfloat16
from utils.infer_func import InferManager
DEFAULT_LANGUAGES = [
"English",
"Chinese",
"Japanese",
"Korean",
"French",
"German",
"Spanish",
"Italian",
"Portuguese",
"Russian",
"Arabic",
"Hindi",
"Bengali",
"Thai",
"Vietnamese",
"Indonesian",
"Turkish",
"Polish",
"Dutch",
"Swedish",
"Danish",
"Norwegian",
"Finnish",
"Greek",
"Czech",
"Hungarian",
"Romanian",
"Ukrainian",
"Malay",
"Filipino",
"Urdu",
"Hebrew",
"Persian",
]
def build_prompt(source_text: str, target_language: str, use_zh_template: bool) -> str:
if use_zh_template:
return (
f"将以下文本翻译为{target_language},注意只需要输出翻译后的结果,不要额外解释:\n"
f"{source_text}"
)
return (
f"Translate the following segment into {target_language}, without additional explanation.\n"
f"{source_text}"
)
def _get_ipv4_address() -> str:
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
return "127.0.0.1"
def create_demo(hf_model: str, axmodel_path: str, max_seq_len=2047):
embeds_path = os.path.join(axmodel_path, "model.embed_tokens.weight.npy")
if not os.path.exists(embeds_path):
raise FileNotFoundError(f"Missing embeddings file: {embeds_path}")
tokenizer = AutoTokenizer.from_pretrained(hf_model)
config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True)
embeds = np.load(embeds_path)
eos_token_id = None
if isinstance(config.eos_token_id, list) and len(config.eos_token_id) > 1:
eos_token_id = config.eos_token_id
imer = InferManager(config, axmodel_path, max_seq_len=max_seq_len)
def translate(
text,
target_language,
use_zh_template,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
):
if not text or not text.strip():
return ""
prompt = build_prompt(text.strip(), target_language, use_zh_template)
messages = [{"role": "user", "content": prompt}]
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
)
token_ids = input_ids[0].cpu().numpy().tolist()
prefill_data = np.take(embeds, token_ids, axis=0).astype(bfloat16)
slice_len = 128
token_ids = imer.prefill(
tokenizer,
token_ids,
prefill_data,
slice_len=slice_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
output = imer.decode(
tokenizer,
token_ids,
embeds,
slice_len=slice_len,
eos_token_id=eos_token_id,
stream=False,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
)
return output.strip()
def translate_stream(
text,
target_language,
use_zh_template,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
):
if not text or not text.strip():
yield ""
return
prompt = build_prompt(text.strip(), target_language, use_zh_template)
messages = [{"role": "user", "content": prompt}]
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
)
token_ids = input_ids[0].cpu().numpy().tolist()
prefill_data = np.take(embeds, token_ids, axis=0).astype(bfloat16)
slice_len = 128
token_ids = imer.prefill(
tokenizer,
token_ids,
prefill_data,
slice_len=slice_len,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
for text_so_far in imer.decode_stream(
tokenizer,
token_ids,
embeds,
slice_len=slice_len,
eos_token_id=eos_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
):
yield text_so_far
with gr.Blocks(title="HY-MT1.5-1.8B_GPTQ_INT4 Multilingual Translation") as demo:
gr.Markdown("## HY-MT1.5-1.8B_GPTQ_INT4 Multilingual Translation")
with gr.Group():
input_text = gr.Textbox(
label="Input Text",
placeholder="Please enter the text you want to translate...",
lines=6,
)
with gr.Group():
with gr.Row(equal_height=True):
target_language = gr.Dropdown(
choices=DEFAULT_LANGUAGES,
value="English",
label="Target Language",
)
use_zh_template = gr.Checkbox(
label="Use Chinese Prompt Template",
value=False,
)
with gr.Group():
with gr.Row(equal_height=True):
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.05,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.05,
label="Top-p",
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=20,
step=1,
label="Top-k",
)
with gr.Group():
with gr.Row(equal_height=True):
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=1.5,
value=1.05,
step=0.01,
label="Repetition Penalty",
)
max_new_tokens = gr.Slider(
minimum=1,
maximum=1024,
value=512,
step=1,
label="Max New Tokens",
)
translate_btn = gr.Button("Translate", variant="primary")
output_text = gr.Textbox(
label="Translation Result",
lines=6,
interactive=False,
)
translate_btn.click(
translate_stream,
inputs=[
input_text,
target_language,
use_zh_template,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
],
outputs=output_text,
)
return demo
def parse_args():
parser = argparse.ArgumentParser(description="HY-MT1.5-1.8B_GPTQ_INT4 Gradio Demo")
parser.add_argument(
"--hf_model",
type=str,
default="./hymt1-5_tokenizer",
help="Path to HuggingFace model",
)
parser.add_argument(
"--axmodel_path",
type=str,
default="./hymt1-5_2k_ax650n_axmodel/",
help="Path to compiled axmodel directory",
)
parser.add_argument(
"--max_seq_len",
type=int,
default=2047,
help="Maximum sequence length for model inference",
)
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--server_port", type=int, default=7860)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
app = create_demo(args.hf_model, args.axmodel_path, args.max_seq_len)
ipv4 = _get_ipv4_address()
print(f"* Running on local URL: http://{ipv4}:{args.server_port}")
app.launch(server_name=args.server_name, server_port=args.server_port)