""" Ministral3 Neuron Model Support for optimum-neuron This module provides support for running Ministral3 models on AWS Neuron. Import this module before loading the model to register the custom classes. Usage: # Method 1: Import and use helper function from huggingface_hub import hf_hub_download exec(open(hf_hub_download("YOUR_USERNAME/ministral3-neuron", "ministral3_neuron.py")).read()) model, tokenizer = load_ministral3("YOUR_USERNAME/ministral3-neuron") # Method 2: Manual registration from huggingface_hub import hf_hub_download exec(open(hf_hub_download("YOUR_USERNAME/ministral3-neuron", "ministral3_neuron.py")).read()) from optimum.neuron import NeuronModelForCausalLM from transformers import AutoTokenizer model = NeuronModelForCausalLM.from_pretrained("YOUR_USERNAME/ministral3-neuron") tokenizer = AutoTokenizer.from_pretrained("YOUR_USERNAME/ministral3-neuron") """ import gc import logging import math from typing import Optional, Tuple import torch from torch import nn logger = logging.getLogger("Neuron") # ============================================================================= # Step 1: Register ministral3 in transformers CONFIG_MAPPING # ============================================================================= try: from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.mistral.configuration_mistral import MistralConfig if "ministral3" not in CONFIG_MAPPING: CONFIG_MAPPING.register("ministral3", MistralConfig) logger.info("Registered ministral3 in CONFIG_MAPPING") except Exception as e: logger.warning(f"Failed to register ministral3 config: {e}") # ============================================================================= # Step 2: Import optimum-neuron components # ============================================================================= try: from neuronx_distributed.parallel_layers.layers import ( ColumnParallelLinear, ParallelEmbedding, RowParallelLinear, ) from transformers.activations import ACT2FN from optimum.neuron.models.inference.backend.config import NxDNeuronConfig from optimum.neuron.models.inference.backend.modules.attention.attention_base import NeuronAttentionBase from optimum.neuron.models.inference.backend.modules.decoder import NxDDecoderModelForCausalLM, NxDModelForCausalLM from optimum.neuron.models.inference.backend.modules.rms_norm import NeuronRMSNorm from optimum.neuron.models.auto_model import register_neuron_model, _REGISTERED_NEURON_MODELS OPTIMUM_NEURON_AVAILABLE = True except ImportError as e: logger.warning(f"optimum-neuron not available: {e}") OPTIMUM_NEURON_AVAILABLE = False # ============================================================================= # Step 3: Define Ministral3 model components # ============================================================================= if OPTIMUM_NEURON_AVAILABLE: def convert_state_dict_to_fused_qkv(state_dict, cfg): """Concatenate qkv weights to Wqkv weight for fused qkv.""" for l in range(cfg.num_hidden_layers): state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( [ state_dict[f"layers.{l}.self_attn.q_proj.weight"], state_dict[f"layers.{l}.self_attn.k_proj.weight"], state_dict[f"layers.{l}.self_attn.v_proj.weight"], ], ) del state_dict[f"layers.{l}.self_attn.q_proj.weight"] del state_dict[f"layers.{l}.self_attn.k_proj.weight"] del state_dict[f"layers.{l}.self_attn.v_proj.weight"] gc.collect() return state_dict class YarnRotaryEmbedding(nn.Module): """YARN Rotary Embedding for extended context.""" def __init__( self, dim, max_position_embeddings=262144, base=1000000.0, factor=16.0, original_max_position_embeddings=16384, beta_fast=32.0, beta_slow=1.0, mscale=1.0, mscale_all_dim=1.0, ): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.factor = factor self.original_max_position_embeddings = original_max_position_embeddings self.beta_fast = beta_fast self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim self.register_buffer("inv_freq", None, persistent=False) self._mscale = None def _compute_mscale(self, scale): if self.mscale_all_dim: return (0.1 * math.log(scale) + 1.0) if scale > 1.0 else 1.0 return 1.0 def _yarn_find_correction_dim(self, num_rotations, dim, base, max_position_embeddings): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) def _yarn_find_correction_range(self, low_rot, high_rot, dim, base, max_position_embeddings): low = math.floor(self._yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil(self._yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) def _yarn_linear_ramp_mask(self, low, high, dim, device): if low == high: high += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32, device=device) - low) / (high - low) return torch.clamp(linear_func, 0.0, 1.0) @torch.no_grad() def forward(self, x, position_ids): if self.inv_freq is None: pos_freqs = self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=x.device) / self.dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (self.factor * pos_freqs) low, high = self._yarn_find_correction_range( self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings, ) inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(low, high, self.dim // 2, x.device) inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask self.inv_freq = inv_freq self._mscale = self._compute_mscale(self.factor) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() with torch.autocast(device_type=x.device.type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self._mscale sin = emb.sin() * self._mscale return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class NeuronMinistral3MLP(nn.Module): """MLP module for Ministral3.""" def __init__(self, config, neuron_config: NxDNeuronConfig): super().__init__() self.tp_degree = neuron_config.tp_degree self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.act_fn = ACT2FN[config.hidden_act] mlp_bias = getattr(config, "mlp_bias", False) self.gate_proj = ColumnParallelLinear( self.hidden_size, self.intermediate_size, bias=mlp_bias, gather_output=False, dtype=neuron_config.torch_dtype, pad=True, ) self.up_proj = ColumnParallelLinear( self.hidden_size, self.intermediate_size, bias=mlp_bias, gather_output=False, dtype=neuron_config.torch_dtype, pad=True, ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=mlp_bias, input_is_parallel=True, dtype=neuron_config.torch_dtype, pad=True, reduce_dtype=neuron_config.torch_dtype, ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class NeuronMinistral3Attention(NeuronAttentionBase): """Attention module for Ministral3 with YARN rotary embeddings.""" def __init__(self, config, neuron_config: NxDNeuronConfig, **kwargs): super().__init__(config, neuron_config, **kwargs) head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads rope_params = getattr(config, "rope_parameters", None) or getattr(config, "rope_scaling", None) or {} self.rotary_emb = YarnRotaryEmbedding( dim=head_dim, max_position_embeddings=config.max_position_embeddings, base=rope_params.get("rope_theta", getattr(config, "rope_theta", 1000000.0)), factor=rope_params.get("factor", 16.0), original_max_position_embeddings=rope_params.get("original_max_position_embeddings", 16384), beta_fast=rope_params.get("beta_fast", 32.0), beta_slow=rope_params.get("beta_slow", 1.0), mscale=rope_params.get("mscale", 1.0), mscale_all_dim=rope_params.get("mscale_all_dim", 1.0), ) class NeuronMinistral3DecoderLayer(nn.Module): """Decoder layer for Ministral3.""" def __init__(self, config, neuron_config: NxDNeuronConfig): super().__init__() self.hidden_size = config.hidden_size self.self_attn = NeuronMinistral3Attention(config, neuron_config) self.mlp = NeuronMinistral3MLP(config, neuron_config) self.input_layernorm = NeuronRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = NeuronRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.config = config def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, **kwargs): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, present_key_value, cos_cache, sin_cache class NxDMinistral3Model(NxDDecoderModelForCausalLM): """The neuron version of the Ministral3 model.""" def __init__(self, config, neuron_config: NxDNeuronConfig): super().__init__(config, neuron_config) self.embed_tokens = ParallelEmbedding( config.vocab_size, config.hidden_size, getattr(config, "pad_token_id", None), dtype=neuron_config.torch_dtype, shard_across_embedding=True, pad=True, ) self.lm_head = ColumnParallelLinear( config.hidden_size, config.vocab_size, gather_output=not neuron_config.on_device_sampling, bias=False, pad=True, ) self.layers = nn.ModuleList([ NeuronMinistral3DecoderLayer(config, neuron_config) for _ in range(config.num_hidden_layers) ]) self.norm = NeuronRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class Ministral3NxDModelForCausalLM(NxDModelForCausalLM): """Ministral3 model for causal language modeling on AWS Neuron.""" _model_cls = NxDMinistral3Model @staticmethod def convert_hf_to_neuron_state_dict(state_dict: dict, config, neuron_config: NxDNeuronConfig) -> dict: dict_keys = list(state_dict.keys()) # Handle prefixes from multimodal model for key in dict_keys: if key.startswith("language_model.model."): state_dict[key.replace("language_model.model.", "")] = state_dict.pop(key) elif key.startswith("language_model."): state_dict[key.replace("language_model.", "")] = state_dict.pop(key) elif key.startswith("model."): state_dict[key.replace("model.", "", 1)] = state_dict.pop(key) # Handle lm_head if "lm_head.weight" not in state_dict: for key in list(state_dict.keys()): if "lm_head" in key and key != "lm_head.weight": state_dict["lm_head.weight"] = state_dict.pop(key) break # Fuse QKV if neuron_config.fused_qkv: state_dict = convert_state_dict_to_fused_qkv(state_dict, config) # Add rank utilities for i in range(config.num_hidden_layers): state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange(0, neuron_config.tp_degree, dtype=torch.int32) state_dict["rank_util.rank"] = torch.arange(0, neuron_config.tp_degree, dtype=torch.int32) return state_dict @staticmethod def update_state_dict_for_tied_weights(state_dict): state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() @classmethod def _get_neuron_config(cls, checkpoint_id, checkpoint_revision, instance_type, batch_size, sequence_length, tensor_parallel_size, dtype): return NxDNeuronConfig( checkpoint_id=checkpoint_id, checkpoint_revision=checkpoint_revision, batch_size=batch_size, sequence_length=sequence_length, tp_degree=tensor_parallel_size, torch_dtype=dtype, target=instance_type, on_device_sampling=True, fused_qkv=True, continuous_batching=(batch_size > 1) if batch_size else False, ) # ============================================================================= # Step 4: Register the model in optimum-neuron # ============================================================================= @register_neuron_model("ministral3", "text-generation", "inference") class Ministral3NeuronModelForCausalLM(Ministral3NxDModelForCausalLM): """Ministral3 model with NxD backend for inference on AWS Neuron.""" pass logger.info("Registered Ministral3NeuronModelForCausalLM in optimum-neuron") # ============================================================================= # Step 5: Helper function for easy loading # ============================================================================= def load_ministral3(model_id: str, **kwargs): """ Load a Ministral3 model from HuggingFace. Args: model_id: HuggingFace model ID or local path **kwargs: Additional arguments passed to from_pretrained Returns: tuple: (model, tokenizer) """ from optimum.neuron import NeuronModelForCausalLM from transformers import AutoTokenizer model = NeuronModelForCausalLM.from_pretrained(model_id, **kwargs) tokenizer = AutoTokenizer.from_pretrained(model_id) return model, tokenizer print("Ministral3 Neuron support loaded successfully!") print("Use load_ministral3('model_id') to load the model.")