mini-qwen3-0.5b-init / configuration_mini_qwen.py
1wrx1's picture
Initial random-init MiniQwen3 (~0.5B) checkpoint.
32de0d4 verified
Raw
History Blame Contribute Delete
2.94 kB
"""Configuration class for the MiniQwen3 standalone model.
Mirrors HF's `Qwen3Config` but with every architectural knob exposed
as a constructor kwarg, so this file is the single source of truth for
ablation experiments. Keep `MiniQwen3Config.model_type` unique so the
HF Auto* registry never silently falls back to the upstream Qwen3
implementation.
"""
from __future__ import annotations
from transformers import PretrainedConfig
class MiniQwen3Config(PretrainedConfig):
"""Standalone, easy-to-edit Qwen3-style transformer config.
Defaults target ~0.5B parameters (24L x d=1024, 16 attn heads,
GQA 2:1, SwiGLU MLP 3072, RoPE theta=1e6, tied embeddings, QK-Norm).
"""
model_type = "mini_qwen3"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = None
base_model_pp_plan = None
def __init__(
self,
vocab_size: int = 151936,
hidden_size: int = 1024,
intermediate_size: int = 3072,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
head_dim: int = 64,
hidden_act: str = "silu",
max_position_embeddings: int = 4096,
initializer_range: float = 0.02,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
rope_theta: float = 1_000_000.0,
attention_bias: bool = False,
mlp_bias: bool = False,
attention_dropout: float = 0.0,
tie_word_embeddings: bool = True,
use_qk_norm: bool = True,
pad_token_id: int | None = None,
bos_token_id: int | None = 151643,
eos_token_id: int | None = 151643,
**kwargs,
) -> None:
if num_attention_heads % num_key_value_heads != 0:
raise ValueError(
f"num_attention_heads ({num_attention_heads}) must be divisible "
f"by num_key_value_heads ({num_key_value_heads}) for GQA."
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.mlp_bias = mlp_bias
self.attention_dropout = attention_dropout
self.use_qk_norm = use_qk_norm
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)