import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast # ============================================================ # Configuration # ============================================================ class Neon213Config(PretrainedConfig): model_type = "neon213" def __init__( self, vocab_size=16384, d_model=384, n_layers=8, n_head=6, d_ff=1536, block_size=256, conv_k=9, mlp_k=9, **kwargs ): self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_head = n_head self.d_ff = d_ff self.block_size = block_size self.conv_k = conv_k self.mlp_k = mlp_k super().__init__(**kwargs) # ============================================================ # Utilities # ============================================================ class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight def apply_rotary_emb(x, freqs_cos, freqs_sin): d = x.shape[-1] x_even, x_odd = x[..., :d:2], x[..., 1:d:2] cos = freqs_cos[:x.shape[1]].view(1, x.shape[1], 1, -1) sin = freqs_sin[:x.shape[1]].view(1, x.shape[1], 1, -1) return torch.cat([x_even * cos - x_odd * sin, x_even * sin + x_odd * cos], dim=-1) # ============================================================ # Components # ============================================================ class GrowableConvAttention(nn.Module): def __init__(self, config): super().__init__() self.n_head = config.n_head self.head_dim = config.d_model // config.n_head d_model = config.d_model self.k = config.conv_k self.c_attn = nn.Linear(d_model, 4 * d_model, bias=False) self.conv_q = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) self.conv_k = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) self.conv_v = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) self.conv_i = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) self.q_norm = RMSNorm(self.head_dim) self.k_norm = RMSNorm(self.head_dim) self.c_proj = nn.Linear(d_model, d_model, bias=False) def forward(self, x, f_cos, f_sin): B, T, C = x.shape q, k, v, intent = self.c_attn(x).split(C, dim=2) pad = self.k - 1 q = self.conv_q(F.pad(q.transpose(1, 2), (pad, 0))).transpose(1, 2) k = self.conv_k(F.pad(k.transpose(1, 2), (pad, 0))).transpose(1, 2) v = self.conv_v(F.pad(v.transpose(1, 2), (pad, 0))).transpose(1, 2) intent = self.conv_i(F.pad(intent.transpose(1, 2), (pad, 0))).transpose(1, 2) q = q.view(B, T, self.n_head, self.head_dim) k = k.view(B, T, self.n_head, self.head_dim) v = v.view(B, T, self.n_head, self.head_dim) intent = intent.view(B, T, self.n_head, self.head_dim) q, k = self.q_norm(q), self.k_norm(k) q = apply_rotary_emb(q, f_cos, f_sin) k = apply_rotary_emb(k, f_cos, f_sin) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) intent = intent.transpose(1, 2) attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = torch.sigmoid(intent) * attn_out y = y.transpose(1, 2).contiguous().view(B, T, C) return self.c_proj(y) class GrowableHydraMLP(nn.Module): def __init__(self, config): super().__init__() d_model = config.d_model d_ff = config.d_ff self.k = config.mlp_k self.conv_gate = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) self.c_gate_proj = nn.Linear(d_model, d_ff, bias=False) self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_ff, d_model, bias=False) def forward(self, x): x_t = x.transpose(1, 2) pad = self.k - 1 c = self.conv_gate(F.pad(x_t, (pad, 0))).transpose(1, 2) gate = F.silu(self.c_gate_proj(c)) return self.w2(gate * self.w1(x)) class Block(nn.Module): def __init__(self, config): super().__init__() self.ln1 = RMSNorm(config.d_model) self.attn = GrowableConvAttention(config) self.ln2 = RMSNorm(config.d_model) self.mlp = GrowableHydraMLP(config) def forward(self, x, f_cos, f_sin): x = x + self.attn(self.ln1(x), f_cos, f_sin) x = x + self.mlp(self.ln2(x)) return x # ============================================================ # Main Model: Neon213 # ============================================================ class Neon213(PreTrainedModel): config_class = Neon213Config base_model_prefix = "neon213" def __init__(self, config): super().__init__(config) self.token_emb = nn.Embedding(config.vocab_size, config.d_model) self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)]) self.ln_f = RMSNorm(config.d_model) self.head = nn.Linear(config.d_model, config.vocab_size, bias=False) # RoPE Buffers dim = config.d_model // config.n_head inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(config.block_size).float() freqs = torch.outer(t, inv_freq) self.register_buffer("freqs_cos", torch.cos(freqs), persistent=True) self.register_buffer("freqs_sin", torch.sin(freqs), persistent=True) self.post_init() def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def get_input_embeddings(self): return self.token_emb def set_input_embeddings(self, value): self.token_emb = value def get_output_embeddings(self): return self.head def set_output_embeddings(self, value): self.head = value def tie_weights(self, **kwargs): # Hugging Face style weight tying self.head.weight = self.token_emb.weight def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): # Handle use_cache if passed (we ignore it for now as neon213 lacks KV-cache support) # transformers often passes more args via kwargs # input_ids: [B, T] x = self.token_emb(input_ids) for block in self.blocks: x = block(x, self.freqs_cos, self.freqs_sin) logits = self.head(self.ln_f(x)) loss = None if labels is not None: loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1)) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None ) def prepare_inputs_for_generation(self, input_ids, **kwargs): # We don't support caching, so we always pass the full sequence return {"input_ids": input_ids}