""" config.py -- SpikeWhale: combined config from SpikeTransformer (My Project) + NanoWhale (DeepSeek-V4). Features carried from My Project (not in NanoWhale): - DERF attention: erf(alpha*score+bias)*gamma replaces softmax - XSA (Exclusive Self-Attention): orthogonality correction removes self-echo from attn output - Engram N-gram module: hash-table N-gram lookup with DERF gate injected into embeddings - Three-tier optimizer: embed/table params trained at lower LR Features carried from NanoWhale (not in My Project): - MLA (Multi-Head Latent Attention): low-rank Q projection + direct K,V (MQA) - Partial RoPE: rotary embeddings on only qk_rope_head_dim dims of Q and K - Low-rank grouped output projection (o_lora_rank) - Hyper-Connections: hc_mult residual streams with learned routing between layers - Shared expert in MoE (always-active expert alongside routed experts) - sqrtsoftplus expert scoring (vs softmax in My Project) - Hash-based routing for first num_hash_layers layers - norm_topk_prob + routed_scaling_factor - Multi-Token Prediction (MTP): extra heads predict k steps ahead - torch.compile, FineWeb-Edu streaming, Trackio, YAML configs in train.py """ from transformers import PretrainedConfig class SpikeWhaleConfig(PretrainedConfig): model_type = "spike_whale" def __init__( self, # Standard vocab_size: int = 16512, # SpikeTokenizer: 16384 base + 128 padded special slots hidden_size: int = 2048, num_hidden_layers: int = 11, max_position_embeddings: int = 4096, rms_norm_eps: float = 1e-6, initializer_range: float = 0.02, tie_word_embeddings: bool = False, hidden_dropout: float = 0.0, bos_token_id: int = 0, eos_token_id: int = 1, # MLA Attention (NanoWhale) num_attention_heads: int = 8, num_key_value_heads: int = 1, # 1 = MQA; >1 = GQA q_lora_rank: int = 160, # low-rank Q: hidden -> q_lora_rank -> num_heads*head_dim head_dim: int = 96, # total per-head dim = nope_head_dim + qk_rope_head_dim qk_rope_head_dim: int = 32, # RoPE applied only to these dims o_lora_rank: int = 80, # low-rank output: num_heads*head_dim -> o_lora_rank -> hidden attention_dropout: float = 0.0, rope_theta: float = 10000.0, # DERF + XSA (My Project) use_derf: bool = True, use_xsa: bool = True, # MoE (combined) use_moe: bool = True, moe_intermediate_size: int = 640, n_routed_experts: int = 4, n_shared_experts: int = 1, # NanoWhale: always-active shared expert num_experts_per_tok: int = 2, norm_topk_prob: bool = True, # NanoWhale: normalize top-k routing weights scoring_func: str = "sqrtsoftplus", # NanoWhale: sqrt(softplus(x)) vs softmax routed_scaling_factor: float = 1.0, # NanoWhale: scale routed expert weights num_hash_layers: int = 2, # NanoWhale: first N layers use hash routing moe_aux_loss_coef: float = 0.01, moe_layers: list = None, # Hyper-Connections (NanoWhale) use_hyper_connections: bool = True, hc_mult: int = 4, # number of parallel residual streams hc_sinkhorn_iters: int = 20, hc_eps: float = 1e-6, # Multi-Token Prediction (NanoWhale) num_nextn_predict_layers: int = 1, # extra MTP heads (0 = disabled) # Engram N-gram module (My Project) use_engram: bool = True, engram_compress_dim: int = 64, engram_num_heads: int = 4, engram_table_size: int = 8192, engram_max_ngram: int = 3, engram_gate_init_bias: float = -4.0, # HRM-inspired iterative refinement (EXPERIMENTAL; off by default). # Adds one small block that refines the final hidden state over N inner # steps before the output norm. This is the "iterative refinement" part # that the ARC-Prize ablation found carried most of HRM's benefit -- NOT # the full two-timescale H/L hierarchy. Honestly labeled HRM-inspired. use_hrm_refine: bool = False, hrm_refine_steps: int = 3, # inner refinement iterations hrm_refine_dim: int = 256, # bottleneck width of the refine MLP # --- v2 additions --- use_qk_norm: bool = True, # per-head RMSNorm on Q,K before RoPE zloss_coef: float = 1e-4, # log^2(Z) penalty on lm_head logits (0=off) mtp_loss_weight: float = 0.3, # down-weight for MTP CE loss use_value_embed: bool = False, # per-layer value-embedding residual (zero-init) **kwargs, ): super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.max_position_embeddings = max_position_embeddings self.rms_norm_eps = rms_norm_eps self.initializer_range = initializer_range self.hidden_dropout = hidden_dropout self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.q_lora_rank = q_lora_rank self.head_dim = head_dim self.qk_rope_head_dim = qk_rope_head_dim self.nope_head_dim = head_dim - qk_rope_head_dim self.o_lora_rank = o_lora_rank self.attention_dropout = attention_dropout self.rope_theta = rope_theta self.use_derf = use_derf self.use_xsa = use_xsa self.use_moe = use_moe self.moe_intermediate_size = moe_intermediate_size self.n_routed_experts = n_routed_experts self.n_shared_experts = n_shared_experts self.num_experts_per_tok = num_experts_per_tok self.norm_topk_prob = norm_topk_prob self.scoring_func = scoring_func self.routed_scaling_factor = routed_scaling_factor self.num_hash_layers = num_hash_layers self.moe_aux_loss_coef = moe_aux_loss_coef self.moe_layers = moe_layers if moe_layers is not None else list(range(num_hidden_layers)) self.use_hyper_connections = use_hyper_connections self.hc_mult = hc_mult self.hc_sinkhorn_iters = hc_sinkhorn_iters self.hc_eps = hc_eps self.num_nextn_predict_layers = num_nextn_predict_layers self.use_engram = use_engram self.engram_compress_dim = engram_compress_dim self.engram_num_heads = engram_num_heads self.engram_table_size = engram_table_size self.engram_max_ngram = engram_max_ngram self.engram_gate_init_bias = engram_gate_init_bias self.use_hrm_refine = use_hrm_refine self.hrm_refine_steps = hrm_refine_steps self.hrm_refine_dim = hrm_refine_dim self.use_qk_norm = use_qk_norm self.zloss_coef = zloss_coef self.mtp_loss_weight = mtp_loss_weight self.use_value_embed = use_value_embed