HY-MT1.5-1.8B_GPTQ_INT4-AX620E / gradio_cpp_backend.py
yongqiang
Initial AX620E axllm serve package
80ad90c
Raw
History Blame Contribute Delete
6.72 kB
import argparse
import socket
import json
import requests
import gradio as gr
DEFAULT_LANGUAGES = [
"English",
"Chinese",
"Japanese",
"Korean",
"French",
"German",
"Spanish",
"Italian",
"Portuguese",
"Russian",
"Arabic",
"Hindi",
"Bengali",
"Thai",
"Vietnamese",
"Indonesian",
"Turkish",
"Polish",
"Dutch",
"Swedish",
"Danish",
"Norwegian",
"Finnish",
"Greek",
"Czech",
"Hungarian",
"Romanian",
"Ukrainian",
"Malay",
"Filipino",
"Urdu",
"Hebrew",
"Persian",
]
def _get_ipv4_address() -> str:
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
return "127.0.0.1"
def build_prompt(source_text: str, target_language: str, use_zh_template: bool) -> str:
if use_zh_template:
return (
f"将以下文本翻译为{target_language},注意只需要输出翻译后的结果,不要额外解释:\n"
f"{source_text}"
)
return (
f"Translate the following segment into {target_language}, without additional explanation.\n"
f"{source_text}"
)
def create_demo(api_base: str, model_name: str):
def translate_stream(
text,
target_language,
use_zh_template,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
):
if not text or not text.strip():
yield ""
return
payload = {
"model": model_name,
"messages": [{"role": "user", "content": text.strip()}],
"stream": True,
"temperature": temperature,
"top_p": top_p,
"top_k": int(top_k),
"repetition_penalty": repetition_penalty,
"max_tokens": int(max_new_tokens),
"target_language": target_language,
"use_zh_template": bool(use_zh_template),
}
url = f"{api_base}/v1/chat/completions"
with requests.post(url, json=payload, stream=True, timeout=300) as resp:
resp.raise_for_status()
resp.encoding = "utf-8"
buffer = ""
for raw_line in resp.iter_lines(decode_unicode=False):
if not raw_line:
continue
try:
line = raw_line.decode("utf-8")
except Exception:
line = raw_line.decode("utf-8", errors="replace")
if line.startswith("data: "):
data = line[len("data: "):].strip()
else:
data = line.strip()
if data == "[DONE]":
break
if data:
try:
obj = json.loads(data)
delta = obj.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
buffer += content
yield buffer.strip()
except Exception:
continue
with gr.Blocks(title="HY-MT1.5-1.8B_GPTQ_INT4 Multilingual Translation (C++ Backend)") as demo:
gr.Markdown("## HY-MT1.5-1.8B_GPTQ_INT4 Multilingual Translation (C++ Backend)")
with gr.Group():
input_text = gr.Textbox(
label="Input Text",
placeholder="Please enter the text you want to translate...",
lines=6,
)
with gr.Group():
with gr.Row(equal_height=True):
target_language = gr.Dropdown(
choices=DEFAULT_LANGUAGES,
value="English",
label="Target Language",
)
use_zh_template = gr.Checkbox(
label="Use Chinese Prompt Template",
value=False,
)
with gr.Group():
with gr.Row(equal_height=True):
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.05,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.05,
label="Top-p",
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=20,
step=1,
label="Top-k",
)
with gr.Group():
with gr.Row(equal_height=True):
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=1.5,
value=1.05,
step=0.01,
label="Repetition Penalty",
)
max_new_tokens = gr.Slider(
minimum=1,
maximum=1024,
value=512,
step=1,
label="Max New Tokens",
)
translate_btn = gr.Button("Translate", variant="primary")
output_text = gr.Textbox(
label="Translation Result",
lines=6,
interactive=False,
)
translate_btn.click(
translate_stream,
inputs=[
input_text,
target_language,
use_zh_template,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
],
outputs=output_text,
)
return demo
def parse_args():
parser = argparse.ArgumentParser(description="HY-MT1.5-1.8B_GPTQ_INT4 Gradio Demo (C++ Backend)")
parser.add_argument("--api_base", type=str, default="http://127.0.0.1:8000")
parser.add_argument("--model", type=str, default="AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E")
parser.add_argument("--server_name", type=str, default="0.0.0.0")
parser.add_argument("--server_port", type=int, default=7860)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
app = create_demo(args.api_base, args.model)
ipv4 = _get_ipv4_address()
print(f"* Running on local URL: http://{ipv4}:{args.server_port}")
app.launch(server_name=args.server_name, server_port=args.server_port)