"""Standalone Qwen3-style transformer for from-scratch pretraining. This is intentionally a single, self-contained file: every architectural component (RMSNorm, rotary embeddings, GQA + QK-Norm attention, SwiGLU MLP, decoder layer, backbone, LM head) is defined here so an experiment that wants to swap one piece (e.g. QK-Norm off, MLP -> MoE, RoPE -> ALiBi) only needs to edit this file. Compatibility: - Drop-in replacement for `transformers.Qwen3ForCausalLM` at the `from_pretrained` / `save_pretrained` / `generate` API surface. - Uses HF's modern `Cache` interface (DynamicCache) for KV cache so `.generate()` works with the standard GenerationMixin. - Built and tested against `torch==2.7.1` + `transformers==5.7.0`. Save/load with custom code: MiniQwen3Config.register_for_auto_class() MiniQwen3ForCausalLM.register_for_auto_class("AutoModelForCausalLM") model.save_pretrained(out_dir) # also copies these .py files AutoModelForCausalLM.from_pretrained(out_dir, trust_remote_code=True) """ from __future__ import annotations from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from configuration_mini_qwen import MiniQwen3Config # --------------------------------------------------------------------------- # RMSNorm # --------------------------------------------------------------------------- class MiniQwen3RMSNorm(nn.Module): """Root-mean-square LayerNorm. Same numerics as `Qwen3RMSNorm`.""" def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self) -> str: return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" # --------------------------------------------------------------------------- # Rotary embeddings # --------------------------------------------------------------------------- class MiniQwen3RotaryEmbedding(nn.Module): """Standard RoPE with a configurable base (theta). Computes (cos, sin) on the fly so the same module handles arbitrary sequence lengths up to `max_position_embeddings` without precomputing a giant table. """ def __init__(self, head_dim: int, max_position_embeddings: int, base: float) -> None: super().__init__() if head_dim % 2 != 0: raise ValueError(f"head_dim must be even for RoPE, got {head_dim}") self.head_dim = head_dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / ( base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) @torch.no_grad() def forward( self, x: torch.Tensor, position_ids: torch.LongTensor ) -> Tuple[torch.Tensor, torch.Tensor]: # position_ids: (B, T) -> cos/sin: (B, T, head_dim) inv_freq = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) positions = position_ids[:, None, :].float() device_type = "cpu" if x.device.type == "mps" else x.device.type with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq @ positions).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(x.dtype), sin.to(x.dtype) def _rotate_half(x: torch.Tensor) -> torch.Tensor: half = x.shape[-1] // 2 return torch.cat((-x[..., half:], x[..., :half]), dim=-1) def apply_rotary_pos_emb( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # q/k: (B, H, T, D) cos/sin: (B, T, D) -> add head dim cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) q_rot = (q * cos) + (_rotate_half(q) * sin) k_rot = (k * cos) + (_rotate_half(k) * sin) return q_rot, k_rot def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """Tile KV heads to match Q heads for GQA. (B, KV_H, T, D) -> (B, Q_H, T, D).""" if n_rep == 1: return hidden_states b, kv_h, t, d = hidden_states.shape return ( hidden_states[:, :, None, :, :] .expand(b, kv_h, n_rep, t, d) .reshape(b, kv_h * n_rep, t, d) ) # --------------------------------------------------------------------------- # Attention (GQA + QK-Norm + RoPE, dispatched through SDPA) # --------------------------------------------------------------------------- class MiniQwen3Attention(nn.Module): def __init__(self, config: MiniQwen3Config, layer_idx: int) -> None: super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = config.head_dim self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.num_kv_groups = self.num_heads // self.num_kv_heads self.attention_dropout = config.attention_dropout self.scaling = self.head_dim**-0.5 q_out = self.num_heads * self.head_dim kv_out = self.num_kv_heads * self.head_dim self.q_proj = nn.Linear(config.hidden_size, q_out, bias=config.attention_bias) self.k_proj = nn.Linear(config.hidden_size, kv_out, bias=config.attention_bias) self.v_proj = nn.Linear(config.hidden_size, kv_out, bias=config.attention_bias) self.o_proj = nn.Linear(q_out, config.hidden_size, bias=config.attention_bias) if config.use_qk_norm: self.q_norm = MiniQwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = MiniQwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) else: self.q_norm = nn.Identity() self.k_norm = nn.Identity() def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: bsz, q_len, _ = hidden_states.shape q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) q = self.q_norm(q) k = self.k_norm(k) cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin) if past_key_value is not None: cache_kwargs = {"cache_position": cache_position} k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) k = repeat_kv(k, self.num_kv_groups) v = repeat_kv(v, self.num_kv_groups) is_causal = attention_mask is None and q_len > 1 attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, scale=self.scaling, ) attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) return self.o_proj(attn_output), None # --------------------------------------------------------------------------- # SwiGLU MLP # --------------------------------------------------------------------------- class MiniQwen3MLP(nn.Module): """SwiGLU FFN: down(silu(gate(x)) * up(x)).""" def __init__(self, config: MiniQwen3Config) -> None: super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) if config.hidden_act != "silu": raise ValueError( f"This implementation only supports hidden_act='silu', got {config.hidden_act!r}. " "Edit MiniQwen3MLP if you want a different activation." ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) # --------------------------------------------------------------------------- # Decoder layer (pre-norm + residual) # --------------------------------------------------------------------------- class MiniQwen3DecoderLayer(nn.Module): def __init__(self, config: MiniQwen3Config, layer_idx: int) -> None: super().__init__() self.input_layernorm = MiniQwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.self_attn = MiniQwen3Attention(config, layer_idx) self.post_attention_layernorm = MiniQwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MiniQwen3MLP(config) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_out, _ = self.self_attn( hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_value=past_key_value, cache_position=cache_position, ) hidden_states = residual + attn_out residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) return residual + hidden_states # --------------------------------------------------------------------------- # Backbone + LM head # --------------------------------------------------------------------------- class MiniQwen3PreTrainedModel(PreTrainedModel): config_class = MiniQwen3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MiniQwen3DecoderLayer"] _supports_sdpa = True _supports_cache_class = True _supports_static_cache = False def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, MiniQwen3RMSNorm): module.weight.data.fill_(1.0) class MiniQwen3Model(MiniQwen3PreTrainedModel): """Embedding + N decoder layers + final RMSNorm.""" def __init__(self, config: MiniQwen3Config) -> None: super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [MiniQwen3DecoderLayer(config, idx) for idx in range(config.num_hidden_layers)] ) self.norm = MiniQwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = MiniQwen3RotaryEmbedding( head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self) -> nn.Module: return self.embed_tokens def set_input_embeddings(self, value: nn.Module) -> None: self.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> BaseModelOutputWithPast: if (input_ids is None) == (inputs_embeds is None): raise ValueError("Specify exactly one of `input_ids` or `inputs_embeds`.") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache is None: use_cache = self.config.use_cache and not self.training if use_cache and past_key_values is None: past_key_values = DynamicCache() past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0 seq_len = inputs_embeds.shape[1] if cache_position is None: cache_position = torch.arange( past_seen, past_seen + seq_len, device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._build_attention_mask( attention_mask, inputs_embeds, past_seen ) position_embeddings = self.rotary_emb(inputs_embeds, position_ids) hidden_states = inputs_embeds for layer in self.layers: if self.gradient_checkpointing and self.training: hidden_states = torch.utils.checkpoint.checkpoint( layer.__call__, hidden_states, position_embeddings, causal_mask, past_key_values, cache_position, use_reentrant=False, ) else: hidden_states = layer( hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask, past_key_value=past_key_values, cache_position=cache_position, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) @staticmethod def _build_attention_mask( attention_mask: Optional[torch.Tensor], inputs_embeds: torch.Tensor, past_seen: int, ) -> Optional[torch.Tensor]: # No padding mask -> let SDPA's is_causal=True path handle the # causal triangle for us (cheaper, no allocation). if attention_mask is None: return None bsz, q_len = inputs_embeds.shape[:2] s_len = past_seen + q_len device = inputs_embeds.device dtype = inputs_embeds.dtype min_val = torch.finfo(dtype).min causal = torch.full((q_len, s_len), min_val, device=device, dtype=dtype) causal = torch.triu(causal, diagonal=past_seen + 1) causal = causal.unsqueeze(0).unsqueeze(0).expand(bsz, 1, -1, -1).clone() if attention_mask.dim() == 2: pad_mask = attention_mask[:, None, None, :].to(dtype) causal = causal + (1.0 - pad_mask) * min_val return causal class MiniQwen3ForCausalLM(MiniQwen3PreTrainedModel, GenerationMixin): # transformers >= 5.0 expects a {tied_key: source_key} dict, not a list. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: MiniQwen3Config) -> None: super().__init__(config) self.model = MiniQwen3Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self) -> nn.Module: return self.model.embed_tokens def set_input_embeddings(self, value: nn.Module) -> None: self.model.embed_tokens = value def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Module) -> None: self.lm_head = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: int = 0, **kwargs, ) -> CausalLMOutputWithPast: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, ) hidden_states = outputs.last_hidden_state slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None) ) logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous().float() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, )