Instructions to use AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "translation" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("translation", model="AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import os | |
| import socket | |
| import numpy as np | |
| import gradio as gr | |
| from transformers import AutoConfig, AutoTokenizer | |
| from ml_dtypes import bfloat16 | |
| from utils.infer_func import InferManager | |
| DEFAULT_LANGUAGES = [ | |
| "English", | |
| "Chinese", | |
| "Japanese", | |
| "Korean", | |
| "French", | |
| "German", | |
| "Spanish", | |
| "Italian", | |
| "Portuguese", | |
| "Russian", | |
| "Arabic", | |
| "Hindi", | |
| "Bengali", | |
| "Thai", | |
| "Vietnamese", | |
| "Indonesian", | |
| "Turkish", | |
| "Polish", | |
| "Dutch", | |
| "Swedish", | |
| "Danish", | |
| "Norwegian", | |
| "Finnish", | |
| "Greek", | |
| "Czech", | |
| "Hungarian", | |
| "Romanian", | |
| "Ukrainian", | |
| "Malay", | |
| "Filipino", | |
| "Urdu", | |
| "Hebrew", | |
| "Persian", | |
| ] | |
| def build_prompt(source_text: str, target_language: str, use_zh_template: bool) -> str: | |
| if use_zh_template: | |
| return ( | |
| f"将以下文本翻译为{target_language},注意只需要输出翻译后的结果,不要额外解释:\n" | |
| f"{source_text}" | |
| ) | |
| return ( | |
| f"Translate the following segment into {target_language}, without additional explanation.\n" | |
| f"{source_text}" | |
| ) | |
| def _get_ipv4_address() -> str: | |
| try: | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| s.connect(("8.8.8.8", 80)) | |
| ip = s.getsockname()[0] | |
| s.close() | |
| return ip | |
| except Exception: | |
| return "127.0.0.1" | |
| def create_demo(hf_model: str, axmodel_path: str, max_seq_len=2047): | |
| embeds_path = os.path.join(axmodel_path, "model.embed_tokens.weight.npy") | |
| if not os.path.exists(embeds_path): | |
| raise FileNotFoundError(f"Missing embeddings file: {embeds_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(hf_model) | |
| config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True) | |
| embeds = np.load(embeds_path) | |
| eos_token_id = None | |
| if isinstance(config.eos_token_id, list) and len(config.eos_token_id) > 1: | |
| eos_token_id = config.eos_token_id | |
| imer = InferManager(config, axmodel_path, max_seq_len=max_seq_len) | |
| def translate( | |
| text, | |
| target_language, | |
| use_zh_template, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| max_new_tokens, | |
| ): | |
| if not text or not text.strip(): | |
| return "" | |
| prompt = build_prompt(text.strip(), target_language, use_zh_template) | |
| messages = [{"role": "user", "content": prompt}] | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=False, | |
| return_tensors="pt", | |
| ) | |
| token_ids = input_ids[0].cpu().numpy().tolist() | |
| prefill_data = np.take(embeds, token_ids, axis=0).astype(bfloat16) | |
| slice_len = 128 | |
| token_ids = imer.prefill( | |
| tokenizer, | |
| token_ids, | |
| prefill_data, | |
| slice_len=slice_len, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| output = imer.decode( | |
| tokenizer, | |
| token_ids, | |
| embeds, | |
| slice_len=slice_len, | |
| eos_token_id=eos_token_id, | |
| stream=False, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| return output.strip() | |
| def translate_stream( | |
| text, | |
| target_language, | |
| use_zh_template, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| max_new_tokens, | |
| ): | |
| if not text or not text.strip(): | |
| yield "" | |
| return | |
| prompt = build_prompt(text.strip(), target_language, use_zh_template) | |
| messages = [{"role": "user", "content": prompt}] | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=False, | |
| return_tensors="pt", | |
| ) | |
| token_ids = input_ids[0].cpu().numpy().tolist() | |
| prefill_data = np.take(embeds, token_ids, axis=0).astype(bfloat16) | |
| slice_len = 128 | |
| token_ids = imer.prefill( | |
| tokenizer, | |
| token_ids, | |
| prefill_data, | |
| slice_len=slice_len, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| for text_so_far in imer.decode_stream( | |
| tokenizer, | |
| token_ids, | |
| embeds, | |
| slice_len=slice_len, | |
| eos_token_id=eos_token_id, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| max_new_tokens=max_new_tokens, | |
| ): | |
| yield text_so_far | |
| with gr.Blocks(title="HY-MT1.5-1.8B_GPTQ_INT4 Multilingual Translation") as demo: | |
| gr.Markdown("## HY-MT1.5-1.8B_GPTQ_INT4 Multilingual Translation") | |
| with gr.Group(): | |
| input_text = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Please enter the text you want to translate...", | |
| lines=6, | |
| ) | |
| with gr.Group(): | |
| with gr.Row(equal_height=True): | |
| target_language = gr.Dropdown( | |
| choices=DEFAULT_LANGUAGES, | |
| value="English", | |
| label="Target Language", | |
| ) | |
| use_zh_template = gr.Checkbox( | |
| label="Use Chinese Prompt Template", | |
| value=False, | |
| ) | |
| with gr.Group(): | |
| with gr.Row(equal_height=True): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.05, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.6, | |
| step=0.05, | |
| label="Top-p", | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=20, | |
| step=1, | |
| label="Top-k", | |
| ) | |
| with gr.Group(): | |
| with gr.Row(equal_height=True): | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=1.5, | |
| value=1.05, | |
| step=0.01, | |
| label="Repetition Penalty", | |
| ) | |
| max_new_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=1024, | |
| value=512, | |
| step=1, | |
| label="Max New Tokens", | |
| ) | |
| translate_btn = gr.Button("Translate", variant="primary") | |
| output_text = gr.Textbox( | |
| label="Translation Result", | |
| lines=6, | |
| interactive=False, | |
| ) | |
| translate_btn.click( | |
| translate_stream, | |
| inputs=[ | |
| input_text, | |
| target_language, | |
| use_zh_template, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| max_new_tokens, | |
| ], | |
| outputs=output_text, | |
| ) | |
| return demo | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="HY-MT1.5-1.8B_GPTQ_INT4 Gradio Demo") | |
| parser.add_argument( | |
| "--hf_model", | |
| type=str, | |
| default="./hymt1-5_tokenizer", | |
| help="Path to HuggingFace model", | |
| ) | |
| parser.add_argument( | |
| "--axmodel_path", | |
| type=str, | |
| default="./hymt1-5_2k_ax650n_axmodel/", | |
| help="Path to compiled axmodel directory", | |
| ) | |
| parser.add_argument( | |
| "--max_seq_len", | |
| type=int, | |
| default=2047, | |
| help="Maximum sequence length for model inference", | |
| ) | |
| parser.add_argument("--server_name", type=str, default="0.0.0.0") | |
| parser.add_argument("--server_port", type=int, default=7860) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| app = create_demo(args.hf_model, args.axmodel_path, args.max_seq_len) | |
| ipv4 = _get_ipv4_address() | |
| print(f"* Running on local URL: http://{ipv4}:{args.server_port}") | |
| app.launch(server_name=args.server_name, server_port=args.server_port) | |