aman0419 commited on
Commit
0ee2f54
·
verified ·
1 Parent(s): 09bb48d

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +184 -0
model.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from dataclasses import dataclass
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from contextlib import nullcontext
9
+ import os
10
+
11
+ @dataclass
12
+ class SLMConfig:
13
+ block_size: int = 256
14
+ vocab_size: int = 16834
15
+ n_layer: int = 10
16
+ n_head: int = 8
17
+ n_embd: int = 512
18
+ dropout: float = 0.0
19
+ bias: bool = True
20
+
21
+ class LayerNorm(nn.Module):
22
+ def __init__(self, ndim, bias=True, eps=1e-5):
23
+ super().__init__()
24
+ self.weight = nn.Parameter(torch.ones(ndim))
25
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
26
+ self.eps = eps
27
+
28
+ def forward(self, x):
29
+ return F.layer_norm(x, x.shape[-1:], self.weight, self.bias, self.eps)
30
+
31
+
32
+ class CausalSelfAttention(nn.Module):
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ assert config.n_embd % config.n_head == 0
36
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38
+ self.attn_dropout = nn.Dropout(config.dropout)
39
+ self.resid_dropout = nn.Dropout(config.dropout)
40
+ self.n_head = config.n_head
41
+ self.n_embd = config.n_embd
42
+ self.flash = hasattr(F, 'scaled_dot_product_attention')
43
+ if not self.flash:
44
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
45
+ .view(1, 1, config.block_size, config.block_size))
46
+
47
+ def forward(self, x):
48
+ B, T, C = x.size()
49
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
50
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
51
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
52
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
53
+
54
+ if self.flash:
55
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
56
+ else:
57
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
58
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
59
+ att = F.softmax(att, dim=-1)
60
+ att = self.attn_dropout(att)
61
+ y = att @ v
62
+
63
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
64
+ y = self.resid_dropout(self.c_proj(y))
65
+ return y
66
+
67
+
68
+ class MLP(nn.Module):
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ # SwiGLU typically keeps the hidden dimension at 4 * n_embd (like LLaMA),
72
+ # or uses 8/3 * n_embd to maintain the same parameter count as a standard MLP.
73
+ # Here we stick to 4 * n_embd for maximum capacity.
74
+ hidden_dim = 4 * config.n_embd
75
+
76
+ # w1: Gate Projection
77
+ self.w1 = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
78
+ # w2: Value Projection
79
+ self.w2 = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
80
+ # c_proj: Output Projection (Down projection)
81
+ self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
82
+
83
+ self.dropout = nn.Dropout(config.dropout)
84
+
85
+ def forward(self, x):
86
+ # SwiGLU Logic: (SiLU(Gate) * Value) -> Projection
87
+ # 1. Gate path: w1(x) -> SiLU
88
+ # 2. Value path: w2(x)
89
+ # 3. Element-wise multiply
90
+ x = F.silu(self.w1(x)) * self.w2(x)
91
+
92
+ # 4. Output projection
93
+ return self.dropout(self.c_proj(x))
94
+
95
+
96
+ class Block(nn.Module):
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.ln1 = LayerNorm(config.n_embd, config.bias)
100
+ self.attn = CausalSelfAttention(config)
101
+ self.ln2 = LayerNorm(config.n_embd, config.bias)
102
+ self.mlp = MLP(config)
103
+
104
+ def forward(self, x):
105
+ x = x + self.attn(self.ln1(x))
106
+ x = x + self.mlp(self.ln2(x))
107
+ return x
108
+
109
+ @dataclass
110
+ class SLMConfig:
111
+ block_size: int
112
+ vocab_size: int
113
+ n_layer: int
114
+ n_head: int
115
+ n_embd: int
116
+ dropout: float = 0.0
117
+ bias: bool = True
118
+
119
+ class SLM(nn.Module):
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ self.config = config
123
+ self.transformer = nn.ModuleDict(dict(
124
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
125
+ wpe=nn.Embedding(config.block_size, config.n_embd),
126
+ drop=nn.Dropout(config.dropout),
127
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
128
+ ln_f=LayerNorm(config.n_embd, config.bias),
129
+ ))
130
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
131
+ self.transformer.wte.weight = self.lm_head.weight # weight tying
132
+
133
+ self.apply(self._init_weights)
134
+ # Apply special scaled init to the residual projections, c_proj
135
+ for pn, p in self.named_parameters():
136
+ if pn.endswith('c_proj.weight'):
137
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
138
+
139
+ def _init_weights(self, module):
140
+ if isinstance(module, nn.Linear):
141
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
142
+ if module.bias is not None:
143
+ nn.init.zeros_(module.bias)
144
+ elif isinstance(module, nn.Embedding):
145
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
146
+
147
+ def forward(self, idx, targets=None):
148
+ device = idx.device
149
+ b, t = idx.size()
150
+ assert t <= self.config.block_size
151
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
152
+
153
+ tok_emb = self.transformer.wte(idx)
154
+ pos_emb = self.transformer.wpe(pos)
155
+ x = self.transformer.drop(tok_emb + pos_emb)
156
+ for block in self.transformer.h:
157
+ x = block(x)
158
+ x = self.transformer.ln_f(x)
159
+
160
+ if targets is not None:
161
+ logits = self.lm_head(x)
162
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
163
+ return logits, loss
164
+ else:
165
+ logits = self.lm_head(x[:, [-1], :])
166
+ return logits, None
167
+
168
+ @torch.no_grad()
169
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
170
+ """
171
+ Generate tokens given a conditioning sequence.
172
+ idx: Tensor of shape (B, T)
173
+ """
174
+ for _ in range(max_new_tokens):
175
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
176
+ logits, _ = self(idx_cond)
177
+ logits = logits[:, -1, :] / temperature
178
+ if top_k is not None:
179
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
180
+ logits[logits < v[:, [-1]]] = -float('Inf')
181
+ probs = F.softmax(logits, dim=-1)
182
+ idx_next = torch.multinomial(probs, num_samples=1)
183
+ idx = torch.cat((idx, idx_next), dim=1)
184
+ return idx