File size: 4,359 Bytes
5a836cf b9ef8cb 5a836cf b9ef8cb 5a836cf b9ef8cb 5a836cf b9ef8cb 5a836cf b9ef8cb 5a836cf b9ef8cb 5a836cf b9ef8cb 5a836cf b9ef8cb 1a215a2 b9ef8cb 1a215a2 b9ef8cb 1a215a2 b9ef8cb 1a215a2 b9ef8cb 1a215a2 b9ef8cb 5a836cf b9ef8cb 1a215a2 5a836cf 1a215a2 b9ef8cb 5a836cf 1a215a2 b9ef8cb 5a836cf 1a215a2 b9ef8cb 5a836cf 1a215a2 5a836cf 1a215a2 5a836cf b9ef8cb 5a836cf b9ef8cb 5a836cf 1a215a2 b9ef8cb 5a836cf 1a215a2 2a4d19d 1a215a2 b9ef8cb 5a836cf 1a215a2 b9ef8cb 5a836cf 1a215a2 b9ef8cb 5a836cf b9ef8cb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | ---
language:
- en
license: mit
tags:
- gpt2
- from-scratch
- decoder-only
- transformer
- zeroshot
- llm-training
- bfloat16
- flash-attention
datasets:
- HuggingFaceFW/fineweb-edu-score-2
- HuggingFaceTB/smoltalk
- tatsu-lab/alpaca
- OpenAssistant/oasst2
pipeline_tag: text-generation
library_name: pytorch
---
# ZeroShot-500M
**530M parameter decoder-only transformer trained entirely from scratch** β base pre-training, mid-training, and supervised fine-tuning β on a single rented RTX 5090.
Part of the **ZeroShot scaling series** by [TobiasLogic](https://github.com/TobiasLogic), a progression of increasingly large GPT-2 style LLMs trained from zero, on consumer/prosumer GPUs, at minimal cost.
---
## Model Details
| | |
|---|---|
| **Parameters** | ~530M |
| **Architecture** | GPT-2 style decoder-only transformer |
| **Layers** | 24 |
| **Attention Heads** | 20 |
| **Embedding Dim** | 1280 |
| **Head Dim** | 64 |
| **Context Window** | 2,048 tokens |
| **Vocab Size** | 50,304 (GPT-2 BPE, padded for tensor cores) |
| **Precision** | bfloat16 |
| **Attention** | Flash Attention via `F.scaled_dot_product_attention` |
| **Weight Tying** | Embedding β LM head |
---
## Training
### Stage 1 β Base Pre-training β
| | |
|---|---|
| **Data** | [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu-score-2) (streamed, zero disk usage) |
| **Tokens** | ~7.9B |
| **Steps** | 30,000 |
| **LR Schedule** | Cosine decay: 4e-4 β 4e-5 |
| **Effective Batch** | 128 sequences (4 micro Γ 32 grad accumulation) |
| **Final Loss** | 2.75 |

### Stage 2 β Mid-training β
| | |
|---|---|
| **Data** | [SmolTalk](https://huggingface.co/datasets/HuggingFaceTB/smoltalk) (80k) + [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (50k) + [OpenAssistant](https://huggingface.co/datasets/OpenAssistant/oasst2) (50k) |
| **Steps** | 4,975 |
| **LR** | 1e-4 (cosine decay) |
| **Final Loss** | 1.03 |

### Stage 3 β SFT β
| | |
|---|---|
| **Steps** | 1,975 |
| **LR** | 3e-5 (cosine decay) |
| **Final Loss** | 0.93 |

---
## Hardware
| | |
|---|---|
| **GPU** | NVIDIA RTX 5090 (32GB GDDR7) |
| **Platform** | [Vast.ai](https://vast.ai) (South Korea) |
| **Cost/hr** | $0.343/hr |
| **Throughput** | ~43,000 tokens/sec |
| **Total Cost** | ~$18 |
| **PyTorch** | Nightly (cu128, Blackwell sm_120) |
| **torch.compile** | Disabled (unsupported on sm_120) |
---
## Checkpoints
| File | Description |
|---|---|
| `ckpt_base_final.pt` | Base pre-trained β 30k steps, loss 2.75 |
| `ckpt_mid_final.pt` | Post mid-training β 4,975 steps, loss 1.03 |
| `ckpt_sft_final.pt` | Final chat model β 1,975 steps, loss 0.93 |
---
## Usage
Requires the `GPT` class from `train.py` included in this repo.
```python
import torch
import tiktoken
from train import GPT, ModelConfig
# Load the SFT checkpoint (chat model)
ckpt = torch.load("ckpt_sft_final.pt", map_location="cuda", weights_only=False)
model = GPT(ModelConfig(**ckpt["model_config"])).to("cuda")
model.load_state_dict(ckpt["model"])
model.eval()
enc = tiktoken.get_encoding("gpt2")
tokens = torch.tensor(
[enc.encode("The meaning of life is")],
dtype=torch.long,
device="cuda"
)
with torch.no_grad():
output = model.generate(tokens, max_new_tokens=200, temperature=0.8, top_k=200)
print(enc.decode(output[0].tolist()))
```
> **Blackwell GPU users (RTX 5060 Ti / 5090):** disable `torch.compile` and use PyTorch nightly cu128.
---
## ZeroShot Family
| Model | Params | Base Loss | SFT Loss | Cost | GPU |
|---|---|---|---|---|---|
| [MicroGPT](https://github.com/TobiasLogic/microgpt) | 30.5M | 3.85 | β | Free | RTX 3050 |
| [ZeroShot-124M](https://huggingface.co/TobiasLogic/ZeroShot-124M) | 124M | 3.45 | 1.60 | ~$6.77 | RTX 5060 Ti |
| [ZeroShot-350M](https://huggingface.co/TobiasLogic/ZeroShot-350M) | 337M | 3.20 | 1.3| ~$7.88 | RTX 5090 |
| **ZeroShot-500M** | **530M** | **2.75** | **0.93** | **~$18** | **RTX 5090** |
---
## Limitations
- Undertrained by Chinchilla standards (~75% of optimal tokens for 530M params)
- Will hallucinate, repeat, and struggle on complex reasoning tasks
- English only
- No RLHF or safety alignment
---
## License
[MIT](https://opensource.org/licenses/MIT) |