neon213 / neon213.py
luozhangzichen's picture
Upload folder using huggingface_hub
7bdd25a verified
Raw
History Blame Contribute Delete
7.62 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
# ============================================================
# Configuration
# ============================================================
class Neon213Config(PretrainedConfig):
model_type = "neon213"
def __init__(
self,
vocab_size=16384,
d_model=384,
n_layers=8,
n_head=6,
d_ff=1536,
block_size=256,
conv_k=9,
mlp_k=9,
**kwargs
):
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_head = n_head
self.d_ff = d_ff
self.block_size = block_size
self.conv_k = conv_k
self.mlp_k = mlp_k
super().__init__(**kwargs)
# ============================================================
# Utilities
# ============================================================
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def apply_rotary_emb(x, freqs_cos, freqs_sin):
d = x.shape[-1]
x_even, x_odd = x[..., :d:2], x[..., 1:d:2]
cos = freqs_cos[:x.shape[1]].view(1, x.shape[1], 1, -1)
sin = freqs_sin[:x.shape[1]].view(1, x.shape[1], 1, -1)
return torch.cat([x_even * cos - x_odd * sin, x_even * sin + x_odd * cos], dim=-1)
# ============================================================
# Components
# ============================================================
class GrowableConvAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_head = config.n_head
self.head_dim = config.d_model // config.n_head
d_model = config.d_model
self.k = config.conv_k
self.c_attn = nn.Linear(d_model, 4 * d_model, bias=False)
self.conv_q = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False)
self.conv_k = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False)
self.conv_v = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False)
self.conv_i = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False)
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
self.c_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, f_cos, f_sin):
B, T, C = x.shape
q, k, v, intent = self.c_attn(x).split(C, dim=2)
pad = self.k - 1
q = self.conv_q(F.pad(q.transpose(1, 2), (pad, 0))).transpose(1, 2)
k = self.conv_k(F.pad(k.transpose(1, 2), (pad, 0))).transpose(1, 2)
v = self.conv_v(F.pad(v.transpose(1, 2), (pad, 0))).transpose(1, 2)
intent = self.conv_i(F.pad(intent.transpose(1, 2), (pad, 0))).transpose(1, 2)
q = q.view(B, T, self.n_head, self.head_dim)
k = k.view(B, T, self.n_head, self.head_dim)
v = v.view(B, T, self.n_head, self.head_dim)
intent = intent.view(B, T, self.n_head, self.head_dim)
q, k = self.q_norm(q), self.k_norm(k)
q = apply_rotary_emb(q, f_cos, f_sin)
k = apply_rotary_emb(k, f_cos, f_sin)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
intent = intent.transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = torch.sigmoid(intent) * attn_out
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
class GrowableHydraMLP(nn.Module):
def __init__(self, config):
super().__init__()
d_model = config.d_model
d_ff = config.d_ff
self.k = config.mlp_k
self.conv_gate = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False)
self.c_gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
x_t = x.transpose(1, 2)
pad = self.k - 1
c = self.conv_gate(F.pad(x_t, (pad, 0))).transpose(1, 2)
gate = F.silu(self.c_gate_proj(c))
return self.w2(gate * self.w1(x))
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = RMSNorm(config.d_model)
self.attn = GrowableConvAttention(config)
self.ln2 = RMSNorm(config.d_model)
self.mlp = GrowableHydraMLP(config)
def forward(self, x, f_cos, f_sin):
x = x + self.attn(self.ln1(x), f_cos, f_sin)
x = x + self.mlp(self.ln2(x))
return x
# ============================================================
# Main Model: Neon213
# ============================================================
class Neon213(PreTrainedModel):
config_class = Neon213Config
base_model_prefix = "neon213"
def __init__(self, config):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
self.ln_f = RMSNorm(config.d_model)
self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# RoPE Buffers
dim = config.d_model // config.n_head
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(config.block_size).float()
freqs = torch.outer(t, inv_freq)
self.register_buffer("freqs_cos", torch.cos(freqs), persistent=True)
self.register_buffer("freqs_sin", torch.sin(freqs), persistent=True)
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def get_input_embeddings(self):
return self.token_emb
def set_input_embeddings(self, value):
self.token_emb = value
def get_output_embeddings(self):
return self.head
def set_output_embeddings(self, value):
self.head = value
def tie_weights(self, **kwargs):
# Hugging Face style weight tying
self.head.weight = self.token_emb.weight
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
# Handle use_cache if passed (we ignore it for now as neon213 lacks KV-cache support)
# transformers often passes more args via kwargs
# input_ids: [B, T]
x = self.token_emb(input_ids)
for block in self.blocks:
x = block(x, self.freqs_cos, self.freqs_sin)
logits = self.head(self.ln_f(x))
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1))
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# We don't support caching, so we always pass the full sequence
return {"input_ids": input_ids}