| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| |
| |
| |
| class Neon213Config(PretrainedConfig): |
| model_type = "neon213" |
| def __init__( |
| self, |
| vocab_size=16384, |
| d_model=384, |
| n_layers=8, |
| n_head=6, |
| d_ff=1536, |
| block_size=256, |
| conv_k=9, |
| mlp_k=9, |
| **kwargs |
| ): |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.n_head = n_head |
| self.d_ff = d_ff |
| self.block_size = block_size |
| self.conv_k = conv_k |
| self.mlp_k = mlp_k |
| super().__init__(**kwargs) |
|
|
| |
| |
| |
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| def forward(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
| def apply_rotary_emb(x, freqs_cos, freqs_sin): |
| d = x.shape[-1] |
| x_even, x_odd = x[..., :d:2], x[..., 1:d:2] |
| cos = freqs_cos[:x.shape[1]].view(1, x.shape[1], 1, -1) |
| sin = freqs_sin[:x.shape[1]].view(1, x.shape[1], 1, -1) |
| return torch.cat([x_even * cos - x_odd * sin, x_even * sin + x_odd * cos], dim=-1) |
|
|
| |
| |
| |
| class GrowableConvAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.n_head = config.n_head |
| self.head_dim = config.d_model // config.n_head |
| d_model = config.d_model |
| self.k = config.conv_k |
|
|
| self.c_attn = nn.Linear(d_model, 4 * d_model, bias=False) |
| self.conv_q = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) |
| self.conv_k = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) |
| self.conv_v = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) |
| self.conv_i = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) |
| self.q_norm = RMSNorm(self.head_dim) |
| self.k_norm = RMSNorm(self.head_dim) |
| self.c_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
| def forward(self, x, f_cos, f_sin): |
| B, T, C = x.shape |
| q, k, v, intent = self.c_attn(x).split(C, dim=2) |
| pad = self.k - 1 |
| q = self.conv_q(F.pad(q.transpose(1, 2), (pad, 0))).transpose(1, 2) |
| k = self.conv_k(F.pad(k.transpose(1, 2), (pad, 0))).transpose(1, 2) |
| v = self.conv_v(F.pad(v.transpose(1, 2), (pad, 0))).transpose(1, 2) |
| intent = self.conv_i(F.pad(intent.transpose(1, 2), (pad, 0))).transpose(1, 2) |
| |
| q = q.view(B, T, self.n_head, self.head_dim) |
| k = k.view(B, T, self.n_head, self.head_dim) |
| v = v.view(B, T, self.n_head, self.head_dim) |
| intent = intent.view(B, T, self.n_head, self.head_dim) |
| |
| q, k = self.q_norm(q), self.k_norm(k) |
| q = apply_rotary_emb(q, f_cos, f_sin) |
| k = apply_rotary_emb(k, f_cos, f_sin) |
| |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
| intent = intent.transpose(1, 2) |
| |
| attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| y = torch.sigmoid(intent) * attn_out |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.c_proj(y) |
|
|
| class GrowableHydraMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| d_model = config.d_model |
| d_ff = config.d_ff |
| self.k = config.mlp_k |
| self.conv_gate = nn.Conv1d(d_model, d_model, kernel_size=self.k, groups=d_model, bias=False) |
| self.c_gate_proj = nn.Linear(d_model, d_ff, bias=False) |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) |
| self.w2 = nn.Linear(d_ff, d_model, bias=False) |
|
|
| def forward(self, x): |
| x_t = x.transpose(1, 2) |
| pad = self.k - 1 |
| c = self.conv_gate(F.pad(x_t, (pad, 0))).transpose(1, 2) |
| gate = F.silu(self.c_gate_proj(c)) |
| return self.w2(gate * self.w1(x)) |
|
|
| class Block(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln1 = RMSNorm(config.d_model) |
| self.attn = GrowableConvAttention(config) |
| self.ln2 = RMSNorm(config.d_model) |
| self.mlp = GrowableHydraMLP(config) |
| def forward(self, x, f_cos, f_sin): |
| x = x + self.attn(self.ln1(x), f_cos, f_sin) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
| |
| |
| |
| class Neon213(PreTrainedModel): |
| config_class = Neon213Config |
| base_model_prefix = "neon213" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.token_emb = nn.Embedding(config.vocab_size, config.d_model) |
| self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)]) |
| self.ln_f = RMSNorm(config.d_model) |
| self.head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| |
| |
| dim = config.d_model // config.n_head |
| inv_freq = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim)) |
| t = torch.arange(config.block_size).float() |
| freqs = torch.outer(t, inv_freq) |
| self.register_buffer("freqs_cos", torch.cos(freqs), persistent=True) |
| self.register_buffer("freqs_sin", torch.sin(freqs), persistent=True) |
|
|
| self.post_init() |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def get_input_embeddings(self): |
| return self.token_emb |
|
|
| def set_input_embeddings(self, value): |
| self.token_emb = value |
|
|
| def get_output_embeddings(self): |
| return self.head |
|
|
| def set_output_embeddings(self, value): |
| self.head = value |
|
|
| def tie_weights(self, **kwargs): |
| |
| self.head.weight = self.token_emb.weight |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| |
| |
| |
| |
| x = self.token_emb(input_ids) |
| for block in self.blocks: |
| x = block(x, self.freqs_cos, self.freqs_sin) |
| |
| logits = self.head(self.ln_f(x)) |
| |
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
| |
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=None, |
| hidden_states=None, |
| attentions=None |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| |
| return {"input_ids": input_ids} |
|
|