Instructions to use AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E with Transformers:
# Use a pipeline as a high-level helper # Warning: Pipeline type "translation" is no longer supported in transformers v5. # You must load the model directly (see below) or downgrade to v4.x with: # 'pip install "transformers<5.0.0' from transformers import pipeline pipe = pipeline("translation", model="AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("AXERA-TECH/HY-MT1.5-1.8B_GPTQ_INT4-AX620E", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import numpy as np | |
| from typing import List, Tuple | |
| from tqdm import tqdm | |
| from axengine import InferenceSession | |
| import os | |
| import re | |
| from ml_dtypes import bfloat16 | |
| # Discover model files automatically from model_dir. | |
| # We expect files like: <prefix>_p128_l<idx>_together.axmodel and <prefix>_post.axmodel | |
| # we try to detect model prefix and layer files automatically | |
| def _find_axmodel_files(base_dir: str, expected_layers: int = None, expected_prefill: int = 128): | |
| files = os.listdir(base_dir) | |
| # match prefix, prefill size (dynamic), and layer index | |
| layer_pattern = re.compile(r"^(?P<prefix>.*)_p(?P<prefill>\d+)_l(?P<idx>\d+)_together\.axmodel$") | |
| post_pattern = re.compile(r"^(?P<prefix>.*)_post\.axmodel$") | |
| # collect prefix -> [(idx, fname)] | |
| prefix_map = {} | |
| for fname in files: | |
| m = layer_pattern.match(fname) | |
| if m: | |
| prefix = m.group("prefix") | |
| idx = int(m.group("idx")) | |
| prefix_map.setdefault(prefix, []).append((idx, fname)) | |
| if not prefix_map: | |
| # fallback to hardcoded pattern if nothing detected | |
| prefix = "gemma3_text" | |
| layer_files = [( | |
| i, f"{prefix}_p{expected_prefill}_l{i}_together.axmodel" | |
| ) for i in range(expected_layers or 0)] | |
| else: | |
| # choose the prefix with the most layers (most likely the correct one) | |
| prefix = max(prefix_map.items(), key=lambda kv: len(kv[1]))[0] | |
| # debug info | |
| print(f"Detected prefixes: {list(prefix_map.keys())}, chosen: {prefix}, layers: {len(prefix_map[prefix])}") | |
| layer_files = sorted(prefix_map[prefix], key=lambda it: it[0]) | |
| # find post process file | |
| post_file = None | |
| for fname in files: | |
| m = post_pattern.match(fname) | |
| if m and m.group("prefix") == prefix: | |
| post_file = fname | |
| break | |
| if post_file is None: | |
| candidate = os.path.join(base_dir, f"{prefix}_post.axmodel") | |
| if os.path.exists(candidate): | |
| post_file = f"{prefix}_post.axmodel" | |
| else: | |
| for fname in files: | |
| if fname.endswith("_post.axmodel"): | |
| post_file = fname | |
| break | |
| return layer_files, post_file, prefix | |
| class InferManager: | |
| def __init__(self, config, model_dir, max_seq_len=2047): | |
| self.config = config | |
| self.max_seq_len = max_seq_len | |
| self.sub_dim = config.hidden_size // config.num_attention_heads if not config.head_dim else config.head_dim | |
| self.kv_dim = self.sub_dim * config.num_key_value_heads | |
| self.k_caches = [ | |
| np.zeros((1, self.max_seq_len, self.kv_dim), dtype=bfloat16) | |
| for _ in range(config.num_hidden_layers) | |
| ] | |
| self.v_caches = [ | |
| np.zeros((1, self.max_seq_len, self.kv_dim), dtype=bfloat16) | |
| for _ in range(config.num_hidden_layers) | |
| ] | |
| layer_files, post_file, prefix = _find_axmodel_files(model_dir, config.num_hidden_layers) | |
| self.decoder_sessions = [] | |
| for _, fname in tqdm(layer_files, desc="Init InferenceSession"): | |
| session = InferenceSession(os.path.join(model_dir, fname)) | |
| self.decoder_sessions.append(session) | |
| # post_file was returned by _find_axmodel_files; ensure it was found | |
| if post_file is None: | |
| raise FileNotFoundError("Cannot find post process .axmodel file in model_dir") | |
| self.post_process_session = InferenceSession(os.path.join(model_dir, post_file)) | |
| print("Model loaded successfully!") | |
| def _top_p(probs: np.ndarray, p: float) -> np.ndarray: | |
| sorted_indices = np.argsort(probs) | |
| filtered = probs.copy() | |
| cumulative = 0 | |
| for idx in sorted_indices[::-1]: | |
| if cumulative >= p: | |
| filtered[idx] = 0 | |
| cumulative += filtered[idx] | |
| return filtered / cumulative | |
| def _softmax(logits: np.ndarray) -> np.ndarray: | |
| logits = logits - logits.max() | |
| exp_logits = np.exp(logits) | |
| return (exp_logits / np.sum(exp_logits)).astype(np.float64) | |
| def post_process( | |
| self, | |
| logits, | |
| top_k=1, | |
| top_p=0.9, | |
| temperature=0.6, | |
| repetition_penalty=1.0, | |
| token_ids=None, | |
| ): | |
| logits = logits.astype(np.float32).flatten() | |
| if repetition_penalty is not None and repetition_penalty != 1.0 and token_ids: | |
| for t in set(token_ids): | |
| if 0 <= t < logits.size: | |
| if logits[t] < 0: | |
| logits[t] *= repetition_penalty | |
| else: | |
| logits[t] /= repetition_penalty | |
| top_k = max(1, min(int(top_k), logits.size)) | |
| temperature = max(float(temperature), 1e-6) | |
| top_p = min(max(float(top_p), 1e-6), 1.0) | |
| candidate_indices = np.argpartition(logits, -top_k)[-top_k:] | |
| candidate_logits = logits[candidate_indices] / temperature | |
| candidate_probs = self._softmax(candidate_logits) | |
| candidate_probs = self._top_p(candidate_probs, top_p) | |
| candidate_probs = candidate_probs.astype(np.float64) / candidate_probs.sum() | |
| chosen_idx = np.random.multinomial(1, candidate_probs).argmax() | |
| next_token = candidate_indices[chosen_idx] | |
| return next_token, candidate_indices, candidate_probs | |
| def gen_slice_indices(self, token_len, prefill=128, expand=128): | |
| remaining = max(0, token_len - prefill) | |
| extra_blocks = (remaining + expand - 1) // expand | |
| return list(range(extra_blocks + 1)) | |
| def prefill( | |
| self, | |
| tokenizer, | |
| token_ids, | |
| embed_data, | |
| slice_len=128, | |
| top_k=1, | |
| top_p=0.9, | |
| temperature=0.6, | |
| repetition_penalty=1.0, | |
| ): | |
| """ | |
| Prefill step for chunked inference. | |
| """ | |
| seq_len = len(token_ids) | |
| slice_indices = [i for i in range(seq_len // slice_len + 1)] | |
| print(f"slice_indices: {slice_indices}") | |
| # total_prefill_len = ( | |
| # slice_len * slice_indices[-1] | |
| # if slice_indices[-1] != 0 | |
| # else slice_len | |
| # ) | |
| total_prefill_len = slice_len * (slice_indices[-1] + 1) | |
| # slice_indices = self.gen_slice_indices(seq_len) | |
| if total_prefill_len > 0: | |
| for slice_idx in slice_indices: | |
| indices = np.arange( | |
| slice_idx * slice_len, | |
| (slice_idx + 1) * slice_len, | |
| dtype=np.uint32 | |
| ).reshape((1, slice_len)) | |
| mask = ( | |
| np.zeros((1, slice_len, slice_len * (slice_idx + 1))) | |
| - 65536 | |
| ) | |
| data = np.zeros((1, slice_len, self.config.hidden_size)).astype(bfloat16) | |
| for i, t in enumerate( | |
| range( | |
| slice_idx * slice_len, | |
| (slice_idx + 1) * slice_len, | |
| ) | |
| ): | |
| if t < len(token_ids): | |
| mask[:, i, : slice_idx * slice_len + i + 1] = 0 | |
| data[:, i : i + 1, :] = ( | |
| embed_data[t] | |
| .reshape((1, 1, self.config.hidden_size)) | |
| .astype(bfloat16) | |
| ) | |
| remain_len = ( | |
| seq_len - slice_idx * slice_len | |
| if slice_idx == slice_indices[-1] | |
| else slice_len | |
| ) | |
| mask = mask.astype(bfloat16) | |
| for layer_idx in range(self.config.num_hidden_layers): | |
| input_feed = { | |
| "K_cache": ( | |
| self.k_caches[layer_idx][:, 0 : slice_len * slice_idx, :] | |
| if slice_idx | |
| else np.zeros((1, 1, self.config.hidden_size), dtype=bfloat16) | |
| ), | |
| "V_cache": ( | |
| self.v_caches[layer_idx][:, 0 : slice_len * slice_idx, :] | |
| if slice_idx | |
| else np.zeros((1, 1, self.config.hidden_size), dtype=bfloat16) | |
| ), | |
| "indices": indices, | |
| "input": data, | |
| "mask": mask, | |
| } | |
| outputs = self.decoder_sessions[layer_idx].run(None, input_feed, shape_group=slice_idx + 1) | |
| self.k_caches[layer_idx][ | |
| :, | |
| slice_idx * slice_len : slice_idx * slice_len + remain_len, | |
| :, | |
| ] = outputs[0][:, :remain_len, :] | |
| self.v_caches[layer_idx][ | |
| :, | |
| slice_idx * slice_len : slice_idx * slice_len + remain_len, | |
| :, | |
| ] = outputs[1][:, :remain_len, :] | |
| data = outputs[2] | |
| print("Slice prefill done:", slice_idx) | |
| # return data[:, :remain_len, :] | |
| post_out = self.post_process_session.run( | |
| None, | |
| { | |
| "input": data[ | |
| :, seq_len - (len(slice_indices) - 1) * slice_len - 1, None, : | |
| ] | |
| } | |
| )[0] | |
| next_token, possible_tokens, possible_probs = self.post_process( | |
| post_out, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| token_ids=token_ids, | |
| ) | |
| possible_decoded = [tokenizer.decode([t]) for t in possible_tokens] | |
| possible_probs_str = [str((t, p)) for t, p in zip(possible_decoded, possible_probs)] | |
| token_ids.append(next_token) | |
| return token_ids | |
| def decode( | |
| self, | |
| tokenizer, | |
| token_ids, | |
| embed_matrix, | |
| prefill_len=128, | |
| slice_len=128, | |
| eos_token_id=None, # 某些模型有多个 eos_token_id | |
| stream=True, | |
| top_k=1, | |
| top_p=0.9, | |
| temperature=0.6, | |
| repetition_penalty=1.0, | |
| max_new_tokens=None, | |
| stream_callback=None, | |
| ): | |
| """Autoregressive decode; optionally stream tokens or collect silently.""" | |
| seq_len = len(token_ids) - 1 | |
| prompt_len = seq_len | |
| mask = np.zeros((1, 1, self.max_seq_len + 1), dtype=np.float32).astype(bfloat16) | |
| mask[:, :, :self.max_seq_len] -= 65536 | |
| decoded_text = tokenizer.decode(token_ids[prompt_len:], skip_special_tokens=True) | |
| prev_decoded_text = decoded_text | |
| if stream: | |
| print("answer >>", decoded_text, end='', flush=True) | |
| if stream_callback is not None: | |
| stream_callback(decoded_text) | |
| if prefill_len > 0: | |
| mask[:, :, :seq_len] = 0 | |
| max_new_tokens = self.max_seq_len if max_new_tokens is None else int(max_new_tokens) | |
| generated = 0 | |
| for step_idx in range(self.max_seq_len): | |
| if prefill_len > 0 and step_idx < seq_len: | |
| continue | |
| cur_token = token_ids[step_idx] | |
| indices = np.array([step_idx], np.uint32).reshape((1, 1)) | |
| data = embed_matrix[cur_token, :].reshape((1, 1, self.config.hidden_size)).astype(bfloat16) | |
| for layer_idx in range(self.config.num_hidden_layers): | |
| input_feed = { | |
| "K_cache": self.k_caches[layer_idx], | |
| "V_cache": self.v_caches[layer_idx], | |
| "indices": indices, | |
| "input": data, | |
| "mask": mask, | |
| } | |
| outputs = self.decoder_sessions[layer_idx].run(None, input_feed, shape_group=0) | |
| self.k_caches[layer_idx][:, step_idx, :] = outputs[0][:, :, :] | |
| self.v_caches[layer_idx][:, step_idx, :] = outputs[1][:, :, :] | |
| data = outputs[2] | |
| mask[..., step_idx] = 0 | |
| if step_idx < seq_len - 1: | |
| continue | |
| else: | |
| post_out = self.post_process_session.run(None, {"input": data})[0] | |
| next_token, possible_tokens, possible_probs = self.post_process( | |
| post_out, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| token_ids=token_ids, | |
| ) | |
| if eos_token_id is not None and next_token in eos_token_id: | |
| break | |
| elif next_token == tokenizer.eos_token_id: | |
| break | |
| token_ids.append(next_token) | |
| generated += 1 | |
| if generated >= max_new_tokens: | |
| break | |
| decoded_text = tokenizer.decode(token_ids[prompt_len:], skip_special_tokens=True) | |
| if stream: | |
| if decoded_text.startswith(prev_decoded_text): | |
| print(decoded_text[len(prev_decoded_text):], end='', flush=True) | |
| else: | |
| print(decoded_text, end='', flush=True) | |
| if stream_callback is not None: | |
| stream_callback(decoded_text) | |
| prev_decoded_text = decoded_text | |
| return decoded_text | |
| def decode_stream( | |
| self, | |
| tokenizer, | |
| token_ids, | |
| embed_matrix, | |
| prefill_len=128, | |
| slice_len=128, | |
| eos_token_id=None, # 某些模型有多个 eos_token_id | |
| top_k=1, | |
| top_p=0.9, | |
| temperature=0.6, | |
| repetition_penalty=1.0, | |
| max_new_tokens=None, | |
| ): | |
| seq_len = len(token_ids) - 1 | |
| prompt_len = seq_len | |
| decoded_text = tokenizer.decode(token_ids[prompt_len:], skip_special_tokens=True) | |
| yield decoded_text | |
| mask = np.zeros((1, 1, self.max_seq_len + 1), dtype=np.float32).astype(bfloat16) | |
| mask[:, :, :self.max_seq_len] -= 65536 | |
| seq_len = len(token_ids) - 1 | |
| prompt_len = seq_len | |
| if prefill_len > 0: | |
| mask[:, :, :seq_len] = 0 | |
| max_new_tokens = self.max_seq_len if max_new_tokens is None else int(max_new_tokens) | |
| generated = 0 | |
| for step_idx in range(self.max_seq_len): | |
| if prefill_len > 0 and step_idx < seq_len: | |
| continue | |
| cur_token = token_ids[step_idx] | |
| indices = np.array([step_idx], np.uint32).reshape((1, 1)) | |
| data = embed_matrix[cur_token, :].reshape((1, 1, self.config.hidden_size)).astype(bfloat16) | |
| for layer_idx in range(self.config.num_hidden_layers): | |
| input_feed = { | |
| "K_cache": self.k_caches[layer_idx], | |
| "V_cache": self.v_caches[layer_idx], | |
| "indices": indices, | |
| "input": data, | |
| "mask": mask, | |
| } | |
| outputs = self.decoder_sessions[layer_idx].run(None, input_feed, shape_group=0) | |
| self.k_caches[layer_idx][:, step_idx, :] = outputs[0][:, :, :] | |
| self.v_caches[layer_idx][:, step_idx, :] = outputs[1][:, :, :] | |
| data = outputs[2] | |
| mask[..., step_idx] = 0 | |
| if step_idx < seq_len - 1: | |
| continue | |
| else: | |
| post_out = self.post_process_session.run(None, {"input": data})[0] | |
| next_token, possible_tokens, possible_probs = self.post_process( | |
| post_out, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| token_ids=token_ids, | |
| ) | |
| if eos_token_id is not None and next_token in eos_token_id: | |
| break | |
| elif next_token == tokenizer.eos_token_id: | |
| break | |
| token_ids.append(next_token) | |
| generated += 1 | |
| if generated >= max_new_tokens: | |
| break | |
| decoded_text = tokenizer.decode(token_ids[prompt_len:], skip_special_tokens=True) | |
| yield decoded_text | |