import torch from torch.utils.data import Dataset, DataLoader import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self,d_in,d_out,context_length,dropout,qkv_bias,n_heads): super().__init__() self.n_heads = n_heads self.head_dim = d_out // n_heads self.d_out = d_out self.W_key = nn.Linear(d_in,d_out,bias=qkv_bias) self.W_query = nn.Linear(d_in,d_out,bias=qkv_bias) self.W_value = nn.Linear(d_in,d_out,bias=qkv_bias) self.dropout = nn.Dropout(dropout) self.proj = nn.Linear(d_out,d_out) self.register_buffer( 'mask', torch.triu(torch.ones(context_length, context_length), diagonal=1) ) def forward(self,x): b,n_tokens,d_out = x.shape keys = self.W_key(x).view(b,n_tokens,self.n_heads,self.head_dim) queries = self.W_query(x).view(b,n_tokens,self.n_heads,self.head_dim) values = self.W_value(x).view(b,n_tokens,self.n_heads,self.head_dim) keys = keys.transpose(1,2) queries = queries.transpose(1,2) values = values.transpose(1,2) attn_scores = queries @ keys.transpose(2,3) attn_scores = attn_scores.masked_fill_(self.mask.bool()[:n_tokens,:n_tokens],-torch.inf) attn_weights = torch.softmax(attn_scores/ keys.shape[-1]**0.5, dim=-1) attn_weights = self.dropout(attn_weights) cntx_vec = (attn_weights @ values).transpose(1,2) cntx_vec = cntx_vec.contiguous().view(b,n_tokens,self.d_out) return self.proj(cntx_vec) class NormLayer(nn.Module): def __init__(self,emb_dim): super().__init__() self.eps = 1e-5 self.scale = nn.Parameter(torch.ones(emb_dim)) self.shift = nn.Parameter(torch.zeros(emb_dim)) def forward(self,x): mean = x.mean(dim=-1,keepdim=True) var = x.var(dim=-1,keepdim=True,unbiased=False) return self.scale * ((x-mean)/torch.sqrt(var+self.eps)) + self.shift class GELU(nn.Module): def __init__(self): super().__init__() def forward(self, x): return 0.5 * x * (1 + torch.tanh( torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3)) )) class FeedForward(nn.Module): def __init__(self, cfg): super().__init__() self.layers = nn.Sequential( nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), GELU(), nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), ) def forward(self, x): return self.layers(x) class TransformerBlock(nn.Module): def __init__(self,cfg): super().__init__() self.attn = MultiHeadAttention(d_in=cfg["emb_dim"],d_out=cfg["emb_dim"],context_length=cfg["context_length"],dropout=cfg["drop_rate"],qkv_bias=cfg["qkv_bias"],n_heads=cfg["n_heads"]) self.ff = FeedForward(cfg) self.norm1 = NormLayer(cfg["emb_dim"]) self.norm2 = NormLayer(cfg["emb_dim"]) self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) def forward(self,x): shortcut = x x = self.norm1(x) x = self.attn(x) x = self.drop_shortcut(x) x = x + shortcut shortcut = x x = self.norm2(x) x = self.ff(x) x = self.drop_shortcut(x) x = x + shortcut return x vocab_size=50257 class GPTModel(nn.Module): def __init__(self,cfg): super().__init__() self.tok_emb = nn.Embedding(vocab_size,cfg["emb_dim"]) self.pos_emb = nn.Embedding(cfg["context_length"],cfg["emb_dim"]) self.drop_emb = nn.Dropout(cfg["drop_rate"]) self.tranf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) self.out_head = nn.Linear(cfg["emb_dim"],vocab_size) self.final_norm = NormLayer(cfg["emb_dim"]) def forward(self,x): b,n_inp = x.shape tok_emb = self.tok_emb(x) pos_emb = self.pos_emb(torch.arange(n_inp,device=x.device)) x = tok_emb + pos_emb x= self.drop_emb(x) x = self.tranf_blocks(x) x = self.final_norm(x) x = self.out_head(x) return x def generate_text( model, idx, max_new_tokens, context_size, temperature=0.7, top_k=40 ): model.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -context_size:] with torch.no_grad(): with torch.amp.autocast("cuda"): logits = model(idx_cond) logits = logits[:, -1, :] # temperature scaling logits = logits / temperature # top-k filtering top_logits, top_indices = torch.topk( logits, top_k ) # probabilities only over top-k top_probas = torch.softmax( top_logits, dim=-1 ) # sample from top-k idx_next = top_indices.gather( -1, torch.multinomial(top_probas, 1) ) idx = torch.cat((idx, idx_next), dim=1) return idx def text_to_token_ids(text, tokenizer): encoded = tokenizer.encode(text) encoded_tensor = torch.tensor(encoded,device="cuda").unsqueeze(0) #1 return encoded_tensor def token_ids_to_text(token_ids, tokenizer): flat = token_ids.squeeze(0) #2 return tokenizer.decode(flat.tolist()) def generate_and_print_sample(model, tokenizer, device, start_context): model.eval() context_size = model.pos_emb.weight.shape[0] encoded = text_to_token_ids(start_context, tokenizer).to("cuda") with torch.no_grad(): token_ids = generate_text( model=model, idx=encoded, max_new_tokens=200, context_size=context_size,temperature=0.85,top_k=40 ) decoded_text = token_ids_to_text(token_ids, tokenizer) print(decoded_text.replace("\n", " ")) #1 model.train()