ZeroShot-500M / README.md
TobiasLogic's picture
Update README.md
2a4d19d verified
---
language:
- en
license: mit
tags:
- gpt2
- from-scratch
- decoder-only
- transformer
- zeroshot
- llm-training
- bfloat16
- flash-attention
datasets:
- HuggingFaceFW/fineweb-edu-score-2
- HuggingFaceTB/smoltalk
- tatsu-lab/alpaca
- OpenAssistant/oasst2
pipeline_tag: text-generation
library_name: pytorch
---
# ZeroShot-500M
**530M parameter decoder-only transformer trained entirely from scratch** β€” base pre-training, mid-training, and supervised fine-tuning β€” on a single rented RTX 5090.
Part of the **ZeroShot scaling series** by [TobiasLogic](https://github.com/TobiasLogic), a progression of increasingly large GPT-2 style LLMs trained from zero, on consumer/prosumer GPUs, at minimal cost.
---
## Model Details
| | |
|---|---|
| **Parameters** | ~530M |
| **Architecture** | GPT-2 style decoder-only transformer |
| **Layers** | 24 |
| **Attention Heads** | 20 |
| **Embedding Dim** | 1280 |
| **Head Dim** | 64 |
| **Context Window** | 2,048 tokens |
| **Vocab Size** | 50,304 (GPT-2 BPE, padded for tensor cores) |
| **Precision** | bfloat16 |
| **Attention** | Flash Attention via `F.scaled_dot_product_attention` |
| **Weight Tying** | Embedding ↔ LM head |
---
## Training
### Stage 1 β€” Base Pre-training βœ…
| | |
|---|---|
| **Data** | [FineWeb-Edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu-score-2) (streamed, zero disk usage) |
| **Tokens** | ~7.9B |
| **Steps** | 30,000 |
| **LR Schedule** | Cosine decay: 4e-4 β†’ 4e-5 |
| **Effective Batch** | 128 sequences (4 micro Γ— 32 grad accumulation) |
| **Final Loss** | 2.75 |
![Base loss curve](loss_curve_base.png)
### Stage 2 β€” Mid-training βœ…
| | |
|---|---|
| **Data** | [SmolTalk](https://huggingface.co/datasets/HuggingFaceTB/smoltalk) (80k) + [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) (50k) + [OpenAssistant](https://huggingface.co/datasets/OpenAssistant/oasst2) (50k) |
| **Steps** | 4,975 |
| **LR** | 1e-4 (cosine decay) |
| **Final Loss** | 1.03 |
![Mid loss curve](loss_curve_mid.png)
### Stage 3 β€” SFT βœ…
| | |
|---|---|
| **Steps** | 1,975 |
| **LR** | 3e-5 (cosine decay) |
| **Final Loss** | 0.93 |
![SFT loss curve](loss_curve_sft.png)
---
## Hardware
| | |
|---|---|
| **GPU** | NVIDIA RTX 5090 (32GB GDDR7) |
| **Platform** | [Vast.ai](https://vast.ai) (South Korea) |
| **Cost/hr** | $0.343/hr |
| **Throughput** | ~43,000 tokens/sec |
| **Total Cost** | ~$18 |
| **PyTorch** | Nightly (cu128, Blackwell sm_120) |
| **torch.compile** | Disabled (unsupported on sm_120) |
---
## Checkpoints
| File | Description |
|---|---|
| `ckpt_base_final.pt` | Base pre-trained β€” 30k steps, loss 2.75 |
| `ckpt_mid_final.pt` | Post mid-training β€” 4,975 steps, loss 1.03 |
| `ckpt_sft_final.pt` | Final chat model β€” 1,975 steps, loss 0.93 |
---
## Usage
Requires the `GPT` class from `train.py` included in this repo.
```python
import torch
import tiktoken
from train import GPT, ModelConfig
# Load the SFT checkpoint (chat model)
ckpt = torch.load("ckpt_sft_final.pt", map_location="cuda", weights_only=False)
model = GPT(ModelConfig(**ckpt["model_config"])).to("cuda")
model.load_state_dict(ckpt["model"])
model.eval()
enc = tiktoken.get_encoding("gpt2")
tokens = torch.tensor(
[enc.encode("The meaning of life is")],
dtype=torch.long,
device="cuda"
)
with torch.no_grad():
output = model.generate(tokens, max_new_tokens=200, temperature=0.8, top_k=200)
print(enc.decode(output[0].tolist()))
```
> **Blackwell GPU users (RTX 5060 Ti / 5090):** disable `torch.compile` and use PyTorch nightly cu128.
---
## ZeroShot Family
| Model | Params | Base Loss | SFT Loss | Cost | GPU |
|---|---|---|---|---|---|
| [MicroGPT](https://github.com/TobiasLogic/microgpt) | 30.5M | 3.85 | β€” | Free | RTX 3050 |
| [ZeroShot-124M](https://huggingface.co/TobiasLogic/ZeroShot-124M) | 124M | 3.45 | 1.60 | ~$6.77 | RTX 5060 Ti |
| [ZeroShot-350M](https://huggingface.co/TobiasLogic/ZeroShot-350M) | 337M | 3.20 | 1.3| ~$7.88 | RTX 5090 |
| **ZeroShot-500M** | **530M** | **2.75** | **0.93** | **~$18** | **RTX 5090** |
---
## Limitations
- Undertrained by Chinchilla standards (~75% of optimal tokens for 530M params)
- Will hallucinate, repeat, and struggle on complex reasoning tasks
- English only
- No RLHF or safety alignment
---
## License
[MIT](https://opensource.org/licenses/MIT)