"""FinanceDecoder — from-scratch transformer for the v2 frontier trading model. Design choices, all v2-doc-aligned: - Pure decoder transformer with GQA (Grouped-Query Attention), v2 §2.2. - RMSNorm, SwiGLU MLP (Llama/Mistral lineage; cheap, well-understood). - RoPE positional encoding (no learned PE). - Multi-Token Prediction heads (v2 §2.3): predict next K tokens jointly during pretraining. K=4 by default. Doubles as draft model for speculative decoding at inference. - Decision head (Invention A): a small classifier head reading the final hidden state and predicting (action ∈ {BUY,SELL,HOLD,NO_TRADE}). Active only when the row is a v52 SFT row that carries a labeled action. - bf16-friendly throughout. No HF Transformers. No pretrained weights. - Pre-norm. SwiGLU. weight tying between embedding and lm_head. Param budgets (verified analytically): 1B config: d=2048 n_layer=22 n_head=16 n_kv=4 d_ff=5504 v=32000 -> ~1.05B 350M smoke: d=1024 n_layer=24 n_head=16 n_kv=4 d_ff=2816 v=32000 -> ~352M 50M proxy: d=512 n_layer=12 n_head=8 n_kv=2 d_ff=1408 v=32000 -> ~46M Why not 2B (the original v2 spec): user's compute budget is $10 of Vast 4090 spot. Inference-aware scaling says smaller-trained-longer beats bigger-trained- shorter when you'll query the model heavily. 1B with curated 25B tokens and inventions A/B is a better Pareto point than 2B with 40B tokens. """ from __future__ import annotations from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @dataclass class DecoderConfig: vocab_size: int = 32_000 d_model: int = 2048 n_layer: int = 22 n_head: int = 16 n_kv_head: int = 4 # GQA: queries / kv ratio = n_head / n_kv_head d_ff: int = 5504 # SwiGLU intermediate max_seq_len: int = 4096 rope_base: float = 10_000.0 rms_eps: float = 1e-5 dropout: float = 0.0 # Invention A — decision head decision_head_classes: int = 4 # BUY, SELL, HOLD, NO_TRADE decision_head_dropout: float = 0.0 # MTP — Multi-Token Prediction mtp_k: int = 4 # predict 4 tokens ahead (incl. the next) # Per-head architecture for the K-1 extra MTP predictors. # "block" — full DecoderBlock + RMSNorm (original behavior). # "mlp" — RMSNorm -> Linear -> SwiGLU-equivalent -> Linear -> RMSNorm # (~1/8 the cost of a DecoderBlock at typical d_ff ratios). # "linear" — RMSNorm -> Linear. Cheapest; tests how much MTP capacity # actually matters end-to-end. # Default is "block" for backward-compatible byte-identical behavior; the # 1B factory function below overrides to "mlp" per audit §P0.6. mtp_head_kind: str = "block" # Training hyperparameters carried with the config so checkpoints # are self-describing. init_std: float = 0.02 tie_word_embeddings: bool = True # Compact-capacity FFN experiments. # "swiglu" - dense baseline. # "ternary_swiglu" - BitNet-style ternary forward weights # with straight-through gradients. # "lowrank_swiglu" - factorized SwiGLU projections. # "routed_lowrank_swiglu" - sparse expert capsules: a router mixes # several low-rank SwiGLU experts. ffn_kind: str = "swiglu" ffn_rank: int = 128 ffn_experts: int = 4 ffn_top_k: int = 1 # Scratch chart patch encoder. This is a native raw-pixel path, not a # borrowed vision base: chart images are projected into prefix tokens that # text tokens can attend to through the causal decoder. chart_patch_encoder_enabled: bool = False chart_image_size: int = 224 chart_patch_size: int = 32 chart_channels: int = 3 chart_embed_dropout: float = 0.0 def config_1b() -> DecoderConfig: """Target architecture — 1.0B params. Audit §P0.6: the 1B config defaults to the cheap "mlp" MTP head kind (8-20% throughput gain vs full DecoderBlock heads). The smoke/proxy factories below keep the original "block" kind so byte-identical tests that depend on the legacy MTP wiring still hold. """ return DecoderConfig( d_model=2048, n_layer=22, n_head=16, n_kv_head=4, d_ff=5504, vocab_size=32_000, max_seq_len=4096, mtp_k=4, mtp_head_kind="mlp", ) def config_350m_smoke() -> DecoderConfig: """Mid-size smoke for actual training experiments.""" return DecoderConfig( d_model=1024, n_layer=24, n_head=16, n_kv_head=4, d_ff=2816, vocab_size=32_000, max_seq_len=2048, mtp_k=4, ) def config_50m_proxy() -> DecoderConfig: """Tiny proxy for CPU-side correctness checks; trains in seconds.""" return DecoderConfig( d_model=512, n_layer=12, n_head=8, n_kv_head=2, d_ff=1408, vocab_size=32_000, max_seq_len=1024, mtp_k=4, ) # ---------------------------------------------------------------- primitives class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: # In fp32 for stability, cast back. orig_dtype = x.dtype x = x.float() norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return (norm * self.weight).to(orig_dtype) def _rope_cache(seq_len: int, head_dim: int, base: float, device, dtype) -> tuple[torch.Tensor, torch.Tensor]: inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim)) t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) # (seq, head_dim/2) cos = freqs.cos().to(dtype) sin = freqs.sin().to(dtype) return cos, sin def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: # x: (B, n_head, T, head_dim) cos/sin: (T, head_dim/2) x1, x2 = x.chunk(2, dim=-1) rotated = torch.cat([-x2, x1], dim=-1) cos_full = torch.cat([cos, cos], dim=-1) # (T, head_dim) sin_full = torch.cat([sin, sin], dim=-1) return x * cos_full + rotated * sin_full class GQAttention(nn.Module): """Grouped-Query Attention with RoPE. n_head query heads share n_kv_head KV heads. At n_head=n_kv_head this is standard MHA; at n_kv_head=1 it's MQA. We use ratio 4 (e.g., 16q : 4kv). """ def __init__(self, cfg: DecoderConfig) -> None: super().__init__() assert cfg.d_model % cfg.n_head == 0 assert cfg.n_head % cfg.n_kv_head == 0 self.cfg = cfg self.head_dim = cfg.d_model // cfg.n_head self.n_head = cfg.n_head self.n_kv_head = cfg.n_kv_head self.n_rep = cfg.n_head // cfg.n_kv_head self.q_proj = nn.Linear(cfg.d_model, cfg.n_head * self.head_dim, bias=False) self.k_proj = nn.Linear(cfg.d_model, cfg.n_kv_head * self.head_dim, bias=False) self.v_proj = nn.Linear(cfg.d_model, cfg.n_kv_head * self.head_dim, bias=False) self.o_proj = nn.Linear(cfg.n_head * self.head_dim, cfg.d_model, bias=False) self.dropout = cfg.dropout def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor: B, T, D = x.shape q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) q = _apply_rope(q, cos[:T], sin[:T]) k = _apply_rope(k, cos[:T], sin[:T]) # Expand kv to match q heads (GQA). if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) # Use SDPA for FlashAttention-style speedup when on CUDA. out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=(attn_mask is None), ) out = out.transpose(1, 2).contiguous().view(B, T, D) return self.o_proj(out) class SwiGLU(nn.Module): def __init__(self, cfg: DecoderConfig) -> None: super().__init__() self.w1 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) self.w2 = nn.Linear(cfg.d_ff, cfg.d_model, bias=False) self.w3 = nn.Linear(cfg.d_model, cfg.d_ff, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TernaryLinear(nn.Module): """Linear layer with ternary forward weights and full-precision masters. This is a training-time simulation of the BitNet-style storage target: weights are quantized to {-scale, 0, +scale} in the forward pass, while gradients flow through the full-precision master weights via a straight- through estimator. Export tooling can later pack the ternary weights. """ def __init__(self, in_features: int, out_features: int) -> None: super().__init__() self.weight = nn.Parameter(torch.empty(out_features, in_features)) def forward(self, x: torch.Tensor) -> torch.Tensor: scale = self.weight.detach().abs().mean(dim=1, keepdim=True).clamp_min(1e-6) scaled = self.weight / scale q = scaled.round().clamp_(-1.0, 1.0) q_st = (q - scaled).detach() + scaled return F.linear(x, q_st * scale) class TernarySwiGLU(nn.Module): """SwiGLU where all three projections use ternary forward weights.""" def __init__(self, cfg: DecoderConfig) -> None: super().__init__() self.w1 = TernaryLinear(cfg.d_model, cfg.d_ff) self.w2 = TernaryLinear(cfg.d_ff, cfg.d_model) self.w3 = TernaryLinear(cfg.d_model, cfg.d_ff) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class LowRankLinear(nn.Module): """Factorized linear projection: in -> rank -> out.""" def __init__(self, in_features: int, out_features: int, rank: int) -> None: super().__init__() if rank <= 0: raise ValueError(f"rank must be positive, got {rank}") self.in_features = in_features self.out_features = out_features self.rank = min(rank, in_features, out_features) self.down = nn.Linear(in_features, self.rank, bias=False) self.up = nn.Linear(self.rank, out_features, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.up(self.down(x)) class LowRankSwiGLU(nn.Module): """Parameter-efficient SwiGLU using low-rank projection factors.""" def __init__(self, cfg: DecoderConfig) -> None: super().__init__() rank = cfg.ffn_rank self.w1 = LowRankLinear(cfg.d_model, cfg.d_ff, rank) self.w2 = LowRankLinear(cfg.d_ff, cfg.d_model, rank) self.w3 = LowRankLinear(cfg.d_model, cfg.d_ff, rank) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class RoutedLowRankSwiGLU(nn.Module): """Routed low-rank expert FFN. This keeps the active projection path compact while giving each layer a bank of specialist capsules. It is intentionally simple for proxy tournaments: all low-rank experts are evaluated, then a top-k masked router mixes them. The low-rank experts keep total parameters below a comparable dense MoE and make the storage trade-off explicit. """ def __init__(self, cfg: DecoderConfig) -> None: super().__init__() if cfg.ffn_experts <= 0: raise ValueError(f"ffn_experts must be positive, got {cfg.ffn_experts}") self.n_experts = int(cfg.ffn_experts) self.top_k = max(1, min(int(cfg.ffn_top_k), self.n_experts)) self.router = nn.Linear(cfg.d_model, self.n_experts, bias=False) self.experts = nn.ModuleList([LowRankSwiGLU(cfg) for _ in range(self.n_experts)]) def forward(self, x: torch.Tensor) -> torch.Tensor: route_logits = self.router(x) if self.top_k < self.n_experts: top_vals, top_idx = torch.topk(route_logits, self.top_k, dim=-1) masked = route_logits.new_full(route_logits.shape, float("-inf")) masked.scatter_(-1, top_idx, top_vals) route_logits = masked weights = torch.softmax(route_logits, dim=-1) mixed = None for i, expert in enumerate(self.experts): y = expert(x) * weights[..., i].unsqueeze(-1) mixed = y if mixed is None else mixed + y if mixed is None: # pragma: no cover - constructor guards this. raise RuntimeError("RoutedLowRankSwiGLU has no experts") return mixed def _build_ffn(cfg: DecoderConfig) -> nn.Module: kind = cfg.ffn_kind if kind == "swiglu": return SwiGLU(cfg) if kind == "ternary_swiglu": return TernarySwiGLU(cfg) if kind == "lowrank_swiglu": return LowRankSwiGLU(cfg) if kind == "routed_lowrank_swiglu": return RoutedLowRankSwiGLU(cfg) raise ValueError( f"unknown ffn_kind={kind!r}; expected one of " "'swiglu', 'ternary_swiglu', 'lowrank_swiglu', " "'routed_lowrank_swiglu'" ) class DecoderBlock(nn.Module): def __init__(self, cfg: DecoderConfig) -> None: super().__init__() self.attn_norm = RMSNorm(cfg.d_model, cfg.rms_eps) self.attn = GQAttention(cfg) self.mlp_norm = RMSNorm(cfg.d_model, cfg.rms_eps) self.mlp = _build_ffn(cfg) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor: x = x + self.attn(self.attn_norm(x), cos, sin, attn_mask) x = x + self.mlp(self.mlp_norm(x)) return x # ---------------------------------------------------------------- chart encoder class ChartPatchEncoder(nn.Module): """Scratch chart-image patch encoder. Images are projected with a strided Conv2d into patch tokens, then those tokens are prepended to the decoder stream. The decoder remains causal: chart patches come first, so text/action tokens can attend to chart evidence while the chart tokens do not peek at future action text. """ def __init__(self, cfg: DecoderConfig) -> None: super().__init__() if cfg.chart_image_size <= 0: raise ValueError("chart_image_size must be positive") if cfg.chart_patch_size <= 0: raise ValueError("chart_patch_size must be positive") if cfg.chart_image_size % cfg.chart_patch_size != 0: raise ValueError( "chart_image_size must be divisible by chart_patch_size " f"({cfg.chart_image_size} vs {cfg.chart_patch_size})" ) self.cfg = cfg self.grid = cfg.chart_image_size // cfg.chart_patch_size self.n_patches = self.grid * self.grid self.proj = nn.Conv2d( cfg.chart_channels, cfg.d_model, kernel_size=cfg.chart_patch_size, stride=cfg.chart_patch_size, bias=False, ) self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches, cfg.d_model)) self.type_embed = nn.Parameter(torch.zeros(1, 1, cfg.d_model)) self.norm = RMSNorm(cfg.d_model, cfg.rms_eps) self.dropout = nn.Dropout(cfg.chart_embed_dropout) nn.init.normal_(self.pos_embed, mean=0.0, std=cfg.init_std) nn.init.normal_(self.type_embed, mean=0.0, std=cfg.init_std) def forward( self, images: torch.Tensor, image_mask: torch.Tensor | None = None, ) -> torch.Tensor: if images.ndim != 4: raise ValueError(f"chart images must be (B,C,H,W), got {tuple(images.shape)}") if images.shape[1] != self.cfg.chart_channels: raise ValueError( f"expected {self.cfg.chart_channels} image channels, got {images.shape[1]}" ) if images.shape[-2:] != (self.cfg.chart_image_size, self.cfg.chart_image_size): raise ValueError( "chart image tensor shape does not match config: " f"{tuple(images.shape[-2:])} vs " f"({self.cfg.chart_image_size}, {self.cfg.chart_image_size})" ) x = images.to(dtype=self.proj.weight.dtype) if x.max().detach() > 2.0: x = x / 255.0 x = (x - 0.5) / 0.5 patches = self.proj(x).flatten(2).transpose(1, 2) patches = patches + self.pos_embed.to(dtype=patches.dtype) + self.type_embed.to(dtype=patches.dtype) patches = self.dropout(self.norm(patches)) if image_mask is not None: patches = patches * image_mask.to(dtype=patches.dtype).view(-1, 1, 1) return patches # ---------------------------------------------------------------- MTP heads class MTPHeadMLP(nn.Module): """Cheap MLP MTP head — ~1/8 the cost of a full DecoderBlock at d_ff=2.7d. Architecture (per audit §P0.6): RMSNorm -> Linear(d, d) -> SwiGLU-equivalent -> Linear(d, d) -> RMSNorm The SwiGLU-equivalent uses two parallel Linears (w_gate, w_up) and the silu gating pattern `silu(w_gate(x)) * w_up(x)`, matching the trunk's SwiGLU but at hidden = d_model (no d_ff expansion). """ def __init__(self, cfg: DecoderConfig) -> None: super().__init__() d = cfg.d_model self.norm_in = RMSNorm(d, cfg.rms_eps) self.w_in = nn.Linear(d, d, bias=False) self.w_gate = nn.Linear(d, d, bias=False) self.w_up = nn.Linear(d, d, bias=False) self.w_out = nn.Linear(d, d, bias=False) self.norm_out = RMSNorm(d, cfg.rms_eps) def forward(self, x: torch.Tensor) -> torch.Tensor: h = self.w_in(self.norm_in(x)) h = F.silu(self.w_gate(h)) * self.w_up(h) h = self.w_out(h) return self.norm_out(h) class MTPHeadLinear(nn.Module): """Cheapest MTP head — RMSNorm -> Linear(d, d). One projection plus one norm. Useful as an ablation: how much extra MTP capacity actually matters? """ def __init__(self, cfg: DecoderConfig) -> None: super().__init__() self.norm = RMSNorm(cfg.d_model, cfg.rms_eps) self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(self.norm(x)) # ---------------------------------------------------------------- main model @dataclass class ForwardOutputs: """Bundle of all heads' outputs.""" logits: torch.Tensor # (B, T, V) — standard next-token at K=1 mtp_logits: list[torch.Tensor] # K-1 extra heads for K=2..mtp_k decision_logits: torch.Tensor | None # (B, n_classes) when decision head is queried hidden_states: torch.Tensor # (B, T, D) for downstream heads / inspection class FinanceDecoder(nn.Module): """The frontier trading model — decoder + MTP + decision head.""" def __init__(self, cfg: DecoderConfig) -> None: super().__init__() self.cfg = cfg self.embed = nn.Embedding(cfg.vocab_size, cfg.d_model) self.chart_patch_encoder = ( ChartPatchEncoder(cfg) if cfg.chart_patch_encoder_enabled else None ) self.blocks = nn.ModuleList([DecoderBlock(cfg) for _ in range(cfg.n_layer)]) self.final_norm = RMSNorm(cfg.d_model, cfg.rms_eps) # Main lm_head — tied to embedding by default. self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) if cfg.tie_word_embeddings: self.lm_head.weight = self.embed.weight # MTP heads — v2 §2.3. Each is a small predictor for the t+k token. # The architecture is selected by `cfg.mtp_head_kind`: # "block" — original full DecoderBlock + RMSNorm path. Byte-identical # to the pre-§P0.6 code (same module names, same construction # order, same forward computation). # "mlp" — cheap 2-layer MLP head (MTPHeadMLP). ~1/8 the cost of # a DecoderBlock at typical d_ff ratios. The K-1 heads are # stored on `self.mtp_heads`; `mtp_blocks`/`mtp_norms` # remain ModuleLists for back-compat with helpers that # expect those attributes, but are empty in this mode. # "linear" — cheapest: RMSNorm + Linear (MTPHeadLinear). kind = cfg.mtp_head_kind n_mtp = max(cfg.mtp_k - 1, 0) if kind == "block": # Legacy path. Keep attribute names + construction order intact so # state dicts from the pre-change codebase load unchanged. self.mtp_blocks = nn.ModuleList( [DecoderBlock(cfg) for _ in range(n_mtp)] ) self.mtp_norms = nn.ModuleList( [RMSNorm(cfg.d_model, cfg.rms_eps) for _ in range(n_mtp)] ) self.mtp_heads = nn.ModuleList() elif kind == "mlp": self.mtp_heads = nn.ModuleList( [MTPHeadMLP(cfg) for _ in range(n_mtp)] ) self.mtp_blocks = nn.ModuleList() self.mtp_norms = nn.ModuleList() elif kind == "linear": self.mtp_heads = nn.ModuleList( [MTPHeadLinear(cfg) for _ in range(n_mtp)] ) self.mtp_blocks = nn.ModuleList() self.mtp_norms = nn.ModuleList() else: raise ValueError( f"unknown mtp_head_kind={kind!r}; expected one of " f"'block', 'mlp', 'linear'" ) # Invention A — decision head. self.decision_norm = RMSNorm(cfg.d_model, cfg.rms_eps) self.decision_head = nn.Sequential( nn.Linear(cfg.d_model, cfg.d_model // 2, bias=False), nn.GELU(), nn.Dropout(cfg.decision_head_dropout), nn.Linear(cfg.d_model // 2, cfg.decision_head_classes, bias=False), ) # Persistent RoPE cache (allocated lazily on the right device). self._rope_cos: torch.Tensor | None = None self._rope_sin: torch.Tensor | None = None self.apply(self._init_weights) def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=self.cfg.init_std) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, mean=0.0, std=self.cfg.init_std) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, TernaryLinear): nn.init.normal_(m.weight, mean=0.0, std=self.cfg.init_std) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=self.cfg.init_std) def num_parameters(self) -> int: n = sum(p.numel() for p in self.parameters()) if self.cfg.tie_word_embeddings: # Counted once already; tied lm_head shares memory. return n return n def _get_rope(self, T: int, device, dtype) -> tuple[torch.Tensor, torch.Tensor]: head_dim = self.cfg.d_model // self.cfg.n_head if self._rope_cos is None or self._rope_cos.shape[0] < T or self._rope_cos.device != device: self._rope_cos, self._rope_sin = _rope_cache( max(T, self.cfg.max_seq_len), head_dim, self.cfg.rope_base, device, dtype ) return self._rope_cos, self._rope_sin def _input_embeddings( self, input_ids: torch.Tensor, chart_images: torch.Tensor | None = None, chart_image_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, int]: text_x = self.embed(input_ids) if self.chart_patch_encoder is None or chart_images is None: return text_x, 0 if chart_image_mask is not None and not bool(chart_image_mask.any().item()): return text_x, 0 chart_x = self.chart_patch_encoder(chart_images, chart_image_mask) if chart_x.shape[0] != text_x.shape[0]: raise ValueError( f"chart batch size {chart_x.shape[0]} != text batch size {text_x.shape[0]}" ) return torch.cat([chart_x.to(dtype=text_x.dtype), text_x], dim=1), int(chart_x.shape[1]) def _apply_mtp_heads( self, h: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> list[torch.Tensor]: """Run the K-1 MTP heads on the trunk's final hidden state. Returns a list of `(B, T, V)` logits tensors — one per extra head. Output shape and structure are identical across all `mtp_head_kind` values so downstream loss code stays unchanged. The "block" path mutates `h_k` with a residual at each step (legacy accumulating behavior). The cheap "mlp" / "linear" heads are stateless w.r.t. each other — each reads the trunk's final hidden state h directly. This is intentional: those heads are meant to be thin readouts, not to chain reasoning. """ kind = self.cfg.mtp_head_kind mtp_logits: list[torch.Tensor] = [] if kind == "block": h_k = h for blk, norm in zip(self.mtp_blocks, self.mtp_norms): h_k = blk(h_k, cos, sin) + h_k mtp_logits.append(self.lm_head(norm(h_k))) else: for head in self.mtp_heads: mtp_logits.append(self.lm_head(head(h))) return mtp_logits def forward( self, input_ids: torch.Tensor, # (B, T) int64 decision_query_index: torch.Tensor | None = None, # (B,) which position to read for decision; default last non-pad return_decision: bool = False, chart_images: torch.Tensor | None = None, chart_image_mask: torch.Tensor | None = None, ) -> ForwardOutputs: B, T = input_ids.shape x, text_offset = self._input_embeddings(input_ids, chart_images, chart_image_mask) total_T = x.shape[1] cos, sin = self._get_rope(total_T, x.device, x.dtype) for blk in self.blocks: x = blk(x, cos, sin) h_all = self.final_norm(x) h = h_all[:, text_offset:, :] # (B, T, D) logits = self.lm_head(h) # (B, T, V) — predicts t+1 # MTP — each extra head predicts t+k for k=2..mtp_k. mtp_logits = self._apply_mtp_heads(h, cos[text_offset:], sin[text_offset:]) decision_logits = None if return_decision: if decision_query_index is None: # Use the last token's hidden state. idx = torch.full((B,), T - 1, device=input_ids.device, dtype=torch.long) else: idx = decision_query_index gather_idx = idx.view(B, 1, 1).expand(B, 1, h.shape[-1]) pooled = h.gather(1, gather_idx).squeeze(1) # (B, D) decision_logits = self.decision_head(self.decision_norm(pooled)) return ForwardOutputs( logits=logits, mtp_logits=mtp_logits, decision_logits=decision_logits, hidden_states=h, )