""" nanoGPT SLM -- Cheerful TinyStories (RLVR) Inference ====================================================== 124M-parameter nanoGPT trained from scratch on TinyStories, then aligned with a 3-stage pipeline: Pretraining -> SFT (positive subset) -> RLVR (policy gradient against a VADER sentiment reward). This script loads the RLVR model as the primary generator and also provides a 3-model comparison (Pretrained vs SFT vs RLVR) scored with VADER sentiment. Install: pip install torch tiktoken huggingface_hub nltk Run: python nanogpt_slm_rlvr_inference_tinystories.py Import: from nanogpt_slm_rlvr_inference_tinystories import tell_story, ask, generate_text, compare_models """ import torch, torch.nn as nn, torch.nn.functional as F, math, tiktoken from dataclasses import dataclass from huggingface_hub import hf_hub_download # ============================================================== # ARCHITECTURE (nanoGPT -- 124M parameters) # ============================================================== class LayerNorm(nn.Module): def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None def forward(self, x): return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) self.n_head, self.n_embd = config.n_head, config.n_embd self.flash = hasattr(F, 'scaled_dot_product_attention') if not self.flash: self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) def forward(self, x): B, T, C = x.size() q, k, v = self.c_attn(x).split(self.n_embd, dim=2) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) if self.flash: y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True) else: att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v y = y.transpose(1, 2).contiguous().view(B, T, C) return self.resid_dropout(self.c_proj(y)) class MLP(nn.Module): def __init__(self, config): super().__init__() self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) self.gelu = nn.GELU() self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.c_proj(self.gelu(self.c_fc(x)))) class Block(nn.Module): def __init__(self, config): super().__init__() self.ln1, self.attn = LayerNorm(config.n_embd, config.bias), CausalSelfAttention(config) self.ln2, self.mlp = LayerNorm(config.n_embd, config.bias), MLP(config) def forward(self, x): x = x + self.attn(self.ln1(x)) return x + self.mlp(self.ln2(x)) @dataclass class GPTConfig: block_size: int = 512 vocab_size: int = 50257 n_layer: int = 12 n_head: int = 12 n_embd: int = 768 dropout: float = 0.0 bias: bool = True class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.n_embd), wpe=nn.Embedding(config.block_size, config.n_embd), drop=nn.Dropout(config.dropout), h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f=LayerNorm(config.n_embd, config.bias), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight # weight tying def forward(self, idx, targets=None): b, t = idx.size() pos = torch.arange(0, t, dtype=torch.long, device=idx.device) x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos)) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) if targets is not None: logits = self.lm_head(x) return logits, F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: return self.lm_head(x[:, [-1], :]), None @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, eos_token=50256): """Autoregressive generation with EOS stopping.""" for _ in range(max_new_tokens): idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) if eos_token is not None and idx_next.item() == eos_token: break idx = torch.cat((idx, idx_next), dim=1) return idx # ============================================================== # KV CACHE -- O(1) per decode step # ============================================================== class CausalSelfAttentionKV(nn.Module): """CausalSelfAttention with KV cache. Same param names for weight compat.""" def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) self.n_head, self.n_embd = config.n_head, config.n_embd self.head_dim = config.n_embd // config.n_head def forward(self, x, kv_cache=None, use_cache=False): B, T, C = x.size() q, k, v = self.c_attn(x).split(self.n_embd, dim=2) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) if kv_cache is not None: k = torch.cat([kv_cache[0], k], dim=2) v = torch.cat([kv_cache[1], v], dim=2) new_cache = (k, v) if use_cache else None S = k.size(2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim)) if T > 1: mask = torch.triu(torch.ones(T, S, device=x.device, dtype=torch.bool), diagonal=S - T + 1) att = att.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = (att @ v).transpose(1, 2).contiguous().view(B, T, C) return self.resid_dropout(self.c_proj(y)), new_cache class BlockKV(nn.Module): def __init__(self, config): super().__init__() self.ln1, self.attn = LayerNorm(config.n_embd, config.bias), CausalSelfAttentionKV(config) self.ln2, self.mlp = LayerNorm(config.n_embd, config.bias), MLP(config) def forward(self, x, kv_cache=None, use_cache=False): attn_out, new_cache = self.attn(self.ln1(x), kv_cache=kv_cache, use_cache=use_cache) x = x + attn_out return x + self.mlp(self.ln2(x)), new_cache class GPTKV(nn.Module): """GPT with KV cache. Same weight names as GPT for load_state_dict compat.""" def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size, config.n_embd), wpe=nn.Embedding(config.block_size, config.n_embd), drop=nn.Dropout(config.dropout), h=nn.ModuleList([BlockKV(config) for _ in range(config.n_layer)]), ln_f=LayerNorm(config.n_embd, config.bias), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.transformer.wte.weight = self.lm_head.weight def forward(self, idx, kv_caches=None, use_cache=False): b, t = idx.size() if kv_caches is not None and kv_caches[0] is not None: cache_len = kv_caches[0][0].size(2) pos = torch.arange(cache_len, cache_len + t, dtype=torch.long, device=idx.device) else: pos = torch.arange(0, t, dtype=torch.long, device=idx.device) x = self.transformer.drop(self.transformer.wte(idx) + self.transformer.wpe(pos)) new_caches = [] if kv_caches is None: kv_caches = [None] * len(self.transformer.h) for block, kv_cache in zip(self.transformer.h, kv_caches): x, new_cache = block(x, kv_cache=kv_cache, use_cache=use_cache) new_caches.append(new_cache) x = self.transformer.ln_f(x) logits = self.lm_head(x) return (logits, new_caches) if use_cache else (logits, None) @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, eos_token=50256): """Generate with KV cache -- O(1) per decode step. Stops at <|endoftext|> (eos_token=50256) for clean story boundaries. Caps generation at block_size to prevent positional embedding overflow.""" max_new_tokens = min(max_new_tokens, self.config.block_size - idx.size(1)) logits, kv_caches = self(idx, kv_caches=None, use_cache=True) logits = logits[:, -1, :] / temperature for _ in range(max_new_tokens): if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) if eos_token is not None and idx_next.item() == eos_token: break idx = torch.cat((idx, idx_next), dim=1) logits, kv_caches = self(idx_next, kv_caches=kv_caches, use_cache=True) logits = logits[:, -1, :] / temperature return idx # ============================================================== # CONFIG -- the HuggingFace repo holding all 3 checkpoints # ============================================================== REPO_ID = "nishantup/nanogpt-rlvr-slm-tinystories-124m" CHECKPOINTS = { "Pretrained": "nanogpt_slm_tinystories_best.pth", "SFT (positive)": "nanogpt_slm_sft_best.pth", "RLVR": "nanogpt_slm_rlvr_final.pth", } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") config = GPTConfig() tokenizer = tiktoken.get_encoding("gpt2") def _load(filename): """Download a checkpoint from the Hub and load it into a KV-cache model.""" path = hf_hub_download(repo_id=REPO_ID, filename=filename) m = GPTKV(config) m.load_state_dict(torch.load(path, map_location=device)) m.to(device) m.eval() return m # Lazy-loaded model cache (models load on first use to keep import light) _MODELS = {} def get_model(stage="RLVR"): """Return a pipeline-stage model, loading it on first request. stage in {'Pretrained', 'SFT (positive)', 'RLVR'}.""" if stage not in _MODELS: _MODELS[stage] = _load(CHECKPOINTS[stage]) return _MODELS[stage] # ============================================================== # GENERATION HELPERS (primary model = RLVR) # ============================================================== def generate_with(m, prompt, max_tokens=500, temperature=0.8, top_k=40): """Generate a continuation from a SPECIFIC model instance.""" idx = torch.tensor( tokenizer.encode(prompt, allowed_special={'<|endoftext|>'}) ).unsqueeze(0).to(device) out = m.generate(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k) tokens = out.squeeze(0).tolist() if 50256 in tokens: # trim at <|endoftext|> tokens = tokens[:tokens.index(50256)] return tokenizer.decode(tokens) def generate_text(prompt, max_tokens=500, temperature=0.8, top_k=40): """Generate text with the RLVR model -- the cheerful, sentiment-optimized stage.""" return generate_with(get_model("RLVR"), prompt, max_tokens, temperature, top_k) def ask(prompt, max_tokens=500, temperature=0.8, top_k=40): """Alias for generate_text -- complete any prompt with the RLVR model.""" return generate_text(prompt, max_tokens, temperature, top_k) def tell_story(beginning="Once upon a time", max_tokens=500, temperature=0.8, top_k=40): """Generate a cheerful children's story from an opening line (RLVR model).""" return generate_text(beginning, max_tokens, temperature, top_k) # ============================================================== # VADER SENTIMENT (the verifiable reward RLVR was trained on) # ============================================================== def _get_sia(): import nltk nltk.download('vader_lexicon', quiet=True) from nltk.sentiment.vader import SentimentIntensityAnalyzer return SentimentIntensityAnalyzer() _SIA = None def score_sentiment(text): """VADER compound sentiment in [-1, +1].""" global _SIA if _SIA is None: _SIA = _get_sia() if not text or not text.strip(): return 0.0 return _SIA.polarity_scores(text)['compound'] # ============================================================== # 3-MODEL COMPARISON # ============================================================== def compare_models(prompt, max_tokens=250, temperature=0.8, top_k=40, seed=1234): """Generate the same prompt from all 3 pipeline stages. Returns a list of dicts: {'stage', 'sentiment', 'story'} -- one per stage, in pipeline order (Pretrained -> SFT -> RLVR). """ rows = [] for stage in CHECKPOINTS: # preserves pipeline order if seed is not None: torch.manual_seed(seed) # same RNG start = fair compare story = generate_with(get_model(stage), prompt, max_tokens, temperature, top_k) rows.append({"stage": stage, "sentiment": score_sentiment(story), "story": story}) return rows # ============================================================== # EXAMPLES (only run when executed directly) # ============================================================== if __name__ == "__main__": print("=" * 64) print(" nanoGPT SLM -- RLVR Cheerful Story Generation") print("=" * 64) print(f" device: {device} | repo: {REPO_ID}") # -- RLVR story examples -- starters = [ "Once upon a time there was a little rabbit", "A girl named Lily went to the park", "The friendly dragon lived in a cave", "One day, a boy found a magic key", "There was a tiny kitten who loved to play", ] print("\n" + "=" * 64) print(" RLVR MODEL -- SAMPLE STORIES") print("=" * 64) for s in starters: story = tell_story(s, max_tokens=300) print(f"\nStarter: {s} (sentiment {score_sentiment(story):+.3f})") print("-" * 64) print(story) # -- 3-stage comparison -- print("\n" + "=" * 64) print(" 3-STAGE COMPARISON (Pretrained vs SFT vs RLVR)") print("=" * 64) for prompt in ["The little girl was sad until", "On a rainy day, the puppy"]: print(f"\nPrompt: {prompt}") for row in compare_models(prompt, max_tokens=220): print(f"\n [{row['stage']}] sentiment {row['sentiment']:+.3f}") print(f" {row['story']}") print("-" * 64)