File size: 4,359 Bytes
5a836cf
b9ef8cb
 
5a836cf
b9ef8cb
 
 
 
 
 
 
 
 
5a836cf
b9ef8cb
 
 
 
5a836cf
 
 
 
b9ef8cb
 
 
5a836cf
b9ef8cb
5a836cf
b9ef8cb
5a836cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9ef8cb
 
5a836cf
 
b9ef8cb
 
 
 
 
 
 
 
 
1a215a2
b9ef8cb
1a215a2
 
 
b9ef8cb
 
 
 
1a215a2
 
 
 
 
b9ef8cb
1a215a2
b9ef8cb
 
 
1a215a2
 
 
 
 
b9ef8cb
 
5a836cf
 
 
 
 
 
b9ef8cb
 
1a215a2
5a836cf
 
1a215a2
b9ef8cb
 
5a836cf
 
 
1a215a2
 
 
 
 
b9ef8cb
 
5a836cf
 
 
1a215a2
b9ef8cb
5a836cf
 
 
1a215a2
5a836cf
1a215a2
 
5a836cf
 
 
 
 
b9ef8cb
 
 
 
 
 
5a836cf
 
b9ef8cb
5a836cf
 
 
1a215a2
b9ef8cb
 
 
5a836cf
 
1a215a2
 
 
 
2a4d19d
1a215a2
b9ef8cb
 
 
5a836cf
 
1a215a2
b9ef8cb
5a836cf
1a215a2
b9ef8cb
 
5a836cf
 
 
b9ef8cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
---
language:
- en
license: mit
tags:
- gpt2
- from-scratch
- decoder-only
- transformer
- zeroshot
- llm-training
- bfloat16
- flash-attention
datasets:
- HuggingFaceFW/fineweb-edu-score-2
- HuggingFaceTB/smoltalk
- tatsu-lab/alpaca
- OpenAssistant/oasst2
pipeline_tag: text-generation
library_name: pytorch
---

# ZeroShot-500M

**530M parameter decoder-only transformer trained entirely from scratch** β€” base pre-training, mid-training, and supervised fine-tuning β€” on a single rented RTX 5090.

Part of the **ZeroShot scaling series** by [TobiasLogic](https://github.com/TobiasLogic), a progression of increasingly large GPT-2 style LLMs trained from zero, on consumer/prosumer GPUs, at minimal cost.

---

## Model Details

| | |
|---|---|
| **Parameters** | ~530M |
| **Architecture** | GPT-2 style decoder-only transformer |
| **Layers** | 24 |
| **Attention Heads** | 20 |
| **Embedding Dim** | 1280 |
| **Head Dim** | 64 |
| **Context Window** | 2,048 tokens |
| **Vocab Size** | 50,304 (GPT-2 BPE, padded for tensor cores) |
| **Precision** | bfloat16 |
| **Attention** | Flash Attention via `F.scaled_dot_product_attention` |
| **Weight Tying** | Embedding ↔ LM head |

---

## Training

### Stage 1 β€” Base Pre-training βœ…

| | |
|---|---|
| **Data** | [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu-score-2) (streamed, zero disk usage) |
| **Tokens** | ~7.9B |
| **Steps** | 30,000 |
| **LR Schedule** | Cosine decay: 4e-4 β†’ 4e-5 |
| **Effective Batch** | 128 sequences (4 micro Γ— 32 grad accumulation) |
| **Final Loss** | 2.75 |

![Base loss curve](loss_curve_base.png)

### Stage 2 β€” Mid-training βœ…

| | |
|---|---|
| **Data** | [SmolTalk](https://huggingface.co/datasets/HuggingFaceTB/smoltalk) (80k) + [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (50k) + [OpenAssistant](https://huggingface.co/datasets/OpenAssistant/oasst2) (50k) |
| **Steps** | 4,975 |
| **LR** | 1e-4 (cosine decay) |
| **Final Loss** | 1.03 |

![Mid loss curve](loss_curve_mid.png)

### Stage 3 β€” SFT βœ…

| | |
|---|---|
| **Steps** | 1,975 |
| **LR** | 3e-5 (cosine decay) |
| **Final Loss** | 0.93 |

![SFT loss curve](loss_curve_sft.png)

---

## Hardware

| | |
|---|---|
| **GPU** | NVIDIA RTX 5090 (32GB GDDR7) |
| **Platform** | [Vast.ai](https://vast.ai) (South Korea) |
| **Cost/hr** | $0.343/hr |
| **Throughput** | ~43,000 tokens/sec |
| **Total Cost** | ~$18 |
| **PyTorch** | Nightly (cu128, Blackwell sm_120) |
| **torch.compile** | Disabled (unsupported on sm_120) |

---

## Checkpoints

| File | Description |
|---|---|
| `ckpt_base_final.pt` | Base pre-trained β€” 30k steps, loss 2.75 |
| `ckpt_mid_final.pt` | Post mid-training β€” 4,975 steps, loss 1.03 |
| `ckpt_sft_final.pt` | Final chat model β€” 1,975 steps, loss 0.93 |

---

## Usage

Requires the `GPT` class from `train.py` included in this repo.

```python
import torch
import tiktoken
from train import GPT, ModelConfig

# Load the SFT checkpoint (chat model)
ckpt = torch.load("ckpt_sft_final.pt", map_location="cuda", weights_only=False)
model = GPT(ModelConfig(**ckpt["model_config"])).to("cuda")
model.load_state_dict(ckpt["model"])
model.eval()

enc = tiktoken.get_encoding("gpt2")
tokens = torch.tensor(
    [enc.encode("The meaning of life is")],
    dtype=torch.long,
    device="cuda"
)

with torch.no_grad():
    output = model.generate(tokens, max_new_tokens=200, temperature=0.8, top_k=200)

print(enc.decode(output[0].tolist()))
```

> **Blackwell GPU users (RTX 5060 Ti / 5090):** disable `torch.compile` and use PyTorch nightly cu128.

---

## ZeroShot Family

| Model | Params | Base Loss | SFT Loss | Cost | GPU |
|---|---|---|---|---|---|
| [MicroGPT](https://github.com/TobiasLogic/microgpt) | 30.5M | 3.85 | β€” | Free | RTX 3050 |
| [ZeroShot-124M](https://huggingface.co/TobiasLogic/ZeroShot-124M) | 124M | 3.45 | 1.60 | ~$6.77 | RTX 5060 Ti |
| [ZeroShot-350M](https://huggingface.co/TobiasLogic/ZeroShot-350M) | 337M | 3.20 | 1.3| ~$7.88 | RTX 5090 |
| **ZeroShot-500M** | **530M** | **2.75** | **0.93** | **~$18** | **RTX 5090** |

---

## Limitations

- Undertrained by Chinchilla standards (~75% of optimal tokens for 530M params)
- Will hallucinate, repeat, and struggle on complex reasoning tasks
- English only
- No RLHF or safety alignment

---

## License

[MIT](https://opensource.org/licenses/MIT)