--- language: - en license: mit tags: - gpt2 - from-scratch - decoder-only - transformer - zeroshot - llm-training - bfloat16 - flash-attention datasets: - HuggingFaceFW/fineweb-edu-score-2 - HuggingFaceTB/smoltalk - tatsu-lab/alpaca - OpenAssistant/oasst2 pipeline_tag: text-generation library_name: pytorch --- # ZeroShot-500M **530M parameter decoder-only transformer trained entirely from scratch** — base pre-training, mid-training, and supervised fine-tuning — on a single rented RTX 5090. Part of the **ZeroShot scaling series** by [TobiasLogic](https://github.com/TobiasLogic), a progression of increasingly large GPT-2 style LLMs trained from zero, on consumer/prosumer GPUs, at minimal cost. --- ## Model Details | | | |---|---| | **Parameters** | ~530M | | **Architecture** | GPT-2 style decoder-only transformer | | **Layers** | 24 | | **Attention Heads** | 20 | | **Embedding Dim** | 1280 | | **Head Dim** | 64 | | **Context Window** | 2,048 tokens | | **Vocab Size** | 50,304 (GPT-2 BPE, padded for tensor cores) | | **Precision** | bfloat16 | | **Attention** | Flash Attention via `F.scaled_dot_product_attention` | | **Weight Tying** | Embedding ↔ LM head | --- ## Training ### Stage 1 — Base Pre-training ✅ | | | |---|---| | **Data** | [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu-score-2) (streamed, zero disk usage) | | **Tokens** | ~7.9B | | **Steps** | 30,000 | | **LR Schedule** | Cosine decay: 4e-4 → 4e-5 | | **Effective Batch** | 128 sequences (4 micro × 32 grad accumulation) | | **Final Loss** | 2.75 | ![Base loss curve](loss_curve_base.png) ### Stage 2 — Mid-training ✅ | | | |---|---| | **Data** | [SmolTalk](https://huggingface.co/datasets/HuggingFaceTB/smoltalk) (80k) + [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (50k) + [OpenAssistant](https://huggingface.co/datasets/OpenAssistant/oasst2) (50k) | | **Steps** | 4,975 | | **LR** | 1e-4 (cosine decay) | | **Final Loss** | 1.03 | ![Mid loss curve](loss_curve_mid.png) ### Stage 3 — SFT ✅ | | | |---|---| | **Steps** | 1,975 | | **LR** | 3e-5 (cosine decay) | | **Final Loss** | 0.93 | ![SFT loss curve](loss_curve_sft.png) --- ## Hardware | | | |---|---| | **GPU** | NVIDIA RTX 5090 (32GB GDDR7) | | **Platform** | [Vast.ai](https://vast.ai) (South Korea) | | **Cost/hr** | $0.343/hr | | **Throughput** | ~43,000 tokens/sec | | **Total Cost** | ~$18 | | **PyTorch** | Nightly (cu128, Blackwell sm_120) | | **torch.compile** | Disabled (unsupported on sm_120) | --- ## Checkpoints | File | Description | |---|---| | `ckpt_base_final.pt` | Base pre-trained — 30k steps, loss 2.75 | | `ckpt_mid_final.pt` | Post mid-training — 4,975 steps, loss 1.03 | | `ckpt_sft_final.pt` | Final chat model — 1,975 steps, loss 0.93 | --- ## Usage Requires the `GPT` class from `train.py` included in this repo. ```python import torch import tiktoken from train import GPT, ModelConfig # Load the SFT checkpoint (chat model) ckpt = torch.load("ckpt_sft_final.pt", map_location="cuda", weights_only=False) model = GPT(ModelConfig(**ckpt["model_config"])).to("cuda") model.load_state_dict(ckpt["model"]) model.eval() enc = tiktoken.get_encoding("gpt2") tokens = torch.tensor( [enc.encode("The meaning of life is")], dtype=torch.long, device="cuda" ) with torch.no_grad(): output = model.generate(tokens, max_new_tokens=200, temperature=0.8, top_k=200) print(enc.decode(output[0].tolist())) ``` > **Blackwell GPU users (RTX 5060 Ti / 5090):** disable `torch.compile` and use PyTorch nightly cu128. --- ## ZeroShot Family | Model | Params | Base Loss | SFT Loss | Cost | GPU | |---|---|---|---|---|---| | [MicroGPT](https://github.com/TobiasLogic/microgpt) | 30.5M | 3.85 | — | Free | RTX 3050 | | [ZeroShot-124M](https://huggingface.co/TobiasLogic/ZeroShot-124M) | 124M | 3.45 | 1.60 | ~$6.77 | RTX 5060 Ti | | [ZeroShot-350M](https://huggingface.co/TobiasLogic/ZeroShot-350M) | 337M | 3.20 | 1.3| ~$7.88 | RTX 5090 | | **ZeroShot-500M** | **530M** | **2.75** | **0.93** | **~$18** | **RTX 5090** | --- ## Limitations - Undertrained by Chinchilla standards (~75% of optimal tokens for 530M params) - Will hallucinate, repeat, and struggle on complex reasoning tasks - English only - No RLHF or safety alignment --- ## License [MIT](https://opensource.org/licenses/MIT)