# -*- coding: utf-8 -*- """HybriKo Model - Standalone for HuggingFace A hybrid RNN-Attention language model optimized for Korean. Uses a 2:1 ratio of RNN (Griffin) blocks to Attention blocks. """ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from typing import Optional, Dict, Any, Tuple try: from .configuration_hybridko import HybriKoConfig except ImportError: from configuration_hybridko import HybriKoConfig # ============================================================================ # Basic Layer Components # ============================================================================ class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) def forward(self, x: torch.Tensor) -> torch.Tensor: rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) return x / rms * self.weight class GeGLU(nn.Module): """Gated GELU Feed-Forward Network.""" def __init__(self, d_model: int, d_ff: int): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_model, d_ff, bias=False) self.w3 = nn.Linear(d_ff, d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w3(F.gelu(self.w1(x)) * self.w2(x)) def parallel_scan_stable(a: torch.Tensor, c: torch.Tensor, chunk_size: int = 128) -> torch.Tensor: """Parallel scan for linear recurrence h_t = a_t * h_{t-1} + c_t. Uses chunked computation with fp32 for numerical stability. """ batch, seq_len, d_model = a.shape device = a.device orig_dtype = a.dtype # Use fp32 for numerical stability a = a.float() c = c.float() # Number of chunks n_chunks = (seq_len + chunk_size - 1) // chunk_size # Output tensor h_out = torch.zeros(batch, seq_len, d_model, device=device, dtype=torch.float32) # Previous chunk's final hidden state h_prev = torch.zeros(batch, d_model, device=device, dtype=torch.float32) for chunk_idx in range(n_chunks): start = chunk_idx * chunk_size end = min(start + chunk_size, seq_len) # Get chunk data a_chunk = a[:, start:end, :] c_chunk = c[:, start:end, :] # Clamp a for stability a_chunk = torch.clamp(a_chunk, min=1e-6, max=1.0 - 1e-6) # Use direct cumprod cumprod_a = torch.cumprod(a_chunk, dim=1) # Compute inverse cumprod safely inv_cumprod = 1.0 / (cumprod_a + 1e-8) # scaled_c[k] = c[k] / cumprod_a[k] scaled_c = c_chunk * inv_cumprod cumsum_scaled = torch.cumsum(scaled_c, dim=1) # h from c contributions h_chunk_from_c = cumprod_a * cumsum_scaled # h contribution from previous state h_chunk_from_prev = h_prev.unsqueeze(1) * cumprod_a # Total h for this chunk h_chunk = h_chunk_from_c + h_chunk_from_prev # Store result h_out[:, start:end, :] = h_chunk # Update h_prev for next chunk h_prev = h_chunk[:, -1, :] return h_out.to(orig_dtype) class RGLRU(nn.Module): """Real-Gated Linear Recurrent Unit (Griffin/LFM2 style).""" def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.d_model = d_model self.eps = eps self.input_proj = nn.Linear(d_model, d_model * 2) self.gate_proj = nn.Linear(d_model, d_model * 2) self.a_param = nn.Parameter(torch.zeros(d_model)) self.out_proj = nn.Linear(d_model, d_model) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.input_proj.weight) nn.init.xavier_uniform_(self.gate_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) nn.init.uniform_(self.a_param, -0.5, 0.5) def forward( self, x: torch.Tensor, h_prev: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: batch, seq_len, _ = x.shape # Input gating input_gate = self.input_proj(x) x_in, x_gate = input_gate.chunk(2, dim=-1) x_in = x_in * torch.sigmoid(x_gate) # Recurrent gating gates = self.gate_proj(x) r, i = gates.chunk(2, dim=-1) r = torch.sigmoid(r) i = torch.sigmoid(i) # Compute recurrence coefficients a_base = torch.sigmoid(F.softplus(self.a_param)) a = a_base.unsqueeze(0).unsqueeze(0) * r sqrt_1_minus_a2 = torch.sqrt(torch.clamp(1 - a ** 2, min=self.eps)) # Compute c = sqrt_1_minus_a2 * i * x_in c = sqrt_1_minus_a2 * i * x_in # Use parallel scan for efficient computation h_seq = parallel_scan_stable(a, c, chunk_size=256) # Get final hidden state h = h_seq[:, -1, :] return self.out_proj(h_seq), h # ============================================================================ # Attention Components # ============================================================================ class RotaryEmbedding(nn.Module): """Rotary Positional Embedding (RoPE).""" def __init__(self, d_head: int, max_seq_len: int = 2048): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, d_head, 2).float() / d_head)) self.register_buffer("inv_freq", inv_freq) self._cache = None def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: seq_len = x.shape[2] if self._cache is None or self._cache[0].shape[2] < seq_len: t = torch.arange(seq_len, device=x.device, dtype=x.dtype) freqs = torch.outer(t, self.inv_freq.to(x.device)) emb = torch.cat([freqs, freqs], dim=-1) self._cache = ( emb.cos().unsqueeze(0).unsqueeze(0), emb.sin().unsqueeze(0).unsqueeze(0), ) return self._cache[0][:, :, :seq_len], self._cache[1][:, :, :seq_len] def apply_rope( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> torch.Tensor: """Apply Rotary Positional Embedding to input tensor.""" d_half = x.shape[-1] // 2 x1, x2 = x[..., :d_half], x[..., d_half:] cos = cos[..., :d_half] sin = sin[..., :d_half] return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) class GQAttention(nn.Module): """Grouped Query Attention with RoPE.""" def __init__( self, d_model: int, n_heads: int = 8, n_kv_heads: int = 2, dropout: float = 0.0, ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.d_head = d_model // n_heads self.scale = 1.0 / math.sqrt(self.d_head) self.dropout = dropout self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False) self.v_proj = nn.Linear(d_model, n_kv_heads * self.d_head, bias=False) self.o_proj = nn.Linear(d_model, d_model, bias=False) self.rope = RotaryEmbedding(self.d_head) def forward(self, x: torch.Tensor) -> torch.Tensor: B, L, _ = x.shape q = self.q_proj(x).view(B, L, self.n_heads, self.d_head) k = self.k_proj(x).view(B, L, self.n_kv_heads, self.d_head) v = self.v_proj(x).view(B, L, self.n_kv_heads, self.d_head) # Transpose to [B, n_heads, L, d_head] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Apply RoPE cos, sin = self.rope(q) q = apply_rope(q, cos, sin) k = apply_rope(k, cos, sin) # Expand KV heads to match query heads n_rep = self.n_heads // self.n_kv_heads k = k.repeat_interleave(n_rep, dim=1) v = v.repeat_interleave(n_rep, dim=1) # Attention with causal mask attn = (q @ k.transpose(-2, -1)) * self.scale mask = torch.triu(torch.ones(L, L, device=q.device), diagonal=1).bool() attn = attn.masked_fill(mask, float("-inf")) attn = F.softmax(attn, dim=-1) if self.training and self.dropout > 0: attn = F.dropout(attn, p=self.dropout) out = (attn @ v).transpose(1, 2).contiguous() return self.o_proj(out.view(B, L, -1)) # ============================================================================ # Block Components # ============================================================================ class GriffinBlock(nn.Module): """RNN-based block using RGLRU.""" def __init__(self, d_model: int, ff_mult: int = 3, dropout: float = 0.0): super().__init__() self.norm1 = RMSNorm(d_model) self.rglru = RGLRU(d_model) self.norm2 = RMSNorm(d_model) self.ffn = GeGLU(d_model, d_model * ff_mult) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: rnn_out, _ = self.rglru(self.norm1(x)) x = x + self.dropout(rnn_out) x = x + self.dropout(self.ffn(self.norm2(x))) return x class AttentionBlock(nn.Module): """Attention-based block using GQA.""" def __init__( self, d_model: int, n_heads: int = 8, n_kv_heads: int = 2, ff_mult: int = 3, dropout: float = 0.0 ): super().__init__() self.norm1 = RMSNorm(d_model) self.attn = GQAttention(d_model, n_heads, n_kv_heads) self.norm2 = RMSNorm(d_model) self.ffn = GeGLU(d_model, d_model * ff_mult) self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.dropout(self.attn(self.norm1(x))) x = x + self.dropout(self.ffn(self.norm2(x))) return x # ============================================================================ # Main Model # ============================================================================ class HybriKoModel(nn.Module): """HybriKo: Hybrid RNN-Attention Language Model. Uses a 2:1 ratio of RNN (Griffin) blocks to Attention blocks. - Layers 1, 2: GriffinBlock (RNN) - Layer 3: AttentionBlock - Pattern repeats... """ def __init__(self, config: HybriKoConfig): super().__init__() self.config = config self.gradient_checkpointing = False # Token embedding self.embed = nn.Embedding(config.vocab_size, config.d_model) # Embedding dropout dropout = getattr(config, 'dropout', 0.0) self.embed_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() # Hybrid layers: 2 RNN : 1 Attention pattern self.layers = nn.ModuleList() for i in range(config.n_layers): if (i + 1) % 3 == 0: # Every 3rd layer is Attention self.layers.append( AttentionBlock( config.d_model, config.n_heads, config.n_kv_heads, config.ff_mult, dropout=dropout ) ) else: # RNN blocks self.layers.append(GriffinBlock(config.d_model, config.ff_mult, dropout=dropout)) # Final normalization and LM head self.norm = RMSNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Weight tying self.lm_head.weight = self.embed.weight # Initialize weights self.apply(self._init_weights) # Print parameter count n_params = sum(p.numel() for p in self.parameters()) print(f"HybriKo: {n_params:,} parameters ({n_params/1e6:.1f}M)") def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) def forward( self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, label_smoothing: float = 0.0 ) -> Dict[str, Any]: """Forward pass.""" x = self.embed(input_ids) x = self.embed_dropout(x) for layer in self.layers: if self.gradient_checkpointing and self.training: x = checkpoint(lambda l, inp: l(inp), layer, x, use_reentrant=False) else: x = layer(x) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: loss = F.cross_entropy( logits[:, :-1].contiguous().view(-1, self.config.vocab_size), labels[:, 1:].contiguous().view(-1), ignore_index=-100, label_smoothing=label_smoothing, ) return {"logits": logits, "loss": loss} @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 50, temperature: float = 0.8, top_k: Optional[int] = None, top_p: Optional[float] = None, ) -> torch.Tensor: """Generate text tokens.""" self.eval() for _ in range(max_new_tokens): idx = input_ids[:, -self.config.max_seq_len:] logits = self(idx)["logits"][:, -1] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") if top_p is not None: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits[indices_to_remove] = float("-inf") probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, 1) input_ids = torch.cat([input_ids, next_token], dim=1) return input_ids