thepartsofspeech commited on
Commit
ab876da
·
verified ·
1 Parent(s): 172d237

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +28 -0
  2. .gitignore +5 -0
  3. Dockerfile +33 -0
  4. LICENSE +21 -0
  5. README.md +375 -0
  6. data/cached_fineweb100B.py +16 -0
  7. data/cached_fineweb10B.py +16 -0
  8. data/cached_finewebedu10B.py +16 -0
  9. data/fineweb.py +126 -0
  10. data/requirements.txt +2 -0
  11. img/algo_optimizer.png +3 -0
  12. img/dofa.jpg +0 -0
  13. img/fig_optimizer.png +3 -0
  14. img/fig_tuned_nanogpt.png +3 -0
  15. img/nanogpt_speedrun51.png +3 -0
  16. img/nanogpt_speedrun52.png +0 -0
  17. img/nanogpt_speedrun53.png +3 -0
  18. img/nanogpt_speedrun54.png +0 -0
  19. records/track_1_short/2024-06-06_AdamW/README.md +8 -0
  20. records/track_1_short/2024-06-06_AdamW/f66d43d7-e449-4029-8adf-e8537bab49ea.log +0 -0
  21. records/track_1_short/2024-10-09_SOAP/5bdc3988-496c-4232-b4ef-53764cb81c92.txt +0 -0
  22. records/track_1_short/2024-10-09_SOAP/README.md +9 -0
  23. records/track_1_short/2024-10-09_SOAP/train_gpt2.py +857 -0
  24. records/track_1_short/2024-10-10_Muon/eb5659d0-fb6a-49e5-a311-f1f89412f726.txt +0 -0
  25. records/track_1_short/2024-10-10_Muon/train_gpt2.py +524 -0
  26. records/track_1_short/2024-10-13_llmc/README.md +11 -0
  27. records/track_1_short/2024-10-13_llmc/main.log +0 -0
  28. records/track_1_short/2024-10-14_ModernArch/dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt +0 -0
  29. records/track_1_short/2024-10-14_ModernArch/train_gpt2.py +516 -0
  30. records/track_1_short/2024-10-17_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt +0 -0
  31. records/track_1_short/2024-10-18_PyTorch25/d4bfb25f-688d-4da5-8743-33926fad4842.txt +0 -0
  32. records/track_1_short/2024-10-20_ScaleUp1B/87bd51fd-6203-4c88-b3aa-8a849a6a83ca.txt +0 -0
  33. records/track_1_short/2024-10-20_ScaleUp1B/ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt +0 -0
  34. records/track_1_short/2024-10-20_ScaleUp1B/c0078066-c8c9-49c8-868a-ff4d4f32e615.txt +0 -0
  35. records/track_1_short/2024-10-29_Optimizers/8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt +0 -0
  36. records/track_1_short/2024-10-29_Optimizers/8d6193f4-27fc-4e68-899f-af70019a4d54.txt +0 -0
  37. records/track_1_short/2024-10-29_Optimizers/95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt +0 -0
  38. records/track_1_short/2024-10-29_Optimizers/README.md +103 -0
  39. records/track_1_short/2024-10-29_Optimizers/e21a2838-a0f2-46f2-a247-db0021165682.txt +0 -0
  40. records/track_1_short/2024-10-29_Optimizers/nanogpt_speedrun81w.png +3 -0
  41. records/track_1_short/2024-10-29_Optimizers/nanogpt_speedrun82w.png +3 -0
  42. records/track_1_short/2024-11-03_UntieEmbed/README.md +27 -0
  43. records/track_1_short/2024-11-03_UntieEmbed/d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt +0 -0
  44. records/track_1_short/2024-11-04_50Bruns/3d715d41-453a-40d6-9506-421ba69766b2.txt +0 -0
  45. records/track_1_short/2024-11-04_50Bruns/4fbe61ec-f79a-4c19-836d-46d599deecce.txt +0 -0
  46. records/track_1_short/2024-11-04_50Bruns/530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt +0 -0
  47. records/track_1_short/2024-11-04_50Bruns/69c33fc9-eabb-4a38-aa08-6922914eb405.txt +0 -0
  48. records/track_1_short/2024-11-04_50Bruns/README.md +26 -0
  49. records/track_1_short/2024-11-06_ShortcutsTweaks/042f9e87-07e6-4504-bb04-4ec59a380211.txt +0 -0
  50. records/track_1_short/2024-11-06_ShortcutsTweaks/05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt +0 -0
.gitattributes CHANGED
@@ -33,3 +33,31 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/algo_optimizer.png filter=lfs diff=lfs merge=lfs -text
37
+ img/fig_optimizer.png filter=lfs diff=lfs merge=lfs -text
38
+ img/fig_tuned_nanogpt.png filter=lfs diff=lfs merge=lfs -text
39
+ img/nanogpt_speedrun51.png filter=lfs diff=lfs merge=lfs -text
40
+ img/nanogpt_speedrun53.png filter=lfs diff=lfs merge=lfs -text
41
+ records/track_1_short/2024-10-29_Optimizers/nanogpt_speedrun81w.png filter=lfs diff=lfs merge=lfs -text
42
+ records/track_1_short/2024-10-29_Optimizers/nanogpt_speedrun82w.png filter=lfs diff=lfs merge=lfs -text
43
+ records/track_1_short/2024-11-06_ShortcutsTweaks/nanogpt_speedrun111.png filter=lfs diff=lfs merge=lfs -text
44
+ records/track_1_short/2025-01-04_SoftCap/curves_010425.png filter=lfs diff=lfs merge=lfs -text
45
+ records/track_1_short/2025-01-16_Sub3Min/long-short-swa.png filter=lfs diff=lfs merge=lfs -text
46
+ records/track_1_short/2025-01-26_BatchSize/ablations.png filter=lfs diff=lfs merge=lfs -text
47
+ records/track_1_short/2025-09-03_FA3/media/attn_speed_vs_batch_s1024_ws384.png filter=lfs diff=lfs merge=lfs -text
48
+ records/track_1_short/2025-10-31_AdamSyncGradientHook/profiler-trace-current-comm-overlap.png filter=lfs diff=lfs merge=lfs -text
49
+ records/track_1_short/2025-10-31_AdamSyncGradientHook/profiler-trace-current-first-rs.png filter=lfs diff=lfs merge=lfs -text
50
+ records/track_1_short/2025-10-31_AdamSyncGradientHook/profiler-trace-current-overview.png filter=lfs diff=lfs merge=lfs -text
51
+ records/track_1_short/2025-10-31_AdamSyncGradientHook/profiler-trace-hook-comm-overlap.png filter=lfs diff=lfs merge=lfs -text
52
+ records/track_1_short/2025-10-31_AdamSyncGradientHook/profiler-trace-hook-first-rs.png filter=lfs diff=lfs merge=lfs -text
53
+ records/track_1_short/2025-10-31_AdamSyncGradientHook/profiler-trace-hook-overview.png filter=lfs diff=lfs merge=lfs -text
54
+ records/track_1_short/2025-11-10_CautiousWD/assets/cwd_condition_numbers.jpg filter=lfs diff=lfs merge=lfs -text
55
+ records/track_1_short/2025-11-29_BatchSizeSchedule/val_loss_five_step.png filter=lfs diff=lfs merge=lfs -text
56
+ records/track_1_short/2025-12-19_RetieLMHead/lm_head_weights.png filter=lfs diff=lfs merge=lfs -text
57
+ records/track_1_short/2025-12-21_SmoothedScalars/resid_lambdas_plots.png filter=lfs diff=lfs merge=lfs -text
58
+ records/track_1_short/2025-12-21_SmoothedScalars/smear_gate_plots.png filter=lfs diff=lfs merge=lfs -text
59
+ records/track_1_short/2025-12-21_SmoothedScalars/smear_lambda_plot.png filter=lfs diff=lfs merge=lfs -text
60
+ records/track_1_short/2025-12-21_SmoothedScalars/x0_lambdas_plots.png filter=lfs diff=lfs merge=lfs -text
61
+ records/track_1_short/2025-12-31_GatesToCompiledAdam/adam_kernel_fusion.png filter=lfs diff=lfs merge=lfs -text
62
+ records/track_1_short/2025-12-31_GatesToCompiledAdam/impact_to_opt_window.png filter=lfs diff=lfs merge=lfs -text
63
+ records/track_1_short/2025-12-31_GatesToCompiledAdam/smoothing_plots.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fineweb10B/
2
+ pylog124M/
3
+ __pycache__/
4
+ logs/
5
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.6.2-cudnn-devel-ubuntu24.04
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+ ENV PYTHON_VERSION=3.12.7
5
+ ENV PATH=/usr/local/bin:$PATH
6
+
7
+ RUN apt update && apt install -y --no-install-recommends build-essential libssl-dev zlib1g-dev \
8
+ libbz2-dev libreadline-dev libsqlite3-dev curl git libncursesw5-dev xz-utils tk-dev libxml2-dev \
9
+ libxmlsec1-dev libffi-dev liblzma-dev \
10
+ && apt clean && rm -rf /var/lib/apt/lists/*
11
+
12
+ RUN curl -O https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz && \
13
+ tar -xzf Python-${PYTHON_VERSION}.tgz && \
14
+ cd Python-${PYTHON_VERSION} && \
15
+ ./configure --enable-optimizations && \
16
+ make -j$(nproc) && \
17
+ make altinstall && \
18
+ cd .. && \
19
+ rm -rf Python-${PYTHON_VERSION} Python-${PYTHON_VERSION}.tgz
20
+
21
+ RUN ln -s /usr/local/bin/python3.12 /usr/local/bin/python && \
22
+ ln -s /usr/local/bin/pip3.12 /usr/local/bin/pip
23
+
24
+ COPY requirements.txt /modded-nanogpt/requirements.txt
25
+ WORKDIR /modded-nanogpt
26
+
27
+ RUN python -m pip install --upgrade pip && \
28
+ pip install -r requirements.txt
29
+
30
+ RUN pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --upgrade
31
+
32
+ CMD ["bash"]
33
+ ENTRYPOINT []
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Keller Jordan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modded-NanoGPT
2
+
3
+ This repository hosts the *NanoGPT speedrun*, in which we (collaboratively|competitively) search for the fastest algorithm to use 8 NVIDIA H100 GPUs to train a language model that attains 3.28 cross-entropy loss on the [FineWeb](https://huggingface.co/datasets/HuggingFaceFW/fineweb) validation set.
4
+
5
+ The target (3.28 validation loss on FineWeb) follows Andrej Karpathy's [GPT-2 replication in llm.c, which attains that loss after running for 45 minutes](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29).
6
+ The speedrun code also descends from llm.c's [PyTorch trainer](https://github.com/karpathy/llm.c/blob/master/train_gpt2.py), which itself descends from NanoGPT, hence the name of the repo.
7
+ Thanks to the efforts of many contributors, this repo now contains a training algorithm which attains the target performance in:
8
+ * 2 minutes on 8xH100 (the llm.c GPT-2 replication needed 45)
9
+ * under 500M tokens (the llm.c GPT-2 replication needed 10B)
10
+
11
+ This improvement in training speed has been brought about by the following techniques:
12
+ * Modernized architecture: Rotary embeddings, QK-Norm, and ReLU²
13
+ * The Muon optimizer [[writeup](https://kellerjordan.github.io/posts/muon/)] [[repo](https://github.com/KellerJordan/Muon)]
14
+ * Use FP8 matmul for head, and asymmetric rescale and softcap logits
15
+ * Initialization of projections to zero (muP-like)
16
+ * Skip connections from embedding to every block as well as from block 3 to 6
17
+ * Extra embeddings which are mixed into the values in attention layers (inspired by Zhou et al. 2024)
18
+ * Flash Attention 3 with long-short sliding window attention pattern (inspired by Gemma 2) and window size warmup with YaRN
19
+ * Align training batch starts with EoS and set a max document length
20
+ * Accumulate gradients for 2 steps for embedding and lm_head before updating parameters
21
+ * Enable model to back out contributions from first 2/3 layers before prediction
22
+ * Polar Express implementation in Muon
23
+ * Smear module to enable 1 token look back
24
+ * Sparse attention gate
25
+ * NorMuon
26
+ * Cautious Weight Decay w/ schedule tied to LR
27
+ * Exponential decay of residual stream
28
+ * Batch size schedule
29
+ * Partial Key Offset
30
+ * Multi token prediction
31
+ * Untie embed and lm_head at 2/3 of training
32
+ * Additional gating on value embeddings and skip connection
33
+ * Paired head attention
34
+
35
+ As well as many systems optimizations.
36
+
37
+ Contributors list (growing with each new record): [@bozavlado](https://x.com/bozavlado); [@brendanh0gan](https://x.com/brendanh0gan);
38
+ [@fernbear.bsky.social](https://bsky.app/profile/fernbear.bsky.social); [@Grad62304977](https://x.com/Grad62304977);
39
+ [@jxbz](https://x.com/jxbz); [@kellerjordan0](https://x.com/kellerjordan0);
40
+ [@KoszarskyB](https://x.com/KoszarskyB); [@leloykun](https://x.com/@leloykun);
41
+ [@YouJiacheng](https://x.com/YouJiacheng); [@jadenj3o](https://x.com/jadenj3o);
42
+ [@KonstantinWilleke](https://github.com/KonstantinWilleke), [@alexrgilbert](https://github.com/alexrgilbert), [@adricarda](https://github.com/adricarda),
43
+ [@tuttyfrutyee](https://github.com/tuttyfrutyee), [@vdlad](https://github.com/vdlad);
44
+ [@ryanyang0](https://x.com/ryanyang0), [@vagrawal](https://github.com/vagrawal), [@classiclarryd](https://x.com/classiclarryd),
45
+ [@byronxu99](https://github.com/byronxu99), [@varunneal](https://x.com/varunneal), [@EmelyanenkoK](https://github.com/EmelyanenkoK),
46
+ [@bernard24](https://github.com/bernard24)/https://www.hiverge.ai/, [@Gusarich](https://x.com/Gusarich), [@li_zichong](https://x.com/li_zichong),
47
+ [@akash5474](https://github.com/akash5474), [@snimu](https://x.com/omouamoua), [@roeeshenberg](https://x.com/roeeshenberg),
48
+ [@ChrisJMcCormick](https://x.com/ChrisJMcCormick), [@dominikkallusky](https://github.com/dominikkallusky), [@acutkosky](https://github.com/acutkosky),
49
+ [@manikbhandari](https://github.com/manikbhandari), [@andrewbriand](https://github.com/andrewbriand), [@jrauvola](https://github.com/jrauvola)
50
+
51
+
52
+ ---
53
+
54
+ ## Running the current record
55
+
56
+ To run the current record, run the following commands.
57
+ ```bash
58
+ git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt
59
+ pip install -r requirements.txt
60
+ pip install torch==2.10.0.dev20251210+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
61
+ # downloads only the first 900M training tokens to save time
62
+ python data/cached_fineweb10B.py 9
63
+ ./run.sh
64
+ ```
65
+ Add torchrun to path if ./run.sh gives error `torchrun: command not found`.
66
+
67
+ **Note: torch.compile will add around 7 minutes of latency the first time you run the code.**
68
+
69
+ Official records are timed on 8 NVIDIA H100 GPUs from https://app.primeintellect.ai/. PrimeIntellect has generously sponsored recent validation runs.
70
+
71
+ ## Alternative: Running with Docker (recommended for precise timing)
72
+
73
+ For cases where CUDA or NCCL versions aren't compatible with your current system setup, Docker can be a helpful alternative.
74
+ This approach standardizes versions for CUDA, NCCL, CUDNN, and Python, reducing dependency issues and simplifying setup.
75
+ Note: an NVIDIA driver must already be installed on the system (useful if only the NVIDIA driver and Docker are available).
76
+
77
+ ```bash
78
+ git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt
79
+ sudo docker build -t modded-nanogpt .
80
+ sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt python data/cached_fineweb10B.py 8
81
+ sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt sh run.sh
82
+ ```
83
+
84
+ To get an interactive docker, you can use
85
+ ```bash
86
+ sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt bash
87
+ ```
88
+
89
+ ---
90
+
91
+ ## World record history
92
+
93
+ The following is the historical progression of world speed records for the following competitive task:
94
+
95
+ > *Train a neural network to ≤3.28 validation loss on FineWeb using 8x NVIDIA H100s.*
96
+
97
+ Note: The 3.28 target was selected to match [Andrej Karpathy's GPT-2 (small) reproduction](https://github.com/karpathy/llm.c/discussions/481).
98
+
99
+ | # | Record time | Description | Date | Log | Contributors |
100
+ | - | - | - | - | - | - |
101
+ 1 | 45 minutes | [llm.c baseline](https://github.com/karpathy/llm.c/discussions/481) | 05/28/24 | [log](records/track_1_short/2024-10-13_llmc/main.log) | @karpathy, llm.c contributors
102
+ 2 | 31.4 minutes | [Tuned learning rate & rotary embeddings](https://x.com/kellerjordan0/status/1798863559243513937) | 06/06/24 | [log](records/track_1_short/2024-06-06_AdamW/f66d43d7-e449-4029-8adf-e8537bab49ea.log) | @kellerjordan0
103
+ 3 | 24.9 minutes | [Introduced the Muon optimizer](https://x.com/kellerjordan0/status/1842300916864844014) | 10/04/24 | none | @kellerjordan0, @jxbz
104
+ 4 | 22.3 minutes | [Muon improvements](https://x.com/kellerjordan0/status/1844820919061287009) | 10/11/24 | [log](records/track_1_short/2024-10-10_Muon/eb5659d0-fb6a-49e5-a311-f1f89412f726.txt) | @kellerjordan0, @bozavlado
105
+ 5 | 15.2 minutes | [Pad embeddings, ReLU², zero-init projections, QK-norm](https://x.com/kellerjordan0/status/1845865698532450646) | 10/14/24 | [log](records/track_1_short/2024-10-14_ModernArch/dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt) | @Grad62304977, @kellerjordan0
106
+ 6 | 13.1 minutes | [Distributed the overhead of Muon](https://x.com/kellerjordan0/status/1847291684016783746) | 10/18/24 | [log](records/track_1_short/2024-10-17_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt) | @kellerjordan0
107
+ 7 | 12.0 minutes | [Upgraded PyTorch 2.5.0](https://x.com/kellerjordan0/status/1847358578686152764) | 10/18/24 | [log](records/track_1_short/2024-10-18_PyTorch25/d4bfb25f-688d-4da5-8743-33926fad4842.txt) | @kellerjordan0
108
+ 8 | 10.8 minutes | [Untied embedding and head](https://x.com/kellerjordan0/status/1853188916704387239) | 11/03/24 | [log](records/track_1_short/2024-11-03_UntieEmbed/d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt) | @Grad62304977, @kellerjordan0
109
+ 9 | 8.2 minutes | [Value and embedding skip connections, momentum warmup, logit softcap](https://x.com/kellerjordan0/status/1854296101303800108) | 11/06/24 | [log](records/track_1_short/2024-11-06_ShortcutsTweaks/dd7304a6-cc43-4d5e-adb8-c070111464a1.txt) | @Grad62304977, @kellerjordan0
110
+ 10 | 7.8 minutes | [Bfloat16 activations](https://x.com/kellerjordan0/status/1855267054774865980) | 11/08/24 | [log](records/track_1_short/2024-11-08_CastBf16/a833bed8-2fa8-4cfe-af05-58c1cc48bc30.txt) | @kellerjordan0
111
+ 11 | 7.2 minutes | [U-net pattern skip connections & double lr](https://x.com/kellerjordan0/status/1856053121103093922) | 11/10/24 | [log](records/track_1_short/2024-11-10_UNetDoubleLr/c87bb826-797b-4f37-98c7-d3a5dad2de74.txt) | @brendanh0gan
112
+ 12 | 5.03 minutes | [1024-ctx dense causal attention → 64K-ctx FlexAttention](https://x.com/kellerjordan0/status/1859331370268623321) | 11/19/24 | [log](records/track_1_short/2024-11-19_FlexAttention/8384493d-dba9-4991-b16b-8696953f5e6d.txt) | @KoszarskyB
113
+ 13 | 4.66 minutes | [Attention window warmup](https://x.com/hi_tysam/status/1860851011797053450) | 11/24/24 | [log](records/track_1_short/2024-11-24_WindowWarmup/cf9e4571-c5fc-4323-abf3-a98d862ec6c8.txt) | @fernbear.bsky.social
114
+ 14 | 4.41 minutes | [Value Embeddings](https://x.com/KoszarskyB/status/1864746625572257852) | 12/04/24 | [log](records/track_1_short/2024-12-04_ValueEmbed) | @KoszarskyB
115
+ 15 | 3.95 minutes | [U-net pattern value embeddings, assorted code optimizations](https://x.com/YouJiacheng/status/1865761473886347747) | 12/08/24 | [log](records/track_1_short/2024-12-08_UNetValueEmbedsTweaks) | @leloykun, @YouJiacheng
116
+ 16 | 3.80 minutes | [Split value embeddings, block sliding window, separate block mask](https://x.com/YouJiacheng/status/1866734331559071981) | 12/10/24 | [log](records/track_1_short/2024-12-10_MFUTweaks) | @YouJiacheng
117
+ 17 | 3.57 minutes | [Sparsify value embeddings, improve rotary embeddings, drop an attn layer](https://x.com/YouJiacheng/status/1868938024731787640) | 12/17/24 | [log](records/track_1_short/2024-12-17_SparsifyEmbeds) | @YouJiacheng
118
+ 18 | 3.4 minutes | [Lower logit softcap from 30 to 15](https://x.com/kellerjordan0/status/1876048851158880624) | 01/04/25 | [log](records/track_1_short/2025-01-04_SoftCap/31d6c427-f1f7-4d8a-91be-a67b5dcd13fd.txt) | @KoszarskyB
119
+ 19 | 3.142 minutes | [FP8 head, offset logits, lr decay to 0.1 instead of 0.0](https://x.com/YouJiacheng/status/1878827972519772241) | 01/13/25 | [log](records/track_1_short/2025-01-13_Fp8LmHead/c51969c2-d04c-40a7-bcea-c092c3c2d11a.txt) | @YouJiacheng
120
+ 20 | 2.992 minutes | [Merged QKV weights, long-short attention, attention scale, lower Adam epsilon, batched Muon](https://x.com/leloykun/status/1880301753213809016) | 01/16/25 | [log](records/track_1_short/2025-01-16_Sub3Min/1d3bd93b-a69e-4118-aeb8-8184239d7566.txt) | @leloykun, @fernbear.bsky.social, @YouJiacheng, @brendanh0gan, @scottjmaddox, @Grad62304977
121
+ 21 | 2.933 minutes | [Reduced batch size](https://x.com/leloykun/status/1885640350368420160) | 01/26/25 | [log](records/track_1_short/2025-01-26_BatchSize/c44090cc-1b99-4c95-8624-38fb4b5834f9.txt) | @leloykun
122
+ 21 | 2.997 minutes | 21st record with new timing | 02/01/25 | [log](records/track_1_short/2025-02-01_RuleTweak/eff63a8c-2f7e-4fc5-97ce-7f600dae0bc7.txt) | not a new record, just re-timing #21 with the [updated rules](#timing-change-after-record-21)
123
+ 21 | 3.014 minutes | 21st record with latest torch | 05/24/25 | [log](records/track_1_short/2025-05-24_StableTorch/89d9f224-3b01-4581-966e-358d692335e0.txt) | not a new record, just re-timing #21 with latest torch
124
+ 22 | 2.990 minutes | [Faster gradient all-reduce](https://x.com/KonstantinWille/status/1927137223238909969) | 05/24/25 | [log](records/track_1_short/2025-05-24_FasterReduce/23f40b75-06fb-4c3f-87a8-743524769a35.txt) | @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad; The Enigma project
125
+ 23 | 2.979 minutes | [Overlap computation and gradient communication](https://x.com/kellerjordan0/status/1927460573098262616) | 05/25/25 | [log](records/track_1_short/2025-05-25_EvenFasterReduce/6ae86d05-5cb2-4e40-a512-63246fd08e45.txt) | @ryanyang0
126
+ 24 | 2.966 minutes | Replace gradient all_reduce with reduce_scatter | 05/30/25 | [log](records/track_1_short/2025-05-30_noallreduce/8054c239-3a18-499e-b0c8-dbd27cb4b3ab.txt) | @vagrawal
127
+ 25 | 2.896 minutes | Upgrade PyTorch to 2.9.0.dev20250713+cu126 | 07/13/25 | [log](records/track_1_short/2025-07-13_UpgradeTorch190/692f80e0-5e64-4819-97d4-0dc83b7106b9.txt) | @kellerjordan0
128
+ 26 | 2.863 minutes | Align training batch starts with EoS, increase cooldown frac to .45 | 07/13/25 | [log](records/track_1_short/2025-07-12_BosAlign/c1fd8a38-bb9f-45c4-8af0-d37f70c993f3.txt) | @classiclarryd
129
+ 27 | 2.817 minutes | Transpose one of the MLP matrices + add Triton kernel for symmetric matmul | 07/18/25 | [log](records/track_1_short/2025-07-18_TritonMuon/record.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/109) | @byronxu99
130
+ 28 | 2.812 minutes | Sparse attention gate | 08/23/25 | [log](records/track_1_short/2025-08-23_SparseAttnGate/020630eb-2191-4ba2-9ee4-4cdc94316943.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/117) | @classiclarryd
131
+ 29 | 2.731 minutes | Flash Attention 3, 2048 max_doc_len, update ws schedule | 09/03/25 | [log](records/track_1_short/2025-09-03_FA3/44fc1276-0510-4961-92c0-730c65e5feba.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/118) | @varunneal
132
+ 30 | 2.717 minutes | Drop first MLP layer | 09/05/25 | [log](records/track_1_short/2025-09-05_SkipMLPBlocks/07e7ae76-b7d0-4481-b149-01e7d81b5ad4.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/120) | @EmelyanenkoK
133
+ 31 | 2.656 minutes | Dynamically incorporate YaRN during training and validation | 09/10/25 | [log](records/track_1_short/2025-09-10_Yarn/0ecdb695-510b-4c3b-b030-09861a162ce8.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/122) | @classiclarryd
134
+ 32 | 2.625 minutes | Optimize distributed training, improve skip connection gating, and enhance bfloat16 usage | 09/11/25 | [log](records/track_1_short/2025-09-11_VectSigmoidBFloat16/0d0d9882-c34f-4d82-b961-a17d5659c988.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/125) | @bernard24 & hiverge.ai
135
+ 33 | 2.565 minutes | Asynchronously fetch and index data batches, extend final layer attention window for validation | 09/15/25 | [log](records/track_1_short/2025-09-15_AsyncDataLoadAttnFinalWindow/25db37c7-2bab-4ef4-ae63-d593590ef823.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/127) | @classiclarryd
136
+ 34 | 2.547 minutes | Smear token embeddings 1 position forward | 09/18/25 | [log](records/track_1_short/2025-09-18_Smear/18a1e5c7-947e-479d-bc3a-a57a61a98fc9.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/130) | @classiclarryd
137
+ 35 | 2.527 minutes | Drop first attn layer, extend all long windows for validation, update schedule | 09/21/25 | [log](records/track_1_short/2025-09-21_DropAttn/01fc4a96-f2a0-47a1-8a6a-c7d10bac99fe.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/131) | @classiclarryd
138
+ 36 | 2.495 minutes | MuonCustomSizing, perform mlp and attn reduce scatter in shared call | 09/23/25 | [log](records/track_1_short/2025-09-23_MuonCustomSizing/b067b4ac-72a6-4436-a6f8-ea51c1efeef3.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/132) | @classiclarryd
139
+ 37 | 2.483 minutes | Compute cross entropy in BF16 during training | 09/27/25 | [log](records/track_1_short/2025-09-27_BF16CE/08c0770f-17fc-44cd-971d-734a7a28a3e3.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/133) | @Gusarich
140
+ 38 | 2.476 minutes | Polar Express, replacement for Newton-Schulz | 09/29/25 | [log](records/track_1_short/2025-09-29_PolarExpress/0e3f0af5-ad08-47a6-813d-0c709b50d422.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/134) | @varunneal
141
+ 39 | 2.447 minutes | Only update Adam params every other step, reduce batch size | 09/30/25 | [log](records/track_1_short/2025-09-30_CustomBatching/40b101b1-77ea-45ea-a089-1d3a647daa22.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/136) | @classiclarryd
142
+ 40 | 2.358 minutes | Backout, misc hyperparameter tuning, optimize lambda padding | 10/04/25 | [log](records/track_1_short/2025-10-04_Backout/514e7581-fbd4-4338-a3e4-e556f9c958ce.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/140) | @classiclarryd
143
+ 41 | 2.345 minutes | [NorMuon](https://arxiv.org/pdf/2510.05491) | 10/24/25 | [log](records/track_1_short/2025-10-24_NorMuon/088a77ee-9b67-475a-bbb9-3e92e4698799.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/144) | @li_zichong
144
+ 42 | 2.313 minutes | Update NorMuon LR, Step Logic | 10/27/25 | [log](records/track_1_short/2025-10-27_FixMuonLR/14afd380-d3d9-48d7-ad23-4c13cb96754b.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/146) | @varunneal
145
+ 43 | 2.284 minutes | Cautious Weight Decay w/ schedule | 11/10/25 | [log](records/track_1_short/2025-11-10_CautiousWD/1aac0132-a891-4ed9-b358-0fd2abd1b019.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/154) | @varunneal
146
+ 44 | 2.269 minutes | Backward hooks on Adam, [Profiling 101](https://blog.underfit.ai/profiling-101-nanogpt) | 11/16/25 | [log](records/track_1_short/2025-10-31_AdamSyncGradientHook/0c17cdfd-772c-4906-8d11-141b370599a0.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/149) | @akash5474
147
+ 45 | 2.248 minutes | Refine skip arch, update exponential decay init| 11/18/25 | [log](records/track_1_short/2025-11-18_RefineSkip/00f4e1e6-0044-4a08-b88a-3b7ec0624081.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/159) | @classiclarryd
148
+ 46 | 2.203 minutes | [Batch size schedule](https://x.com/classiclarryd/status/1998212158770065844) | 11/29/25 | [log](records/track_1_short/2025-11-29_BatchSizeSchedule/10e8f7c6-7175-4467-bdb0-a5de25d771a6.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/163) | @varunneal
149
+ 47 | 2.193 minutes | [Multiply attn lambda with weight instead of data, fix warmup](https://x.com/classiclarryd/status/1999630732814348451) | 12/10/25 | [log](records/track_1_short/2025-12-10_SALambdaOnWeights/15ef5eaf-56e1-40e1-9ddf-af010027c9dd.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/166) | @roeeshenberg
150
+ 48 | 2.170 minutes | [Speed up Muon, additional pre-multiply lambda, reshape matrices, update lr, update NorMuon axis](https://x.com/classiclarryd/status/2000272495644152317) | 12/11/25 | [log](records/track_1_short/2025-12-11_NorMuonOptimsAndFixes/82edf6be-f343-475d-b93a-47c32acf4de2.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/168) | @ChrisJMcCormick
151
+ 49 | 2.146 minutes | [Partial Key Offset](https://x.com/classiclarryd/status/2000841339299402142) | 12/14/25 | [log](records/track_1_short/2025-12-14_PartialKeyOffset/150d40bf-c20b-4568-aac9-26eb919e25fd.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/169) | @classiclarryd
152
+ 50 | 2.128 minutes | [Extend Cautious Weight Decay to Adam parameters](https://x.com/classiclarryd/status/2002482925741486381) | 12/18/25 | [log](records/track_1_short/2025-12-18_CautiousWDAdam/1981d492-bc65-4ba9-a0fa-2b30fc5c3eba.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/172) | @roeeshenberg
153
+ 51 | 2.075 minutes | [Retie Embed to lm_head, retune fp8 scales](https://x.com/classiclarryd/status/2003167208483209668) | 12/19/25 | [log](records/track_1_short/2025-12-19_RetieLMHead/0828d309-ecfe-4442-9ee9-68fed3a4b599.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/175) | @varunneal
154
+ 52 | 2.037 minutes | [Smooth scalars via beta increase, decrease smear gate lr, freeze scalars during transitions, adam all reduce](https://x.com/classiclarryd/status/2003863282613190656) | 12/21/25 | [log](records/track_1_short/2025-12-21_SmoothedScalars/12-21-Smoothed-Scalars/0bc6e909-8ee8-4ae3-ac62-0070e151a808.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/177) | @ChrisJMcCormick
155
+ 53 | 1.988 minutes | [Multi-token prediction, untie embed/lm_head at 2/3 training, lr update, tweak CWD](https://x.com/classiclarryd/status/2004248941878296580) | 12/22/25 | [log](records/track_1_short/2025-12-22_MultiTokenPrediction/17aaf854-f338-4d0d-9767-a5db30fd7980.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/178) | @varunneal, feat. @classiclarryd
156
+ 54 | 1.940 minutes | [Asymmetric Logit Rescale](https://x.com/classiclarryd/status/2004791008098480232) | 12/26/25 | [log](records/track_1_short/2025-12-26_LogitRescale/03e41c2d-2951-4546-a599-24cd723247fc.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/181) | @classiclarryd
157
+ 55 | 1.918 minutes | [Gates on value embeds and skip connection](https://x.com/classiclarryd/status/2005659526960492638) | 12/29/25 | [log](records/track_1_short/2025-12-29_VeSkipGates/2851d7dc-d6a5-4e74-8623-57031425db16.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/186) | @classiclarryd
158
+ 56 | 1.894 minutes | [Optimize and compile Adam, increase Adam buffer precision, move gates from Muon to Adam parameter banks](https://x.com/classiclarryd/status/2007882371576873445) | 12/31/25 | [log](records/track_1_short/2025-12-31_GatesToCompiledAdam/12-31-gates-to-adam-20stps/219a5f2f-151e-4c56-ab91-3735ae4610b8.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/187) | @ChrisJMcCormick
159
+ 57 | 1.878 minutes | [Bfloat16 attn/mlp weights, mixed precision Muon, interweave Adam/Muon, finer-grain Adam beta](https://x.com/classiclarryd/status/2008261904566022590) | 01/04/26 | [log](records/track_1_short/2026-01-04_MixedPrecisionInterweavedOptimizer/41f606b6-1b9c-46a3-b46e-2beff1521d18.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/190) | @classiclarryd, feat. @YouJiacheng, @ChrisJMcCormick
160
+ 58 | 1.820 minutes | [Paired Head Attention](https://x.com/classiclarryd/status/2008963501688324228) | 01/07/26 | [log](records/track_1_short/2026-01-07_PairedHeadAttention/2a5d5cde-db5f-4aab-a4a8-cc8e183ea671.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/191) | @classiclarryd
161
+ 59 | 1.781 minutes | Fused linear relu square triton kernel | 01/10/26 | [log](records/track_1_short/2026-01-10_FusedLinearReLUSquare/3c47e63b-075e-4b5b-9c76-9dbe7bad9ad4.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/197) | @andrewbriand, @jrauvola
162
+ ## Rules
163
+
164
+ New records must:
165
+
166
+ 1. Not modify the train or validation data pipelines. (You can change the batch size, sequence length, attention structure etc.; just don't change the underlying streams of tokens.)
167
+ 2. Attain ≤3.28 mean val loss. (Due to inter-run variance, submissions must provide enough run logs to attain a statistical significance level of p<0.01 that their mean val loss is ≤3.28. Example code to compute p-value can be found [here](records/track_1_short/2025-01-04_SoftCap#softer-softcap). For submissions which improve speed by optimizing the systems performance, without touching the ML, this requirement is waived.)
168
+ 3. Not use any extra `torch._inductor.config` or `torch.compile` flags. (These can save a few seconds, but they can also make compilation take >30min. This rule was introduced after the 21st record.)
169
+ 4. Run faster than the prior record when baselined on the same hardware.
170
+
171
+ Discretionary reasons why a PR may not be accepted:
172
+ 1. Disproportionately degrades the readability of the codebase. A 200 line kernel to drop 300ms is considered worthwhile. 500 lines that convolute the optimizer layout for a 50ms gain will likely be rejected.
173
+ 2. The current record is intentionally kept roughly 0.001-0.002 loss below 3.28 to make validation simpler. If a PR substantially consumes this buffer, it should do so in a way that outperforms a simple step count decrease, when measured at equivalent loss.
174
+
175
+ > Note: `torch._inductor.config.coordinate_descent_tuning` is allowed for GPT-2 Medium track (a.k.a. 2.92 track).
176
+
177
+ Other than that, anything and everything is fair game!
178
+
179
+ [further clarifications](https://github.com/KellerJordan/modded-nanogpt/discussions/23?sort=new#discussioncomment-12109560)
180
+
181
+ ---
182
+
183
+ ### Comment on the target metric
184
+
185
+ The target metric is *cross-entropy loss on the FineWeb val set*. To speak mathematically, the goal of the speedrun is *to obtain a probability model of language which assigns a probability of at least `math.exp(-3.28 * 10485760)` to the first 10,485,760 tokens of the FineWeb valset. Hence, e.g., we allow evaluation at any sequence length, so long as we still have a valid probability model of language.
186
+
187
+ ---
188
+
189
+ ### Timing change after record 21
190
+
191
+ After the 21st record, we made two changes to the timing. First, there used to be an initial "grace period" of 10 untimed steps to allow kernel warmup. We replaced this with an explicit kernel-warmup section which is untimed and uses dummy data. This results in an extra runtime of 850ms from the 10 extra timed steps.
192
+ Second, we banned the use of `torch._inductor.config.coordinate_descent_tuning`. This saves ~25min of untimed pre-run compilation, but results in an extra runtime of ~3s.
193
+
194
+ <!--Note: The original llm.c baseline is intended to be closer to a replication of GPT-2 than to an optimized LLM training.
195
+ So it's no surprise that there is room to improve; as @karpathy has said, 'llm.c still has a lot of pending optimizations.'
196
+ In addition, many of the techniques used in these records are completely standard, such as rotary embeddings.
197
+ The goal of this benchmark/speedrun is simply to find out which techniques actually work, and maybe come up with some new ones.-->
198
+ <!--The goal of this benchmark is simply to find out all the techniques which actually work, because I'm going crazy reading all these
199
+ LLM training papers
200
+ which claim a huge benefit but then use their own idiosyncratic non-competitive benchmark and therefore no one in the community has any idea if it's legit for months.-->
201
+ <!--[LLM](https://arxiv.org/abs/2305.14342) [training](https://arxiv.org/abs/2402.17764) [papers](https://arxiv.org/abs/2410.01131)-->
202
+ <!--I mean hello??? We're in a completely empirical field; it is insane to not have a benchmark. Ideally everyone uses the same LLM training benchmark,
203
+ and then reviewing LLM training papers becomes as simple as checking if they beat the benchmark. It's not like this would be unprecedented, that's how things
204
+ were in the ImageNet days.
205
+ The only possible 'benefit' I can think of for any empirical field to abandon benchmarks is that it would make it easier to publish false results. Oh, I guess that's why it happened.
206
+ Hilarious to think about how, in the often-commented-upon and ongoing collapse of the peer review system, people blame the *reviewers* --
207
+ yeah, those guys doing free labor who everyone constantly musters all of their intelligence to lie to, it's *their* fault! My bad, you caught me monologuing.-->
208
+
209
+ ---
210
+
211
+ ### Notable attempts & forks
212
+
213
+ **Notable runs:**
214
+
215
+ * [@alexjc's 01/20/2025 2.77-minute TokenMonster-based record](https://x.com/alexjc/status/1881410039639863622).
216
+ This record is technically outside the rules of the speedrun, since we specified that the train/val tokens must be kept fixed.
217
+ However, it's very interesting, and worth including. The run is not more data-efficient; rather, the speedup comes from the improved tokenizer allowing
218
+ the vocabulary size to be reduced (nearly halved!) while preserving the same bytes-per-token, which saves lots of parameters and FLOPs in the head and embeddings.
219
+
220
+ **Notable forks:**
221
+ * [https://github.com/BlinkDL/modded-nanogpt-rwkv](https://github.com/BlinkDL/modded-nanogpt-rwkv)
222
+ * [https://github.com/nikhilvyas/modded-nanogpt-SOAP](https://github.com/nikhilvyas/modded-nanogpt-SOAP)
223
+
224
+ ---
225
+
226
+ ## Speedrun track 2: GPT-2 Medium
227
+
228
+ The target loss for this track is lowered from 3.28 to 2.92, as per Andrej Karpathy's 350M-parameter llm.c baseline.
229
+ This baseline generates a model with performance similar to the original GPT-2 Medium, whereas the first track's baseline generates a model on par with GPT-2 Small.
230
+ All other rules remain the same.
231
+
232
+ > Note: `torch._inductor.config.coordinate_descent_tuning` is turned on after the record 6 (*).
233
+
234
+ | # | Record time | Description | Date | Log | Contributors |
235
+ | - | - | - | - | - | - |
236
+ 1 | 5.8 hours | [llm.c baseline (350M parameters)](https://github.com/karpathy/llm.c/discussions/481) | 05/28/24 | [log](records/track_2_medium/2025-01-18/main.log) | @karpathy, llm.c contributors
237
+ 2 | 29.3 minutes | [Initial record based on scaling up the GPT-2 small track speedrun](https://x.com/kellerjordan0/status/1881959719012847703) | 01/18/25 | [log](records/track_2_medium/2025-01-18/241dd7a7-3d76-4dce-85a4-7df60387f32a.txt) | @kellerjordan0
238
+ 3 | 28.1 minutes | [Added standard weight decay](https://x.com/kellerjordan0/status/1888320690543284449) | 02/08/25 | [log](records/track_2_medium/2025-02-08_WeightDecay/b01743db-605c-4326-b5b1-d388ee5bebc5.txt) | @kellerjordan0
239
+ 4 | 27.7 minutes | [Tuned Muon Newton-Schulz coefficients](https://x.com/leloykun/status/1892793848163946799) | 02/14/25 | [log](records/track_2_medium/2025-02-14_OptCoeffs/1baa66b2-bff7-4850-aced-d63885ffb4b6.txt) | @leloykun
240
+ 5 | 27.2 minutes | [Increased learning rate cooldown phase duration](records/track_2_medium/2025-03-06_LongerCooldown/779c041a-2a37-45d2-a18b-ec0f223c2bb7.txt) | 03/06/25 | [log](records/track_2_medium/2025-03-06_LongerCooldown/779c041a-2a37-45d2-a18b-ec0f223c2bb7.txt) | @YouJiacheng
241
+ 6 | 25.95 minutes* | [2x MLP wd, qkv norm, all_reduce/opt.step() overlap, optimized skip pattern](https://x.com/YouJiacheng/status/1905861218138804534) | 03/25/25 | [log](records/track_2_medium/2025-03-25_ArchOptTweaks/train_gpt-20250329.txt) | @YouJiacheng
242
+ 7 | 25.29 minutes | [Remove FP8 head; ISRU logits softcap; New sharded mixed precision Muon; merge weights](https://x.com/YouJiacheng/status/1912570883878842527) | 04/16/25 | [log](records/track_2_medium/2025-04-16_Record7/223_3310d0b1-b24d-48ee-899f-d5c2a254a195.txt) | @YouJiacheng
243
+ 8 | 24.50 minutes | [Cubic sliding window size schedule, 2× max window size (24.84 minutes)](https://x.com/jadenj3o/status/1914893086276169754) [24.5min repro](https://x.com/YouJiacheng/status/1915667616913645985) | 04/22/25 | [log](records/track_2_medium/2025-04-22_Record8/075_640429f2-e726-4e83-aa27-684626239ffc.txt) | @jadenj3o
244
+ 9 | 24.12 minutes | [Add two value embeddings](https://snimu.github.io/2025/10/07/modded-nanogpt-value-embeddings.html) | 08/28/25 | [log](records/track_2_medium/2025-08-28_NewValemb/036_61ef4351-7b68-4897-b440-a99221a1a629.txt), [PR](https://github.com/KellerJordan/modded-nanogpt/pull/119) | @snimu
245
+ 10 | 24.07 minutes | [Second input embedding](https://snimu.github.io/2025/10/10/modded-nanogpt-x0.html) | 09/11/25 | [log](records/track_2_medium/2025-09-11_SecondInputEmbed/000_592014ec-6781-4f59-b274-c4af68ccfe75.txt), [PR](https://github.com/KellerJordan/modded-nanogpt/pull/124) | @snimu
246
+ 11 | 23.45 minutes | Upgrade from torch 2.7 to torch==2.10.0.dev20251210+cu126 | - | - | -
247
+ 12 | 23.28 minutes | Snoo Optimizer (Outer optimizer around Adam and Muon) | 09/16/25 | [log](records/track_2_medium/2025-09-16_Snoo/000_01db7a67-f715-4114-a7b5-6bfe23bac1b1.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/128) | @dominikkallusky
248
+ 13 | 23.14 minutes | EMA Wrapper on Muon | 09/17/25 | [log](records/track_2_medium/2025-09-17_UpdateSmoothing/001_8379f695-6bc3-4f76-b58b-8fadd3b6ebb0.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/129) | @acutkosky
249
+ 14 | 23.08 minutes | Combine both records 12 & 13 | 09/30/25 | [log](records/track_2_medium/2025-09-30_SmoothedSnooMedium/101_5bc91cd0-cb46-428c-a5da-9d8d228f1f97.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/137) | @acutkosky
250
+ 15 | 23.03 minutes | Backout (Skip from 2/3 point to pre-lm_head) | 10/04/25 | [log](records/track_2_medium/2025-10-04_GPT2MediumLayerReuse/000_cc3943e4-02b5-4ae3-9441-839d32dfd9b2.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/139) | @snimu
251
+ 16 | 22.99 minutes | Smear-MTP | 11/02/25 | [log](records/track_2_medium/2025-11-02-Smear-MTP/000_3b50518d-d542-44bc-8566-3abf633f83ad.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/151) | @snimu
252
+ 17 | 22.98 minutes | Remove Redundant Mask Op | 11/12/25 | [log](records/track_2_medium/2025-11-12_BlockMaskRedundantOp/000_3b22a9d4-b52e-4916-99bf-3d48b38747a7.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/157/) | @manikbhandari
253
+ 18 | 17.35 minutes | Bulk transfer short track features | 12/31/25 | [log](records/track_2_medium/2025-12-31_BulkSmallTrackTransfer/354be270-7d41-44b7-8064-f040923f024f.txt),[PR](https://github.com/KellerJordan/modded-nanogpt/pull/188) | -
254
+ ---
255
+
256
+ ### Q: What is the point of NanoGPT speedrunning?
257
+
258
+ A: The officially stated goal of NanoGPT speedrunning is as follows: `gotta go fast`. But for something a little more verbose involving an argument for good benchmarking, here's some kind of manifesto, adorned with a blessing from the master. [https://x.com/karpathy/status/1846790537262571739](https://x.com/karpathy/status/1846790537262571739)
259
+
260
+ ### Q: What makes "NanoGPT speedrunning" not just another idiosyncratic benchmark?
261
+
262
+ A: Because it is a *competitive* benchmark. In particular, if you attain a new speed record (using whatever method you want), there is an open invitation for you
263
+ to post that record (on arXiv or X) and thereby vacuum up all the clout for yourself. I will even help you do it by reposting you as much as I can.
264
+
265
+ <!--On the contrary, for example, the benchmark used in the [Sophia](https://arxiv.org/abs/2305.14342) paper does *not* have this property.
266
+ There is no such open invitation for anyone to compete on the benchmark they used. In particular, if, for a random and definitely not weirdly specific example, you happen to find better AdamW hyperparameters for their training setup than
267
+ the ones they used which significantly close the gap between AdamW and their proposed optimizer,
268
+ then there is no clear path for you to publish that result in *any* form.
269
+ You could try posting it on X.com, but then you would be risking being perceived as aggressive/confrontational, which is *not a good look* in this racket.
270
+ So if you're rational, the result probably just dies with you and no one else learns anything
271
+ (unless you're in a frontier lab, in which case you can do a nice internal writeup. Boy I'd love to get my hands on those writeups).-->
272
+
273
+ ["Artificial intelligence advances by inventing games and gloating to goad others to play" - Professor Ben Recht](https://www.argmin.net/p/too-much-information)
274
+
275
+ ### Q: NanoGPT speedrunning is cool and all, but meh it probably won't scale and is just overfitting to val loss
276
+
277
+ A: This is hard to refute, since "at scale" is an infinite category (what if the methods stop working only for >100T models?), making it impossible to fully prove.
278
+ Also, I would agree that some of the methods used in the speedrun are unlikely to scale, particularly those which *impose additional structure* on the network, such as logit softcapping.
279
+ But if the reader cares about 1.5B models, they might be convinced by this result:
280
+
281
+ *Straightforwardly scaling up the speedrun (10/18/24 version) to 1.5B parameters yields a model with GPT-2 (1.5B)-level HellaSwag performance 2.5x more cheaply than [@karpathy's baseline](https://github.com/karpathy/llm.c/discussions/677) ($233 instead of $576):*
282
+
283
+ ![](img/nanogpt_speedrun51.png)
284
+ [[reproducible log](https://github.com/KellerJordan/modded-nanogpt/blob/master/records/track_1_short/2024-10-20_ScaleUp1B/ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt)]
285
+ ![](img/nanogpt_speedrun52.png)
286
+
287
+ ---
288
+
289
+ ## [Muon optimizer](https://github.com/KellerJordan/Muon)
290
+
291
+ Muon is defined as follows:
292
+
293
+ ![](img/algo_optimizer.png)
294
+
295
+ Where NewtonSchulz5 is the following Newton-Schulz iteration [2, 3], which approximately replaces `G` with `U @ V.T` where `U, S, V = G.svd()`.
296
+ ```python
297
+ @torch.compile
298
+ def zeroth_power_via_newtonschulz5(G, steps=5, eps=1e-7):
299
+ assert len(G.shape) == 2
300
+ a, b, c = (3.4445, -4.7750, 2.0315)
301
+ X = G.bfloat16() / (G.norm() + eps)
302
+ if G.size(0) > G.size(1):
303
+ X = X.T
304
+ for _ in range(steps):
305
+ A = X @ X.T
306
+ B = b * A + c * A @ A
307
+ X = a * X + B @ X
308
+ if G.size(0) > G.size(1):
309
+ X = X.T
310
+ return X.to(G.dtype)
311
+ ```
312
+
313
+ For this training scenario, Muon has the following favorable properties:
314
+ * Lower memory usage than Adam
315
+ * ~1.5x better sample-efficiency
316
+ * <2% wallclock overhead
317
+
318
+
319
+ ### Provenance
320
+
321
+ Many of the choices made to generate this optimizer were obtained experimentally by our pursuit of [CIFAR-10 speedrunning](https://github.com/KellerJordan/cifar10-airbench).
322
+ In particular, we experimentally obtained the following practices:
323
+ * Using Nesterov momentum inside the update, with orthogonalization applied after momentum.
324
+ * Using a specifically quintic Newton-Schulz iteration as the method of orthogonalization.
325
+ * Using non-convergent coefficients for the quintic polynomial in order to maximize slope at zero, and thereby minimize the number of necessary Newton-Schulz iterations.
326
+ It turns out that the variance doesn't actually matter that much, so we end up with a quintic that rapidly converges to the range 0.68, 1.13 upon repeated application, rather than converging more slowly to 1.
327
+ * Running the Newton-Schulz iteration in bfloat16 (whereas Shampoo implementations often depend on inverse-pth-roots run in fp32 or fp64).
328
+
329
+ Our use of a Newton-Schulz iteration for orthogonalization traces to [Bernstein & Newhouse (2024)](https://arxiv.org/abs/2409.20325),
330
+ who suggested it as a way to compute Shampoo [5, 6] preconditioners, and theoretically explored Shampoo without preconditioner accumulation.
331
+ In particular, Jeremy Bernstein @jxbz sent us the draft, which caused us to experiment with various Newton-Schulz iterations as the
332
+ orthogonalization method for this optimizer.
333
+ If we had used SVD instead of a Newton-Schulz iteration, this optimizer would have been too slow to be useful.
334
+ Bernstein & Newhouse also pointed out that Shampoo without preconditioner accumulation is equivalent to steepest descent in the spectral norm,
335
+ and therefore Shampoo can be thought of as a way to smooth out spectral steepest descent.
336
+ The proposed optimizer can be thought of as a second way of smoothing spectral steepest descent, with a different set of memory and runtime tradeoffs
337
+ compared to Shampoo.
338
+
339
+ ---
340
+
341
+ ## Running on fewer GPUs
342
+
343
+ * To run experiments on fewer GPUs, simply modify `run.sh` to have a different `--nproc_per_node`. This should not change the behavior of the training.
344
+ * If you're running out of memory, you may need to reduce the sequence length for FlexAttention (which does change the training. see [here](https://github.com/KellerJordan/modded-nanogpt/pull/38) for a guide)
345
+
346
+ ---
347
+
348
+ ## References
349
+
350
+ 1. [Guilherme Penedo et al. "The fineweb datasets: Decanting the web for the finest text data at scale." arXiv preprint arXiv:2406.17557 (2024).](https://arxiv.org/abs/2406.17557)
351
+ 2. Nicholas J. Higham. Functions of Matrices. Society for Industrial and Applied Mathematics (2008). Equation 5.22.
352
+ 3. Günther Schulz. Iterative Berechnung der reziproken Matrix. Z. Angew. Math. Mech., 13:57–59 (1933).
353
+ 4. [Jeremy Bernstein and Laker Newhouse. "Old Optimizer, New Norm: An Anthology." arxiv preprint arXiv:2409.20325 (2024).](https://arxiv.org/abs/2409.20325)
354
+ 5. [Vineet Gupta, Tomer Koren, and Yoram Singer. "Shampoo: Preconditioned stochastic tensor optimization." International Conference on Machine Learning. PMLR, 2018.](https://arxiv.org/abs/1802.09568)
355
+ 6. [Rohan Anil et al. "Scalable second order optimization for deep learning." arXiv preprint arXiv:2002.09018 (2020).](https://arxiv.org/abs/2002.09018)
356
+ 7. [Alexander Hägele et al. "Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations." arXiv preprint arXiv:2405.18392 (2024).](https://arxiv.org/abs/2405.18392)
357
+ 8. [Zhanchao Zhou et al. "Value Residual Learning For Alleviating Attention Concentration In Transformers." arXiv preprint arXiv:2410.17897 (2024).](https://arxiv.org/abs/2410.17897)
358
+ 9. [Team, Gemma, et al. "Gemma 2: Improving open language models at a practical size." arXiv preprint arXiv:2408.00118 (2024).](https://arxiv.org/abs/2408.00118)
359
+ 10. [Alec Radford et al. "Language models are unsupervised multitask learners." OpenAI blog 1.8 (2019).](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
360
+
361
+ ## Citation
362
+
363
+ ```
364
+ @misc{modded_nanogpt_2024,
365
+ author = {Keller Jordan and Jeremy Bernstein and Brendan Rappazzo and
366
+ @fernbear.bsky.social and Boza Vlado and You Jiacheng and
367
+ Franz Cesista and Braden Koszarsky and @Grad62304977},
368
+ title = {modded-nanogpt: Speedrunning the NanoGPT baseline},
369
+ year = {2024},
370
+ url = {https://github.com/KellerJordan/modded-nanogpt}
371
+ }
372
+ ```
373
+
374
+ <img src="img/dofa.jpg" alt="itsover_wereback" style="width:100%;">
375
+
data/cached_fineweb100B.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from huggingface_hub import hf_hub_download
4
+ # Download the GPT-2 tokens of Fineweb100B from huggingface. This
5
+ # saves about an hour of startup time compared to regenerating them.
6
+ def get(fname):
7
+ local_dir = os.path.join(os.path.dirname(__file__), 'fineweb100B')
8
+ if not os.path.exists(os.path.join(local_dir, fname)):
9
+ hf_hub_download(repo_id="kjj0/fineweb100B-gpt2", filename=fname,
10
+ repo_type="dataset", local_dir=local_dir)
11
+ get("fineweb_val_%06d.bin" % 0)
12
+ num_chunks = 1030 # full fineweb100B. Each chunk is 100M tokens
13
+ if len(sys.argv) >= 2: # we can pass an argument to download less
14
+ num_chunks = int(sys.argv[1])
15
+ for i in range(1, num_chunks+1):
16
+ get("fineweb_train_%06d.bin" % i)
data/cached_fineweb10B.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from huggingface_hub import hf_hub_download
4
+ # Download the GPT-2 tokens of Fineweb10B from huggingface. This
5
+ # saves about an hour of startup time compared to regenerating them.
6
+ def get(fname):
7
+ local_dir = os.path.join(os.path.dirname(__file__), 'fineweb10B')
8
+ if not os.path.exists(os.path.join(local_dir, fname)):
9
+ hf_hub_download(repo_id="kjj0/fineweb10B-gpt2", filename=fname,
10
+ repo_type="dataset", local_dir=local_dir)
11
+ get("fineweb_val_%06d.bin" % 0)
12
+ num_chunks = 103 # full fineweb10B. Each chunk is 100M tokens
13
+ if len(sys.argv) >= 2: # we can pass an argument to download less
14
+ num_chunks = int(sys.argv[1])
15
+ for i in range(1, num_chunks+1):
16
+ get("fineweb_train_%06d.bin" % i)
data/cached_finewebedu10B.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from huggingface_hub import hf_hub_download
4
+ # Download the GPT-2 tokens of FinewebEDU10B from huggingface. This
5
+ # saves about an hour of startup time compared to regenerating them.
6
+ def get(fname):
7
+ local_dir = os.path.join(os.path.dirname(__file__), 'finewebedu10B')
8
+ if not os.path.exists(os.path.join(local_dir, fname)):
9
+ hf_hub_download(repo_id="kjj0/finewebedu10B-gpt2", filename=fname,
10
+ repo_type="dataset", local_dir=local_dir)
11
+ get("finewebedu_val_%06d.bin" % 0)
12
+ num_chunks = 99 # full FinewebEDU10B. Each chunk is 100M tokens
13
+ if len(sys.argv) >= 2: # we can pass an argument to download less
14
+ num_chunks = int(sys.argv[1])
15
+ for i in range(1, num_chunks+1):
16
+ get("finewebedu_train_%06d.bin" % i)
data/fineweb.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FineWeb dataset (for srs pretraining)
3
+ https://huggingface.co/datasets/HuggingFaceFW/fineweb
4
+
5
+ example doc to highlight the structure of the dataset:
6
+ {
7
+ "text": "Posted by mattsmith on 20th April 2012\nStraight from...",
8
+ "id": "<urn:uuid:d853d453-196e-4488-a411-efc2b26c40d2>",
9
+ "dump": "CC-MAIN-2013-20",
10
+ "url": "http://nleastchatter.com/philliesphandom/tag/freddy-galvis/",
11
+ "date": "2013-05-18T07:24:47Z",
12
+ "file_path": "s3://commoncrawl/long.../path.../file.gz",
13
+ "language": "en",
14
+ "language_score": 0.9185474514961243,
15
+ "token_count": 594
16
+ }
17
+ """
18
+ import os
19
+ import argparse
20
+ import multiprocessing as mp
21
+ import numpy as np
22
+ import tiktoken
23
+ # from huggingface_hub import snapshot_download
24
+ from datasets import load_dataset
25
+ from tqdm import tqdm
26
+ import argparse
27
+ import numpy as np
28
+ def write_datafile(filename, toks):
29
+ """
30
+ Saves token data as a .bin file, for reading in C.
31
+ - First comes a header with 256 int32s
32
+ - The tokens follow, each as a uint16
33
+ """
34
+ assert len(toks) < 2**31, "token count too large" # ~2.1B tokens
35
+ # construct the header
36
+ header = np.zeros(256, dtype=np.int32)
37
+ header[0] = 20240520 # magic
38
+ header[1] = 1 # version
39
+ header[2] = len(toks) # number of tokens after the 256*4 bytes of header (each 2 bytes as uint16)
40
+ # construct the tokens numpy array, if not already
41
+ if not isinstance(toks, np.ndarray) or not toks.dtype == np.uint16:
42
+ # validate that no token exceeds a uint16
43
+ maxtok = 2**16
44
+ assert all(0 <= t < maxtok for t in toks), "token dictionary too large for uint16"
45
+ toks_np = np.array(toks, dtype=np.uint16)
46
+ else:
47
+ toks_np = toks
48
+ # write to file
49
+ print(f"writing {len(toks):,} tokens to {filename}")
50
+ with open(filename, "wb") as f:
51
+ f.write(header.tobytes())
52
+ f.write(toks_np.tobytes())
53
+ # ------------------------------------------
54
+
55
+ parser = argparse.ArgumentParser(description="FineWeb dataset preprocessing")
56
+ parser.add_argument("-v", "--version", type=str, default="10B", help="Which version of fineweb to use 10B|100B")
57
+ parser.add_argument("-s", "--shard_size", type=int, default=10**8, help="Size of each shard in tokens")
58
+ args = parser.parse_args()
59
+
60
+ # FineWeb has a few possible subsamples available
61
+ assert args.version in ["10B", "100B"], "version must be one of 10B, 100B"
62
+ if args.version == "10B":
63
+ local_dir = "fineweb10B"
64
+ remote_name = "sample-10BT"
65
+ elif args.version == "100B":
66
+ local_dir = "fineweb100B"
67
+ remote_name = "sample-100BT"
68
+
69
+ # create the cache the local directory if it doesn't exist yet
70
+ DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
71
+ os.makedirs(DATA_CACHE_DIR, exist_ok=True)
72
+
73
+ # download the dataset
74
+ fw = load_dataset("HuggingFaceFW/fineweb", name=remote_name, split="train")
75
+
76
+ # init the tokenizer
77
+ enc = tiktoken.get_encoding("gpt2")
78
+ eot = enc._special_tokens['<|endoftext|>'] # end of text token
79
+ def tokenize(doc):
80
+ # tokenizes a single document and returns a numpy array of uint16 tokens
81
+ tokens = [eot] # the special <|endoftext|> token delimits all documents
82
+ tokens.extend(enc.encode_ordinary(doc["text"]))
83
+ tokens_np = np.array(tokens)
84
+ assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
85
+ tokens_np_uint16 = tokens_np.astype(np.uint16)
86
+ return tokens_np_uint16
87
+
88
+ # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
89
+ nprocs = max(1, os.cpu_count() - 2) # don't hog the entire system
90
+ with mp.Pool(nprocs) as pool:
91
+ shard_index = 0
92
+ # preallocate buffer to hold current shard
93
+ all_tokens_np = np.empty((args.shard_size,), dtype=np.uint16)
94
+ token_count = 0
95
+ progress_bar = None
96
+ for tokens in pool.imap(tokenize, fw, chunksize=16):
97
+
98
+ # is there enough space in the current shard for the new tokens?
99
+ if token_count + len(tokens) < args.shard_size:
100
+ # simply append tokens to current shard
101
+ all_tokens_np[token_count:token_count+len(tokens)] = tokens
102
+ token_count += len(tokens)
103
+ # update progress bar
104
+ if progress_bar is None:
105
+ progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}")
106
+ progress_bar.update(len(tokens))
107
+ else:
108
+ # write the current shard and start a new one
109
+ split = "val" if shard_index == 0 else "train"
110
+ filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
111
+ # split the document into whatever fits in this shard; the remainder goes to next one
112
+ remainder = args.shard_size - token_count
113
+ progress_bar.update(remainder)
114
+ all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
115
+ write_datafile(filename, all_tokens_np)
116
+ shard_index += 1
117
+ progress_bar = None
118
+ # populate the next shard with the leftovers of the current doc
119
+ all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
120
+ token_count = len(tokens)-remainder
121
+
122
+ # write any remaining tokens as the last shard
123
+ if token_count != 0:
124
+ split = "val" if shard_index == 0 else "train"
125
+ filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
126
+ write_datafile(filename, all_tokens_np[:token_count])
data/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ datasets
2
+ tiktoken
img/algo_optimizer.png ADDED

Git LFS Details

  • SHA256: ff86376962541686a79693f3e54e8bf48f0ff1bde636b40b9fcdd3c2aa44c9d6
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
img/dofa.jpg ADDED
img/fig_optimizer.png ADDED

Git LFS Details

  • SHA256: 99255caf301f4a3a83ade645b401db2578585926f1f92e782adcb5dba1866e45
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
img/fig_tuned_nanogpt.png ADDED

Git LFS Details

  • SHA256: 4b0d37d36424417cb10988d541b648755aa8d8cd2d3fe2d2b404001e170858c4
  • Pointer size: 131 Bytes
  • Size of remote file: 282 kB
img/nanogpt_speedrun51.png ADDED

Git LFS Details

  • SHA256: 35189cbc33c0c7a0efdc773a4e313e5bd1d9eec3e6a8fdf86f50e091df01c3e1
  • Pointer size: 131 Bytes
  • Size of remote file: 180 kB
img/nanogpt_speedrun52.png ADDED
img/nanogpt_speedrun53.png ADDED

Git LFS Details

  • SHA256: c7f7e94c99c77d9dbbced1b1585a545682a3ef4a27336c8308da0043182f3415
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
img/nanogpt_speedrun54.png ADDED
records/track_1_short/2024-06-06_AdamW/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ This is the log for my baseline AdamW training to which I compared the new Muon and SOAP optimizers.
2
+
3
+ just the log, which is in the old llm.c format ("tel" lines are val loss)
4
+
5
+ this was batch size 2^19, so ~5B tokens
6
+
7
+ was learning rate 0.0018, warmup=250, warmdown=2000, betas=(0.9, 0.95) IIRC
8
+
records/track_1_short/2024-06-06_AdamW/f66d43d7-e449-4029-8adf-e8537bab49ea.log ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-09_SOAP/5bdc3988-496c-4232-b4ef-53764cb81c92.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-09_SOAP/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # SOAP record October 9 2024
2
+
3
+ * New sample efficiency record: <3.28 validation loss in 3.15B tokens
4
+ * Uses SOAP optimizer ([Vyas et al. 2024](https://arxiv.org/abs/2409.11321))
5
+ * 363ms/step - not a new wallclock record (SOAP is in active development to reduce the wallclock overhead for distributed training, so this may change)
6
+ * Set by Nikhil Vyas @vyasnikhil96. Hyperparameters also tuned slightly by me
7
+ * [https://x.com/vyasnikhil96/status/1842656792217858063](https://x.com/vyasnikhil96/status/1842656792217858063)
8
+ * [https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master](https://github.com/nikhilvyas/modded-nanogpt-SOAP/tree/master)
9
+
records/track_1_short/2024-10-09_SOAP/train_gpt2.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+
5
+ from itertools import chain
6
+
7
+ # Parts of the code are modifications of Pytorch's AdamW optimizer
8
+ # Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
9
+
10
+
11
+ class SOAP(optim.Optimizer):
12
+ """
13
+ Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
14
+
15
+ Parameters:
16
+ params (`Iterable[nn.parameter.Parameter]`):
17
+ Iterable of parameters to optimize or dictionaries defining parameter groups.
18
+ lr (`float`, *optional*, defaults to 0.003):
19
+ The learning rate to use.
20
+ betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
21
+ Adam's betas parameters (b1, b2).
22
+ shampoo_beta (`float`, *optional*, defaults to -1):
23
+ If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
24
+ eps (`float`, *optional*, defaults to 1e-08):
25
+ Adam's epsilon for numerical stability.
26
+ weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
27
+ precondition_frequency (`int`, *optional*, defaults to 10):
28
+ How often to update the preconditioner.
29
+ max_precond_dim (`int`, *optional*, defaults to 10000):
30
+ Maximum dimension of the preconditioner.
31
+ Set to 10000, so that we exclude most common vocab sizes while including layers.
32
+ merge_dims (`bool`, *optional*, defaults to `False`):
33
+ Whether or not to merge dimensions of the preconditioner.
34
+ precondition_1d (`bool`, *optional*, defaults to `False`):
35
+ Whether or not to precondition 1D gradients.
36
+ normalize_grads (`bool`, *optional*, defaults to `False`):
37
+ Whether or not to normalize gradients per layer.
38
+ Helps at large precondition_frequency (~100 in our experiments),
39
+ but hurts performance at small precondition_frequency (~10 in our experiments).
40
+ data_format (`str`, *optional*, defaults to `channels_first`):
41
+ Data format of the input for convolutional layers.
42
+ Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
43
+ correct_bias (`bool`, *optional*, defaults to `True`):
44
+ Whether or not to use bias correction in Adam.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ params,
50
+ lr: float = 3e-3,
51
+ betas=(0.95, 0.95),
52
+ shampoo_beta: float= -1,
53
+ eps: float = 1e-8,
54
+ weight_decay: float = 0.01,
55
+ precondition_frequency: int=10,
56
+ max_precond_dim: int=10000, #
57
+ merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
58
+ precondition_1d: bool = False,
59
+ normalize_grads: bool = False,
60
+ data_format: str = "channels_first",
61
+ correct_bias: bool = True,
62
+ ):
63
+ defaults = {
64
+ "lr": lr,
65
+ "betas": betas,
66
+ "shampoo_beta": shampoo_beta,
67
+ "eps": eps,
68
+ "weight_decay": weight_decay,
69
+ "precondition_frequency": precondition_frequency,
70
+ "max_precond_dim": max_precond_dim,
71
+ "merge_dims": merge_dims,
72
+ "precondition_1d": precondition_1d,
73
+ "normalize_grads": normalize_grads,
74
+ "correct_bias": correct_bias,
75
+ }
76
+ super().__init__(params, defaults)
77
+ self._data_format = data_format
78
+
79
+ def merge_dims(self, grad, max_precond_dim):
80
+ """
81
+ Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
82
+ """
83
+ assert self._data_format in ["channels_first", "channels_last"]
84
+ if self._data_format == "channels_last" and grad.dim() == 4:
85
+ grad = grad.permute(0, 3, 1, 2)
86
+ shape = grad.shape
87
+ new_shape = []
88
+
89
+ curr_shape = 1
90
+ for sh in shape:
91
+ temp_shape = curr_shape * sh
92
+ if temp_shape > max_precond_dim:
93
+ if curr_shape > 1:
94
+ new_shape.append(curr_shape)
95
+ curr_shape = sh
96
+ else:
97
+ new_shape.append(sh)
98
+ curr_shape = 1
99
+ else:
100
+ curr_shape = temp_shape
101
+
102
+ if curr_shape > 1 or len(new_shape)==0:
103
+ new_shape.append(curr_shape)
104
+
105
+ new_grad = grad.reshape(new_shape)
106
+ return new_grad
107
+
108
+ @torch.no_grad()
109
+ def step(self):
110
+ """
111
+ Performs a single optimization step.
112
+
113
+ Arguments:
114
+ closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
115
+ """
116
+ loss = None
117
+
118
+ for group in self.param_groups:
119
+ for p in group["params"]:
120
+ if p.grad is None:
121
+ continue
122
+ grad = p.grad
123
+
124
+ state = self.state[p]
125
+
126
+ if "step" not in state:
127
+ state["step"] = 0
128
+
129
+ # State initialization
130
+ if "exp_avg" not in state:
131
+ # Exponential moving average of gradient values
132
+ state["exp_avg"] = torch.zeros_like(grad)
133
+ # Exponential moving average of squared gradient values
134
+ state["exp_avg_sq"] = torch.zeros_like(grad)
135
+
136
+ if 'Q' not in state:
137
+ self.init_preconditioner(
138
+ grad,
139
+ state,
140
+ precondition_frequency=group['precondition_frequency'],
141
+ precondition_1d=group['precondition_1d'],
142
+ shampoo_beta=(group['shampoo_beta'] if group['shampoo_beta'] >= 0 else group["betas"][1]),
143
+ max_precond_dim=group['max_precond_dim'],
144
+ merge_dims=group["merge_dims"],
145
+ )
146
+ self.update_preconditioner(grad, state,
147
+ max_precond_dim=group['max_precond_dim'],
148
+ merge_dims=group["merge_dims"],
149
+ precondition_1d=group["precondition_1d"])
150
+ continue # first step is skipped so that we never use the current gradients in the projection.
151
+
152
+ # Projecting gradients to the eigenbases of Shampoo's preconditioner
153
+ # i.e. projecting to the eigenbases of matrices in state['GG']
154
+ grad_projected = self.project(grad, state, merge_dims=group["merge_dims"],
155
+ max_precond_dim=group['max_precond_dim'])
156
+
157
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
158
+ beta1, beta2 = group["betas"]
159
+
160
+ state["step"] += 1
161
+
162
+ # Decay the first and second moment running average coefficient
163
+ # In-place operations to update the averages at the same time
164
+ exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
165
+ exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2))
166
+
167
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
168
+
169
+ # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
170
+ # i.e. projecting to the eigenbases of matrices in state['GG']
171
+ exp_avg_projected = self.project(exp_avg, state, merge_dims=group["merge_dims"],
172
+ max_precond_dim=group['max_precond_dim'])
173
+
174
+ step_size = group["lr"]
175
+ if group["correct_bias"]:
176
+ bias_correction1 = 1.0 - beta1 ** (state["step"])
177
+ bias_correction2 = 1.0 - beta2 ** (state["step"])
178
+ step_size = step_size * (bias_correction2 ** .5) / bias_correction1
179
+
180
+ # Projecting back the preconditioned (by Adam) exponential moving average of gradients
181
+ # to the original space
182
+ norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=group["merge_dims"],
183
+ max_precond_dim=group['max_precond_dim'])
184
+
185
+ if group["normalize_grads"]:
186
+ norm_grad = norm_grad / (1e-30+torch.mean(norm_grad**2)**0.5)
187
+
188
+ p.add_(norm_grad, alpha=-step_size)
189
+
190
+
191
+ # From AdamW code: Just adding the square of the weights to the loss function is *not*
192
+ # the correct way of using L2 regularization/weight decay with Adam,
193
+ # since that will interact with the m and v parameters in strange ways.
194
+ #
195
+ # Instead we want to decay the weights in a manner that doesn't interact
196
+ # with the m/v parameters. This is equivalent to adding the square
197
+ # of the weights to the loss with plain (non-momentum) SGD.
198
+ # Add weight decay at the end (fixed version)
199
+ if group["weight_decay"] > 0.0:
200
+ p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
201
+
202
+ # Update is done after the gradient step to avoid using current gradients in the projection.
203
+ self.update_preconditioner(grad, state,
204
+ max_precond_dim=group['max_precond_dim'],
205
+ merge_dims=group["merge_dims"],
206
+ precondition_1d=group["precondition_1d"])
207
+
208
+ return loss
209
+
210
+ def init_preconditioner(self, grad, state, precondition_frequency=10,
211
+ shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,
212
+ merge_dims=False):
213
+ """
214
+ Initializes the preconditioner matrices (L and R in the paper).
215
+ """
216
+ state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
217
+ if grad.dim() == 1:
218
+ if not precondition_1d or grad.shape[0] > max_precond_dim:
219
+ state['GG'].append([])
220
+ else:
221
+ state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device))
222
+ else:
223
+ if merge_dims:
224
+ grad = self.merge_dims(grad, max_precond_dim)
225
+
226
+ for sh in grad.shape:
227
+ if sh > max_precond_dim:
228
+ state['GG'].append([])
229
+ else:
230
+ state['GG'].append(torch.zeros(sh, sh, device=grad.device))
231
+
232
+ state['Q'] = None # Will hold all the eigenbases of the preconditioner.
233
+ state['precondition_frequency'] = precondition_frequency
234
+ state['shampoo_beta'] = shampoo_beta
235
+
236
+ def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
237
+ """
238
+ Projects the gradient to the eigenbases of the preconditioner.
239
+ """
240
+ original_shape = grad.shape
241
+ if merge_dims:
242
+ if grad.dim() == 4 and self._data_format == 'channels_last':
243
+ permuted_shape = grad.permute(0, 3, 1, 2).shape
244
+ grad = self.merge_dims(grad, max_precond_dim)
245
+
246
+ for mat in state['Q']:
247
+ if len(mat) > 0:
248
+ grad = torch.tensordot(
249
+ grad,
250
+ mat,
251
+ dims=[[0], [0]],
252
+ )
253
+ else:
254
+ permute_order = list(range(1, len(grad.shape))) + [0]
255
+ grad = grad.permute(permute_order)
256
+
257
+ if merge_dims:
258
+ if self._data_format == 'channels_last' and len(original_shape) == 4:
259
+ grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
260
+ else:
261
+ grad = grad.reshape(original_shape)
262
+ return grad
263
+
264
+ def update_preconditioner(self, grad, state,
265
+ max_precond_dim=10000, merge_dims=False, precondition_1d=False):
266
+ """
267
+ Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
268
+ """
269
+ if grad.dim() == 1:
270
+ if precondition_1d and grad.shape[0] <= max_precond_dim:
271
+ state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])
272
+ else:
273
+ if merge_dims:
274
+ new_grad = self.merge_dims(grad, max_precond_dim)
275
+ for idx, sh in enumerate(new_grad.shape):
276
+ if sh <= max_precond_dim:
277
+ outer_product = torch.tensordot(
278
+ new_grad,
279
+ new_grad,
280
+ dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,
281
+ )
282
+ state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
283
+ else:
284
+ for idx, sh in enumerate(grad.shape):
285
+ if sh <= max_precond_dim:
286
+ outer_product = torch.tensordot(
287
+ grad,
288
+ grad,
289
+ # Contracts across all dimensions except for k.
290
+ dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
291
+ )
292
+ state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
293
+
294
+ if state['Q'] is None:
295
+ state['Q'] = self.get_orthogonal_matrix(state['GG'])
296
+ if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:
297
+ state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)
298
+
299
+ def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
300
+ """
301
+ Projects the gradient back to the original space.
302
+ """
303
+ original_shape = grad.shape
304
+ if merge_dims:
305
+ if self._data_format == 'channels_last' and grad.dim() == 4:
306
+ permuted_shape = grad.permute(0, 3, 1, 2).shape
307
+ grad = self.merge_dims(grad, max_precond_dim)
308
+ for mat in state['Q']:
309
+ if len(mat) > 0:
310
+ grad = torch.tensordot(
311
+ grad,
312
+ mat,
313
+ dims=[[0], [1]],
314
+ )
315
+ else:
316
+ permute_order = list(range(1, len(grad.shape))) + [0]
317
+ grad = grad.permute(permute_order)
318
+
319
+ if merge_dims:
320
+ if self._data_format == 'channels_last' and len(original_shape) == 4:
321
+ grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
322
+ else:
323
+ grad = grad.reshape(original_shape)
324
+ return grad
325
+
326
+
327
+ def get_orthogonal_matrix(self, mat):
328
+ """
329
+ Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
330
+ """
331
+ matrix = []
332
+ for m in mat:
333
+ if len(m) == 0:
334
+ matrix.append([])
335
+ continue
336
+ if m.data.dtype != torch.float:
337
+ float_data = False
338
+ original_type = m.data.dtype
339
+ original_device = m.data.device
340
+ matrix.append(m.data.float())
341
+ else:
342
+ float_data = True
343
+ matrix.append(m.data)
344
+
345
+ final = []
346
+ for m in matrix:
347
+ if len(m) == 0:
348
+ final.append([])
349
+ continue
350
+ try:
351
+ _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device))
352
+ except:
353
+ _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device))
354
+ Q = Q.to(m.dtype)
355
+ Q = torch.flip(Q, [1])
356
+
357
+ if not float_data:
358
+ Q = Q.to(original_device).type(original_type)
359
+ final.append(Q)
360
+ return final
361
+
362
+
363
+ def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
364
+ """
365
+ Computes the eigenbases of the preconditioner using one round of power iteration
366
+ followed by torch.linalg.qr decomposition.
367
+ """
368
+ precond_list = state['GG']
369
+ orth_list = state['Q']
370
+
371
+ matrix = []
372
+ orth_matrix = []
373
+ for m,o in zip(precond_list, orth_list):
374
+ if len(m) == 0:
375
+ matrix.append([])
376
+ orth_matrix.append([])
377
+ continue
378
+ if m.data.dtype != torch.float:
379
+ float_data = False
380
+ original_type = m.data.dtype
381
+ original_device = m.data.device
382
+ matrix.append(m.data.float())
383
+ orth_matrix.append(o.data.float())
384
+ else:
385
+ float_data = True
386
+ matrix.append(m.data.float())
387
+ orth_matrix.append(o.data.float())
388
+
389
+ orig_shape = state['exp_avg_sq'].shape
390
+ if self._data_format == 'channels_last' and len(orig_shape) == 4:
391
+ permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape
392
+ if merge_dims:
393
+ exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)
394
+ else:
395
+ exp_avg_sq = state['exp_avg_sq']
396
+
397
+ final = []
398
+ for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
399
+ if len(m)==0:
400
+ final.append([])
401
+ continue
402
+ est_eig = torch.diag(o.T @ m @ o)
403
+ sort_idx = torch.argsort(est_eig, descending=True)
404
+ exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
405
+ o = o[:,sort_idx]
406
+ power_iter = m @ o
407
+ Q, _ = torch.linalg.qr(power_iter)
408
+
409
+ if not float_data:
410
+ Q = Q.to(original_device).type(original_type)
411
+ final.append(Q)
412
+
413
+ if merge_dims:
414
+ if self._data_format == 'channels_last' and len(orig_shape) == 4:
415
+ exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
416
+ else:
417
+ exp_avg_sq = exp_avg_sq.reshape(orig_shape)
418
+
419
+ state['exp_avg_sq'] = exp_avg_sq
420
+ return final
421
+
422
+ import os
423
+ import sys
424
+ with open(sys.argv[0]) as f:
425
+ code = f.read() # read the code of this file ASAP, for logging
426
+ import uuid
427
+ import glob
428
+ import time
429
+ from dataclasses import dataclass
430
+
431
+ import numpy as np
432
+ import torch
433
+ from torch import nn
434
+ import torch.nn.functional as F
435
+ import torch.distributed as dist
436
+ import torch._inductor.config as config
437
+ from torch.nn.parallel import DistributedDataParallel as DDP
438
+
439
+ # -----------------------------------------------------------------------------
440
+ # PyTorch nn.Module definitions for the GPT-2 model
441
+
442
+ class Rotary(torch.nn.Module):
443
+
444
+ def __init__(self, dim, base=10000):
445
+ super().__init__()
446
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
447
+ self.register_buffer("inv_freq", inv_freq)
448
+ self.seq_len_cached = None
449
+ self.cos_cached = None
450
+ self.sin_cached = None
451
+
452
+ def forward(self, x):
453
+ seq_len = x.shape[1]
454
+ if seq_len != self.seq_len_cached:
455
+ self.seq_len_cached = seq_len
456
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
457
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
458
+ self.cos_cached = freqs.cos()
459
+ self.sin_cached = freqs.sin()
460
+ return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
461
+
462
+ def apply_rotary_emb(x, cos, sin):
463
+ assert x.ndim == 4 # multihead attention
464
+ d = x.shape[3]//2
465
+ x1 = x[..., :d]
466
+ x2 = x[..., d:]
467
+ y1 = x1 * cos + x2 * sin
468
+ y2 = x1 * (-sin) + x2 * cos
469
+ return torch.cat([y1, y2], 3)
470
+
471
+ def rmsnorm(x0, eps=1e-6):
472
+ x = x0.float()
473
+ x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
474
+ return x.type_as(x0)
475
+
476
+ class CausalSelfAttention(nn.Module):
477
+
478
+ def __init__(self, config):
479
+ super().__init__()
480
+ self.n_head = config.n_head
481
+ self.n_embd = config.n_embd
482
+ self.head_dim = self.n_embd // self.n_head
483
+ assert self.n_embd % self.n_head == 0
484
+ self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
485
+ self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
486
+ self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
487
+ # output projection
488
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
489
+ self.rotary = Rotary(self.head_dim)
490
+
491
+ def forward(self, x):
492
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
493
+ q, k, v = self.c_q(x), self.c_k(x), self.c_v(x)
494
+ k = k.view(B, T, self.n_head, self.head_dim)
495
+ q = q.view(B, T, self.n_head, self.head_dim)
496
+ v = v.view(B, T, self.n_head, self.head_dim)
497
+ cos, sin = self.rotary(q)
498
+ q = apply_rotary_emb(q, cos, sin)
499
+ k = apply_rotary_emb(k, cos, sin)
500
+ y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
501
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
502
+ # output projection
503
+ y = self.c_proj(y)
504
+ return y
505
+
506
+ class MLP(nn.Module):
507
+
508
+ def __init__(self, config):
509
+ super().__init__()
510
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
511
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
512
+
513
+ def forward(self, x):
514
+ x = self.c_fc(x)
515
+ x = F.gelu(x)
516
+ x = self.c_proj(x)
517
+ return x
518
+
519
+ class Block(nn.Module):
520
+
521
+ def __init__(self, config):
522
+ super().__init__()
523
+ self.attn = CausalSelfAttention(config)
524
+ self.mlp = MLP(config)
525
+ self.attn_scale = (1 / (2 * config.n_layer)**0.5)
526
+
527
+ def forward(self, x):
528
+ x = x + self.attn_scale * self.attn(rmsnorm(x))
529
+ x = x + self.mlp(rmsnorm(x))
530
+ return x
531
+
532
+ # -----------------------------------------------------------------------------
533
+ # The main GPT-2 model
534
+
535
+ @dataclass
536
+ class GPTConfig:
537
+ vocab_size : int = 50257
538
+ n_layer : int = 12
539
+ n_head : int = 12
540
+ n_embd : int = 768
541
+
542
+ class GPT(nn.Module):
543
+
544
+ def __init__(self, config):
545
+ super().__init__()
546
+ self.config = config
547
+
548
+ self.transformer = nn.ModuleDict(dict(
549
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
550
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
551
+ ))
552
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
553
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
554
+
555
+ def forward(self, idx, targets=None, return_logits=True):
556
+ b, t = idx.size()
557
+ pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # shape (t)
558
+
559
+ # forward the GPT model itself
560
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
561
+
562
+ for block in self.transformer.h:
563
+ x = block(x)
564
+ x = rmsnorm(x)
565
+
566
+ if targets is not None:
567
+ # if we are given some desired targets also calculate the loss
568
+ logits = self.lm_head(x)
569
+ logits = logits.float() # use tf32/fp32 for logits
570
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
571
+ else:
572
+ # inference-time mini-optimization: only forward the lm_head on the very last position
573
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
574
+ logits = logits.float() # use tf32/fp32 for logits
575
+ loss = None
576
+
577
+ # there are performance reasons why not returning logits is prudent, if not needed
578
+ if not return_logits:
579
+ logits = None
580
+
581
+ return logits, loss
582
+
583
+ # -----------------------------------------------------------------------------
584
+ # Our own simple Distributed Data Loader
585
+
586
+ def _peek_data_shard(filename):
587
+ # only reads the header, returns header data
588
+ with open(filename, "rb") as f:
589
+ # first read the header, which is 256 int32 integers (4 bytes each)
590
+ header = np.frombuffer(f.read(256*4), dtype=np.int32)
591
+ if header[0] != 20240520:
592
+ print("ERROR: magic number mismatch in the data .bin file!")
593
+ print("---> HINT: Are you passing in a correct file with --input_bin?")
594
+ print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
595
+ print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
596
+ exit(1)
597
+ assert header[1] == 1, "unsupported version"
598
+ ntok = header[2] # number of tokens (claimed)
599
+ return ntok # for now just return the number of tokens
600
+
601
+ def _load_data_shard(filename):
602
+ with open(filename, "rb") as f:
603
+ # first read the header, which is 256 int32 integers (4 bytes each)
604
+ header = np.frombuffer(f.read(256*4), dtype=np.int32)
605
+ assert header[0] == 20240520, "magic number mismatch in the data .bin file"
606
+ assert header[1] == 1, "unsupported version"
607
+ ntok = header[2] # number of tokens (claimed)
608
+ # the rest of it are tokens, stored as uint16
609
+ tokens = np.frombuffer(f.read(), dtype=np.uint16)
610
+ assert len(tokens) == ntok, "number of tokens read does not match header?"
611
+ return tokens
612
+
613
+ class DistributedDataLoader:
614
+ def __init__(self, filename_pattern, B, T, process_rank, num_processes):
615
+ self.process_rank = process_rank
616
+ self.num_processes = num_processes
617
+ self.B = B
618
+ self.T = T
619
+
620
+ # glob files that match the pattern
621
+ self.files = sorted(glob.glob(filename_pattern))
622
+ assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
623
+
624
+ # load and validate all data shards, count number of tokens in total
625
+ ntok_total = 0
626
+ for fname in self.files:
627
+ shard_ntok = _peek_data_shard(fname)
628
+ assert shard_ntok >= num_processes * B * T + 1
629
+ ntok_total += int(shard_ntok)
630
+ self.ntok_total = ntok_total
631
+
632
+ # kick things off
633
+ self.reset()
634
+
635
+ def reset(self):
636
+ self.current_shard = 0
637
+ self.current_position = self.process_rank * self.B * self.T
638
+ self.tokens = _load_data_shard(self.files[self.current_shard])
639
+
640
+ def advance(self): # advance to next data shard
641
+ self.current_shard = (self.current_shard + 1) % len(self.files)
642
+ self.current_position = self.process_rank * self.B * self.T
643
+ self.tokens = _load_data_shard(self.files[self.current_shard])
644
+
645
+ def next_batch(self):
646
+ B = self.B
647
+ T = self.T
648
+ buf = self.tokens[self.current_position : self.current_position+B*T+1]
649
+ buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
650
+ x = (buf[:-1]).view(B, T) # inputs
651
+ y = (buf[1:]).view(B, T) # targets
652
+ # advance current position and load next shard if necessary
653
+ self.current_position += B * T * self.num_processes
654
+ if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
655
+ self.advance()
656
+ return x.cuda(), y.cuda()
657
+
658
+ # -----------------------------------------------------------------------------
659
+ # int main
660
+
661
+ @dataclass
662
+ class Hyperparameters:
663
+ # data hyperparams
664
+ input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
665
+ input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
666
+ # optimization hyperparams
667
+ batch_size : int = 8*64 # batch size, in sequences, across all devices
668
+ device_batch_size : int = 64 # batch size, in sequences, per device
669
+ sequence_length : int = 1024 # sequence length, in tokens
670
+ num_iterations : int = 6000 # number of iterations to run
671
+ learning_rate : float = 0.0036
672
+ warmup_iters : int = 250
673
+ warmdown_iters : int = 1800 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
674
+ # evaluation and logging hyperparams
675
+ val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
676
+ val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
677
+ save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
678
+ args = Hyperparameters()
679
+
680
+ # set up DDP (distributed data parallel). torchrun sets this env variable
681
+ assert torch.cuda.is_available()
682
+ dist.init_process_group(backend='nccl')
683
+ ddp_rank = int(os.environ['RANK'])
684
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
685
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
686
+ device = f'cuda:{ddp_local_rank}'
687
+ torch.cuda.set_device(device)
688
+ print(f"using device: {device}")
689
+ master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
690
+
691
+ # convenience variables
692
+ B, T = args.device_batch_size, args.sequence_length
693
+ # calculate the number of steps to take in the val loop.
694
+ assert args.val_tokens % (B * T * ddp_world_size) == 0
695
+ val_steps = args.val_tokens // (B * T * ddp_world_size)
696
+ # calculate the steps of gradient accumulation required to attain the desired global batch size.
697
+ assert args.batch_size % (B * ddp_world_size) == 0
698
+ train_accumulation_steps = args.batch_size // (B * ddp_world_size)
699
+
700
+ # load tokens
701
+ train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
702
+ val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
703
+ if master_process:
704
+ print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
705
+ print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
706
+ x, y = train_loader.next_batch()
707
+
708
+ # init the model from scratch
709
+ num_vocab = 50257
710
+ model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=12, n_embd=768))
711
+ model = model.cuda()
712
+ if hasattr(config, "coordinate_descent_tuning"):
713
+ config.coordinate_descent_tuning = True # suggested by @Chillee
714
+ model = torch.compile(model)
715
+ # here we wrap model into DDP container
716
+ model = DDP(model, device_ids=[ddp_local_rank])
717
+ raw_model = model.module # always contains the "raw" unwrapped model
718
+ ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
719
+
720
+ # init the optimizer(s)
721
+ optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95),
722
+ weight_decay=0, fused=True)
723
+ optimizer2 = SOAP(raw_model.transformer.h.parameters(), lr=0.5*args.learning_rate, betas=(.95, .95), weight_decay=0, precondition_frequency=10)
724
+ optimizers = [optimizer1, optimizer2]
725
+ # learning rate decay scheduler (linear warmup and warmdown)
726
+ def get_lr(it):
727
+ assert it <= args.num_iterations
728
+ # 1) linear warmup for warmup_iters steps
729
+ if it < args.warmup_iters:
730
+ return (it+1) / args.warmup_iters
731
+ # 2) constant lr for a while
732
+ elif it < args.num_iterations - args.warmdown_iters:
733
+ return 1.0
734
+ # 3) linear warmdown
735
+ else:
736
+ decay_ratio = (args.num_iterations - it) / args.warmdown_iters
737
+ return decay_ratio
738
+ schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
739
+
740
+ # begin logging
741
+ if master_process:
742
+ run_id = str(uuid.uuid4())
743
+ logdir = 'logs/%s/' % run_id
744
+ os.makedirs(logdir, exist_ok=True)
745
+ logfile = 'logs/%s.txt' % run_id
746
+ # create the log file
747
+ with open(logfile, "w") as f:
748
+ # begin the log by printing this file (the Python code)
749
+ f.write('='*100 + '\n')
750
+ f.write(code)
751
+ f.write('='*100 + '\n')
752
+ # log information about the hardware/software environment this is running on
753
+ # and print the full `nvidia-smi` to file
754
+ f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
755
+ import subprocess
756
+ result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
757
+ f.write(f'{result.stdout}\n')
758
+ f.write('='*100 + '\n')
759
+
760
+ training_time_ms = 0
761
+ # start the clock
762
+ torch.cuda.synchronize()
763
+ t0 = time.time()
764
+ # begin training
765
+ train_loader.reset()
766
+ for step in range(args.num_iterations + 1):
767
+ last_step = (step == args.num_iterations)
768
+ # This effectively ignores timing first 10 steps, which are slower for weird reasons.
769
+ # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
770
+ # steps with dummy data first, and then re-initialize the model and reset the loader.
771
+ if step == 10:
772
+ training_time_ms = 0
773
+ t0 = time.time()
774
+ timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
775
+
776
+ # once in a while evaluate the validation dataset
777
+ if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
778
+ # stop the clock
779
+ torch.cuda.synchronize()
780
+ training_time_ms += 1000 * (time.time() - t0)
781
+ # run validation batches
782
+ model.eval()
783
+ val_loader.reset()
784
+ val_loss = 0.0
785
+ for _ in range(val_steps):
786
+ x_val, y_val = val_loader.next_batch()
787
+ with torch.no_grad(): # of course, we'd like to use ctx here too, but that creates a torch.compile error for some reason
788
+ _, loss = model(x_val, y_val, return_logits=False)
789
+ val_loss += loss
790
+ dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
791
+ val_loss /= val_steps
792
+ # log val loss to console and to logfile
793
+ if master_process:
794
+ print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
795
+ with open(logfile, "a") as f:
796
+ f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
797
+ # start the clock again
798
+ torch.cuda.synchronize()
799
+ t0 = time.time()
800
+
801
+ if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
802
+ # stop the clock
803
+ torch.cuda.synchronize()
804
+ training_time_ms += 1000 * (time.time() - t0)
805
+ # save the state of the training process
806
+ log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
807
+ torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
808
+ # start the clock again
809
+ torch.cuda.synchronize()
810
+ t0 = time.time()
811
+
812
+ # bit confusing: we want to make sure to eval on 0th iteration
813
+ # but also after the very last iteration. so we loop for step <= num_iterations
814
+ # instead of just < num_iterations (one extra due to <=), only to do
815
+ # the validation/sampling one last time, and then we break right here as we're done.
816
+ if last_step:
817
+ break
818
+
819
+ # --------------- TRAINING SECTION BEGIN -----------------
820
+ model.train()
821
+ for i in range(1, train_accumulation_steps+1):
822
+ # forward pass
823
+ with ctx:
824
+ _, loss = model(x, y, return_logits=False)
825
+ train_loss = loss.detach()
826
+ # advance the dataset for the next batch
827
+ x, y = train_loader.next_batch()
828
+ # backward pass
829
+ if i < train_accumulation_steps:
830
+ with model.no_sync(): # there's no need to sync gradients every accumulation step
831
+ loss.backward()
832
+ else:
833
+ loss.backward() # just sync on the last step
834
+ for p in model.parameters():
835
+ p.grad /= train_accumulation_steps
836
+ # step the optimizers and schedulers
837
+ for opt, sched in zip(optimizers, schedulers):
838
+ opt.step()
839
+ sched.step()
840
+ # null the gradients
841
+ model.zero_grad(set_to_none=True)
842
+ # --------------- TRAINING SECTION END -------------------
843
+ # everything that follows now is just diagnostics, prints, logging, etc.
844
+
845
+ #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
846
+ if master_process:
847
+ approx_time = training_time_ms + 1000 * (time.time() - t0)
848
+ print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
849
+ with open(logfile, "a") as f:
850
+ f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")
851
+
852
+ if master_process:
853
+ print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
854
+
855
+ # -------------------------------------------------------------------------
856
+ # clean up nice
857
+ dist.destroy_process_group()
records/track_1_short/2024-10-10_Muon/eb5659d0-fb6a-49e5-a311-f1f89412f726.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-10_Muon/train_gpt2.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ with open(sys.argv[0]) as f:
4
+ code = f.read() # read the code of this file ASAP, for logging
5
+ import uuid
6
+ import glob
7
+ import time
8
+ from dataclasses import dataclass
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import torch.distributed as dist
15
+ import torch._inductor.config as config
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+
18
+ # -----------------------------------------------------------------------------
19
+ # Muon optimizer
20
+
21
+ def zeropower_via_svd(G, steps=None):
22
+ U, S, V = G.svd()
23
+ return U @ V.T
24
+
25
+ @torch.compile
26
+ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
27
+ """
28
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
32
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33
+ where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
34
+ performance at all relative to UV^T, where USV^T = G is the SVD.
35
+ """
36
+ assert len(G.shape) == 2
37
+ a, b, c = (3.4445, -4.7750, 2.0315)
38
+ X = G.bfloat16() / (G.norm() + eps) # ensure top singular value <= 1
39
+ if G.size(0) > G.size(1):
40
+ X = X.T
41
+ for _ in range(steps):
42
+ A = X @ X.T
43
+ B = A @ X
44
+ X = a * X + b * B + c * A @ B
45
+ if G.size(0) > G.size(1):
46
+ X = X.T
47
+ return X.to(G.dtype)
48
+
49
+ zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
50
+
51
+ class Muon(torch.optim.Optimizer):
52
+ """
53
+ Muon: MomentUm Orthogonalized by Newton-schulz
54
+
55
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
56
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
57
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
58
+ the advantage that it can be stably run in bfloat16 on the GPU.
59
+
60
+ Some warnings:
61
+ - This optimizer assumes that all parameters passed in are 2D.
62
+ - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
63
+ parameters; those should all be optimized by a standard method (e.g., AdamW).
64
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
65
+ - We believe it is unlikely to work well for training with small batch size.
66
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
67
+ - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
68
+
69
+ Arguments:
70
+ lr: The learning rate used by the internal SGD.
71
+ momentum: The momentum used by the internal SGD.
72
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
73
+ backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
74
+ backend_steps: The number of iteration steps to use in the backend, if it is iterative.
75
+ """
76
+ def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5):
77
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
78
+ super().__init__(params, defaults)
79
+
80
+ def step(self):
81
+ for group in self.param_groups:
82
+ lr = group['lr']
83
+ momentum = group['momentum']
84
+ zeropower_backend = zeropower_backends[group['backend']]
85
+ for p in group['params']:
86
+ g = p.grad
87
+ if g is None:
88
+ continue
89
+ state = self.state[p]
90
+ if 'momentum_buffer' not in state:
91
+ state['momentum_buffer'] = torch.zeros_like(g)
92
+ buf = state['momentum_buffer']
93
+ buf.mul_(momentum).add_(g)
94
+ if group['nesterov']:
95
+ g = g.add(buf, alpha=momentum)
96
+ if g.size(0) == 3 * g.size(1): # split grouped QKV parameters
97
+ g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))])
98
+ scale = g.size(1)**0.5
99
+ else:
100
+ g = zeropower_backend(g, steps=group['backend_steps'])
101
+ scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1
102
+ p.data.add_(g, alpha=-lr * scale)
103
+
104
+ # -----------------------------------------------------------------------------
105
+ # PyTorch nn.Module definitions for the GPT-2 model
106
+
107
+ class Rotary(torch.nn.Module):
108
+
109
+ def __init__(self, dim, base=10000):
110
+ super().__init__()
111
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
112
+ self.register_buffer("inv_freq", inv_freq)
113
+ self.seq_len_cached = None
114
+ self.cos_cached = None
115
+ self.sin_cached = None
116
+
117
+ def forward(self, x):
118
+ seq_len = x.shape[1]
119
+ if seq_len != self.seq_len_cached:
120
+ self.seq_len_cached = seq_len
121
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
122
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
123
+ self.cos_cached = freqs.cos()
124
+ self.sin_cached = freqs.sin()
125
+ return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
126
+
127
+ def apply_rotary_emb(x, cos, sin):
128
+ assert x.ndim == 4 # multihead attention
129
+ d = x.shape[3]//2
130
+ x1 = x[..., :d]
131
+ x2 = x[..., d:]
132
+ y1 = x1 * cos + x2 * sin
133
+ y2 = x1 * (-sin) + x2 * cos
134
+ return torch.cat([y1, y2], 3)
135
+
136
+ def rmsnorm(x0, eps=1e-6):
137
+ x = x0.float()
138
+ x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
139
+ return x.type_as(x0)
140
+
141
+ class CausalSelfAttention(nn.Module):
142
+
143
+ def __init__(self, config):
144
+ super().__init__()
145
+ self.n_head = config.n_head
146
+ self.n_embd = config.n_embd
147
+ self.head_dim = self.n_embd // self.n_head
148
+ assert self.n_embd % self.n_head == 0
149
+ # key, query, value projections for all heads, but in a batch
150
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
151
+ # output projection
152
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
153
+ self.rotary = Rotary(self.head_dim)
154
+
155
+ def forward(self, x):
156
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
157
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
158
+ qkv = self.c_attn(x)
159
+ q, k, v = qkv.split(self.n_embd, dim=2)
160
+ k = k.view(B, T, self.n_head, self.head_dim)
161
+ q = q.view(B, T, self.n_head, self.head_dim)
162
+ v = v.view(B, T, self.n_head, self.head_dim)
163
+ cos, sin = self.rotary(q)
164
+ q = apply_rotary_emb(q, cos, sin)
165
+ k = apply_rotary_emb(k, cos, sin)
166
+ y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
167
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
168
+ # output projection
169
+ y = self.c_proj(y)
170
+ return y
171
+
172
+ class MLP(nn.Module):
173
+
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
177
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
178
+
179
+ def forward(self, x):
180
+ x = self.c_fc(x)
181
+ x = F.gelu(x)
182
+ x = self.c_proj(x)
183
+ return x
184
+
185
+ class Block(nn.Module):
186
+
187
+ def __init__(self, config):
188
+ super().__init__()
189
+ self.attn = CausalSelfAttention(config)
190
+ self.mlp = MLP(config)
191
+ self.attn_scale = (1 / (2 * config.n_layer)**0.5)
192
+
193
+ def forward(self, x):
194
+ x = x + self.attn_scale * self.attn(rmsnorm(x))
195
+ x = x + self.mlp(rmsnorm(x))
196
+ return x
197
+
198
+ # -----------------------------------------------------------------------------
199
+ # The main GPT-2 model
200
+
201
+ @dataclass
202
+ class GPTConfig:
203
+ vocab_size : int = 50257
204
+ n_layer : int = 12
205
+ n_head : int = 12
206
+ n_embd : int = 768
207
+
208
+ class GPT(nn.Module):
209
+
210
+ def __init__(self, config):
211
+ super().__init__()
212
+ self.config = config
213
+
214
+ self.transformer = nn.ModuleDict(dict(
215
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
216
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
217
+ ))
218
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
219
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
220
+
221
+ def forward(self, idx, targets=None, return_logits=True):
222
+ b, t = idx.size()
223
+ pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # shape (t)
224
+
225
+ # forward the GPT model itself
226
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
227
+
228
+ for block in self.transformer.h:
229
+ x = block(x)
230
+ x = rmsnorm(x)
231
+
232
+ if targets is not None:
233
+ # if we are given some desired targets also calculate the loss
234
+ logits = self.lm_head(x)
235
+ logits = logits.float() # use tf32/fp32 for logits
236
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
237
+ else:
238
+ # inference-time mini-optimization: only forward the lm_head on the very last position
239
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
240
+ logits = logits.float() # use tf32/fp32 for logits
241
+ loss = None
242
+
243
+ # there are performance reasons why not returning logits is prudent, if not needed
244
+ if not return_logits:
245
+ logits = None
246
+
247
+ return logits, loss
248
+
249
+ # -----------------------------------------------------------------------------
250
+ # Our own simple Distributed Data Loader
251
+
252
+ def _peek_data_shard(filename):
253
+ # only reads the header, returns header data
254
+ with open(filename, "rb") as f:
255
+ # first read the header, which is 256 int32 integers (4 bytes each)
256
+ header = np.frombuffer(f.read(256*4), dtype=np.int32)
257
+ if header[0] != 20240520:
258
+ print("ERROR: magic number mismatch in the data .bin file!")
259
+ print("---> HINT: Are you passing in a correct file with --input_bin?")
260
+ print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
261
+ print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
262
+ exit(1)
263
+ assert header[1] == 1, "unsupported version"
264
+ ntok = header[2] # number of tokens (claimed)
265
+ return ntok # for now just return the number of tokens
266
+
267
+ def _load_data_shard(filename):
268
+ with open(filename, "rb") as f:
269
+ # first read the header, which is 256 int32 integers (4 bytes each)
270
+ header = np.frombuffer(f.read(256*4), dtype=np.int32)
271
+ assert header[0] == 20240520, "magic number mismatch in the data .bin file"
272
+ assert header[1] == 1, "unsupported version"
273
+ ntok = header[2] # number of tokens (claimed)
274
+ # the rest of it are tokens, stored as uint16
275
+ tokens = np.frombuffer(f.read(), dtype=np.uint16)
276
+ assert len(tokens) == ntok, "number of tokens read does not match header?"
277
+ return tokens
278
+
279
+ class DistributedDataLoader:
280
+ def __init__(self, filename_pattern, B, T, process_rank, num_processes):
281
+ self.process_rank = process_rank
282
+ self.num_processes = num_processes
283
+ self.B = B
284
+ self.T = T
285
+
286
+ # glob files that match the pattern
287
+ self.files = sorted(glob.glob(filename_pattern))
288
+ assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
289
+
290
+ # load and validate all data shards, count number of tokens in total
291
+ ntok_total = 0
292
+ for fname in self.files:
293
+ shard_ntok = _peek_data_shard(fname)
294
+ assert shard_ntok >= num_processes * B * T + 1
295
+ ntok_total += int(shard_ntok)
296
+ self.ntok_total = ntok_total
297
+
298
+ # kick things off
299
+ self.reset()
300
+
301
+ def reset(self):
302
+ self.current_shard = 0
303
+ self.current_position = self.process_rank * self.B * self.T
304
+ self.tokens = _load_data_shard(self.files[self.current_shard])
305
+
306
+ def advance(self): # advance to next data shard
307
+ self.current_shard = (self.current_shard + 1) % len(self.files)
308
+ self.current_position = self.process_rank * self.B * self.T
309
+ self.tokens = _load_data_shard(self.files[self.current_shard])
310
+
311
+ def next_batch(self):
312
+ B = self.B
313
+ T = self.T
314
+ buf = self.tokens[self.current_position : self.current_position+B*T+1]
315
+ buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
316
+ x = (buf[:-1]).view(B, T) # inputs
317
+ y = (buf[1:]).view(B, T) # targets
318
+ # advance current position and load next shard if necessary
319
+ self.current_position += B * T * self.num_processes
320
+ if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
321
+ self.advance()
322
+ return x.cuda(), y.cuda()
323
+
324
+ # -----------------------------------------------------------------------------
325
+ # int main
326
+
327
+ @dataclass
328
+ class Hyperparameters:
329
+ # data hyperparams
330
+ input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
331
+ input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
332
+ # optimization hyperparams
333
+ batch_size : int = 8*64 # batch size, in sequences, across all devices
334
+ device_batch_size : int = 64 # batch size, in sequences, per device
335
+ sequence_length : int = 1024 # sequence length, in tokens
336
+ num_iterations : int = 6200 # number of iterations to run
337
+ learning_rate : float = 0.0036
338
+ warmup_iters : int = 0
339
+ warmdown_iters : int = 1800 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
340
+ weight_decay : float = 0
341
+ # evaluation and logging hyperparams
342
+ val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
343
+ val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
344
+ save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
345
+ args = Hyperparameters()
346
+
347
+ # set up DDP (distributed data parallel). torchrun sets this env variable
348
+ assert torch.cuda.is_available()
349
+ dist.init_process_group(backend='nccl')
350
+ ddp_rank = int(os.environ['RANK'])
351
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
352
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
353
+ device = f'cuda:{ddp_local_rank}'
354
+ torch.cuda.set_device(device)
355
+ print(f"using device: {device}")
356
+ master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
357
+
358
+ # convenience variables
359
+ B, T = args.device_batch_size, args.sequence_length
360
+ # calculate the number of steps to take in the val loop.
361
+ assert args.val_tokens % (B * T * ddp_world_size) == 0
362
+ val_steps = args.val_tokens // (B * T * ddp_world_size)
363
+ # calculate the steps of gradient accumulation required to attain the desired global batch size.
364
+ assert args.batch_size % (B * ddp_world_size) == 0
365
+ train_accumulation_steps = args.batch_size // (B * ddp_world_size)
366
+
367
+ # load tokens
368
+ train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
369
+ val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
370
+ if master_process:
371
+ print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
372
+ print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
373
+ x, y = train_loader.next_batch()
374
+
375
+ # init the model from scratch
376
+ num_vocab = 50257
377
+ model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=12, n_embd=768))
378
+ model = model.cuda()
379
+ if hasattr(config, "coordinate_descent_tuning"):
380
+ config.coordinate_descent_tuning = True # suggested by @Chillee
381
+ model = torch.compile(model)
382
+ # here we wrap model into DDP container
383
+ model = DDP(model, device_ids=[ddp_local_rank])
384
+ raw_model = model.module # always contains the "raw" unwrapped model
385
+ ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
386
+
387
+ # init the optimizer(s)
388
+ optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95),
389
+ weight_decay=args.weight_decay, fused=True)
390
+ optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95)
391
+ optimizers = [optimizer1, optimizer2]
392
+ # learning rate decay scheduler (linear warmup and warmdown)
393
+ def get_lr(it):
394
+ assert it <= args.num_iterations
395
+ # 1) linear warmup for warmup_iters steps
396
+ if it < args.warmup_iters:
397
+ return (it+1) / args.warmup_iters
398
+ # 2) constant lr for a while
399
+ elif it < args.num_iterations - args.warmdown_iters:
400
+ return 1.0
401
+ # 3) linear warmdown
402
+ else:
403
+ decay_ratio = (args.num_iterations - it) / args.warmdown_iters
404
+ return decay_ratio
405
+ schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
406
+
407
+ # begin logging
408
+ if master_process:
409
+ run_id = str(uuid.uuid4())
410
+ logdir = 'logs/%s/' % run_id
411
+ os.makedirs(logdir, exist_ok=True)
412
+ logfile = 'logs/%s.txt' % run_id
413
+ # create the log file
414
+ with open(logfile, "w") as f:
415
+ # begin the log by printing this file (the Python code)
416
+ f.write('='*100 + '\n')
417
+ f.write(code)
418
+ f.write('='*100 + '\n')
419
+ # log information about the hardware/software environment this is running on
420
+ # and print the full `nvidia-smi` to file
421
+ f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
422
+ import subprocess
423
+ result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
424
+ f.write(f'{result.stdout}\n')
425
+ f.write('='*100 + '\n')
426
+
427
+ training_time_ms = 0
428
+ # start the clock
429
+ torch.cuda.synchronize()
430
+ t0 = time.time()
431
+ # begin training
432
+ train_loader.reset()
433
+ for step in range(args.num_iterations + 1):
434
+ last_step = (step == args.num_iterations)
435
+ # This effectively ignores timing first 10 steps, which are slower for weird reasons.
436
+ # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
437
+ # steps with dummy data first, and then re-initialize the model and reset the loader.
438
+ if step == 10:
439
+ training_time_ms = 0
440
+ t0 = time.time()
441
+ timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
442
+
443
+ # once in a while evaluate the validation dataset
444
+ if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
445
+ # stop the clock
446
+ torch.cuda.synchronize()
447
+ training_time_ms += 1000 * (time.time() - t0)
448
+ # run validation batches
449
+ model.eval()
450
+ val_loader.reset()
451
+ val_loss = 0.0
452
+ for _ in range(val_steps):
453
+ x_val, y_val = val_loader.next_batch()
454
+ with torch.no_grad(): # of course, we'd like to use ctx here too, but that creates a torch.compile error for some reason
455
+ _, loss = model(x_val, y_val, return_logits=False)
456
+ val_loss += loss
457
+ dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
458
+ val_loss /= val_steps
459
+ # log val loss to console and to logfile
460
+ if master_process:
461
+ print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
462
+ with open(logfile, "a") as f:
463
+ f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
464
+ # start the clock again
465
+ torch.cuda.synchronize()
466
+ t0 = time.time()
467
+
468
+ if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
469
+ # stop the clock
470
+ torch.cuda.synchronize()
471
+ training_time_ms += 1000 * (time.time() - t0)
472
+ # save the state of the training process
473
+ log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
474
+ torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
475
+ # start the clock again
476
+ torch.cuda.synchronize()
477
+ t0 = time.time()
478
+
479
+ # bit confusing: we want to make sure to eval on 0th iteration
480
+ # but also after the very last iteration. so we loop for step <= num_iterations
481
+ # instead of just < num_iterations (one extra due to <=), only to do
482
+ # the validation/sampling one last time, and then we break right here as we're done.
483
+ if last_step:
484
+ break
485
+
486
+ # --------------- TRAINING SECTION BEGIN -----------------
487
+ model.train()
488
+ for i in range(1, train_accumulation_steps+1):
489
+ # forward pass
490
+ with ctx:
491
+ _, loss = model(x, y, return_logits=False)
492
+ train_loss = loss.detach()
493
+ # advance the dataset for the next batch
494
+ x, y = train_loader.next_batch()
495
+ # backward pass
496
+ if i < train_accumulation_steps:
497
+ with model.no_sync(): # there's no need to sync gradients every accumulation step
498
+ loss.backward()
499
+ else:
500
+ loss.backward() # just sync on the last step
501
+ for p in model.parameters():
502
+ p.grad /= train_accumulation_steps
503
+ # step the optimizers and schedulers
504
+ for opt, sched in zip(optimizers, schedulers):
505
+ opt.step()
506
+ sched.step()
507
+ # null the gradients
508
+ model.zero_grad(set_to_none=True)
509
+ # --------------- TRAINING SECTION END -------------------
510
+ # everything that follows now is just diagnostics, prints, logging, etc.
511
+
512
+ #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
513
+ if master_process:
514
+ approx_time = training_time_ms + 1000 * (time.time() - t0)
515
+ print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
516
+ with open(logfile, "a") as f:
517
+ f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")
518
+
519
+ if master_process:
520
+ print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
521
+
522
+ # -------------------------------------------------------------------------
523
+ # clean up nice
524
+ dist.destroy_process_group()
records/track_1_short/2024-10-13_llmc/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This is a log produced by running the current version of Andrej Karpathy's [llm.c](https://github.com/karpathy/llm.c), as of October 13th 2024.
2
+
3
+ It was run on a node with 8x H100 HBM3 according to the instructions [here](https://github.com/karpathy/llm.c/discussions/481).
4
+ The mean per-step time was 140ms. The total number of training tokens is 10.26B. The final validation loss was **3.2722**.
5
+
6
+ This is (significantly) better than the quoted result of **3.29** val loss in
7
+ [Andrej Karpathy's May 28th GPT-2 replication discussion](https://github.com/karpathy/llm.c/discussions/481#:~:text=By%20the%20end%20of%20the%20optimization%20we%27ll%20get%20to%20about%203.29).
8
+ So it appears that there have been some improvements to the training algorithm used by llm.c since then.
9
+
10
+ Note that the set of examples which llm.c uses for validation appears to be the same as what we do in this repo, i.e., the first `10 * 2**20` tokens of the val set.
11
+
records/track_1_short/2024-10-13_llmc/main.log ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-14_ModernArch/dabaaddd-237c-4ec9-939d-6608a9ed5e27.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-14_ModernArch/train_gpt2.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ with open(sys.argv[0]) as f:
4
+ code = f.read() # read the code of this file ASAP, for logging
5
+ import uuid
6
+ import glob
7
+ import time
8
+ from dataclasses import dataclass
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import torch.distributed as dist
15
+ import torch._inductor.config as config
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+
18
+ # -----------------------------------------------------------------------------
19
+ # Muon optimizer
20
+
21
+ def zeropower_via_svd(G, steps=None):
22
+ U, S, V = G.svd()
23
+ return U @ V.T
24
+
25
+ @torch.compile
26
+ def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
27
+ """
28
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
29
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
30
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
31
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
32
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
33
+ where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
34
+ performance at all relative to UV^T, where USV^T = G is the SVD.
35
+ """
36
+ assert len(G.shape) == 2
37
+ a, b, c = (3.4445, -4.7750, 2.0315)
38
+ X = G.bfloat16()
39
+ X /= (X.norm() + eps) # ensure top singular value <= 1
40
+ if G.size(0) > G.size(1):
41
+ X = X.T
42
+ for _ in range(steps):
43
+ A = X @ X.T
44
+ B = A @ X
45
+ X = a * X + b * B + c * A @ B
46
+ if G.size(0) > G.size(1):
47
+ X = X.T
48
+ return X
49
+
50
+ zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
51
+
52
+ class Muon(torch.optim.Optimizer):
53
+ """
54
+ Muon - MomentUm Orthogonalized by Newton-schulz
55
+
56
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
57
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
58
+ matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
59
+ the advantage that it can be stably run in bfloat16 on the GPU.
60
+
61
+ Some warnings:
62
+ - This optimizer assumes that all parameters passed in are 2D.
63
+ - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
64
+ parameters; those should all be optimized by a standard method (e.g., AdamW).
65
+ - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
66
+ - We believe it is unlikely to work well for training with small batch size.
67
+ - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
68
+ - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
69
+
70
+ Arguments:
71
+ lr: The learning rate used by the internal SGD.
72
+ momentum: The momentum used by the internal SGD.
73
+ nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
74
+ backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
75
+ backend_steps: The number of iteration steps to use in the backend, if it is iterative.
76
+ """
77
+ def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5):
78
+ defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
79
+ super().__init__(params, defaults)
80
+
81
+ def step(self):
82
+ for group in self.param_groups:
83
+ lr = group['lr']
84
+ momentum = group['momentum']
85
+ zeropower_backend = zeropower_backends[group['backend']]
86
+ for p in group['params']:
87
+ g = p.grad
88
+ if g is None:
89
+ continue
90
+ state = self.state[p]
91
+ if 'momentum_buffer' not in state:
92
+ state['momentum_buffer'] = torch.zeros_like(g)
93
+ buf = state['momentum_buffer']
94
+ buf.mul_(momentum).add_(g)
95
+ if group['nesterov']:
96
+ g = g.add(buf, alpha=momentum)
97
+ if g.size(0) == 3 * g.size(1): # split grouped QKV parameters
98
+ g = torch.cat([zeropower_backend(g1, steps=group['backend_steps']) for g1 in g.split(g.size(1))])
99
+ scale = g.size(1)**0.5
100
+ else:
101
+ g = zeropower_backend(g, steps=group['backend_steps'])
102
+ scale = max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1
103
+ p.data.add_(g, alpha=-lr * scale)
104
+
105
+ # -----------------------------------------------------------------------------
106
+ # PyTorch nn.Module definitions for the GPT-2 model
107
+
108
+ class Rotary(torch.nn.Module):
109
+
110
+ def __init__(self, dim, base=10000):
111
+ super().__init__()
112
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
113
+ self.seq_len_cached = None
114
+ self.cos_cached = None
115
+ self.sin_cached = None
116
+
117
+ def forward(self, x):
118
+ seq_len = x.shape[1]
119
+ if seq_len != self.seq_len_cached:
120
+ self.seq_len_cached = seq_len
121
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
122
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
123
+ self.cos_cached = freqs.cos().bfloat16()
124
+ self.sin_cached = freqs.sin().bfloat16()
125
+ return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
126
+
127
+ def apply_rotary_emb(x, cos, sin):
128
+ assert x.ndim == 4 # multihead attention
129
+ d = x.shape[3]//2
130
+ x1 = x[..., :d]
131
+ x2 = x[..., d:]
132
+ y1 = x1 * cos + x2 * sin
133
+ y2 = x1 * (-sin) + x2 * cos
134
+ return torch.cat([y1, y2], 3).type_as(x)
135
+
136
+ class CausalSelfAttention(nn.Module):
137
+
138
+ def __init__(self, config):
139
+ super().__init__()
140
+ self.n_head = config.n_head
141
+ self.n_embd = config.n_embd
142
+ self.head_dim = self.n_embd // self.n_head
143
+ assert self.n_embd % self.n_head == 0
144
+ self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
145
+ self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
146
+ self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
147
+ # output projection
148
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
149
+ self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
150
+ self.rotary = Rotary(self.head_dim)
151
+
152
+ def forward(self, x):
153
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
154
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
155
+ k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
156
+ v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
157
+ cos, sin = self.rotary(q)
158
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
159
+ q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977
160
+ y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
161
+ y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
162
+ y = self.c_proj(y)
163
+ return y
164
+
165
+ class MLP(nn.Module):
166
+
167
+ def __init__(self, config):
168
+ super().__init__()
169
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
170
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
171
+ self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
172
+
173
+ def forward(self, x):
174
+ x = self.c_fc(x)
175
+ x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
176
+ x = self.c_proj(x)
177
+ return x
178
+
179
+ class Block(nn.Module):
180
+
181
+ def __init__(self, config):
182
+ super().__init__()
183
+ self.attn = CausalSelfAttention(config)
184
+ self.mlp = MLP(config)
185
+
186
+ def forward(self, x):
187
+ x = x + self.attn(F.rms_norm(x, (x.size(-1),)))
188
+ x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
189
+ return x
190
+
191
+ # -----------------------------------------------------------------------------
192
+ # The main GPT-2 model
193
+
194
+ @dataclass
195
+ class GPTConfig:
196
+ vocab_size : int = 50304
197
+ n_layer : int = 12
198
+ n_head : int = 6 # head dim 128 suggested by @Grad62304977
199
+ n_embd : int = 768
200
+
201
+ class GPT(nn.Module):
202
+
203
+ def __init__(self, config):
204
+ super().__init__()
205
+ self.config = config
206
+
207
+ self.transformer = nn.ModuleDict(dict(
208
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
209
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
210
+ ))
211
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
212
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
213
+
214
+ def forward(self, idx, targets=None, return_logits=True):
215
+
216
+ # forward the GPT model itself
217
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
218
+ for block in self.transformer.h:
219
+ x = block(x)
220
+ x = F.rms_norm(x, (x.size(-1),))
221
+
222
+ if targets is not None:
223
+ # if we are given some desired targets also calculate the loss
224
+ logits = self.lm_head(x)
225
+ logits = logits.float() # use tf32/fp32 for logits
226
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
227
+ else:
228
+ # inference-time mini-optimization: only forward the lm_head on the very last position
229
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
230
+ logits = logits.float() # use tf32/fp32 for logits
231
+ loss = None
232
+
233
+ # there are performance reasons why not returning logits is prudent, if not needed
234
+ if not return_logits:
235
+ logits = None
236
+
237
+ return logits, loss
238
+
239
+ # -----------------------------------------------------------------------------
240
+ # Our own simple Distributed Data Loader
241
+
242
+ def _peek_data_shard(filename):
243
+ # only reads the header, returns header data
244
+ with open(filename, "rb") as f:
245
+ # first read the header, which is 256 int32 integers (4 bytes each)
246
+ header = np.frombuffer(f.read(256*4), dtype=np.int32)
247
+ if header[0] != 20240520:
248
+ print("ERROR: magic number mismatch in the data .bin file!")
249
+ print("---> HINT: Are you passing in a correct file with --input_bin?")
250
+ print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
251
+ print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
252
+ exit(1)
253
+ assert header[1] == 1, "unsupported version"
254
+ ntok = header[2] # number of tokens (claimed)
255
+ return ntok # for now just return the number of tokens
256
+
257
+ def _load_data_shard(filename):
258
+ with open(filename, "rb") as f:
259
+ # first read the header, which is 256 int32 integers (4 bytes each)
260
+ header = np.frombuffer(f.read(256*4), dtype=np.int32)
261
+ assert header[0] == 20240520, "magic number mismatch in the data .bin file"
262
+ assert header[1] == 1, "unsupported version"
263
+ ntok = header[2] # number of tokens (claimed)
264
+ # the rest of it are tokens, stored as uint16
265
+ tokens = np.frombuffer(f.read(), dtype=np.uint16)
266
+ assert len(tokens) == ntok, "number of tokens read does not match header?"
267
+ return tokens
268
+
269
+ class DistributedDataLoader:
270
+ def __init__(self, filename_pattern, B, T, process_rank, num_processes):
271
+ self.process_rank = process_rank
272
+ self.num_processes = num_processes
273
+ self.B = B
274
+ self.T = T
275
+
276
+ # glob files that match the pattern
277
+ self.files = sorted(glob.glob(filename_pattern))
278
+ assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
279
+
280
+ # load and validate all data shards, count number of tokens in total
281
+ ntok_total = 0
282
+ for fname in self.files:
283
+ shard_ntok = _peek_data_shard(fname)
284
+ assert shard_ntok >= num_processes * B * T + 1
285
+ ntok_total += int(shard_ntok)
286
+ self.ntok_total = ntok_total
287
+
288
+ # kick things off
289
+ self.reset()
290
+
291
+ def reset(self):
292
+ self.current_shard = 0
293
+ self.current_position = self.process_rank * self.B * self.T
294
+ self.tokens = _load_data_shard(self.files[self.current_shard])
295
+
296
+ def advance(self): # advance to next data shard
297
+ self.current_shard = (self.current_shard + 1) % len(self.files)
298
+ self.current_position = self.process_rank * self.B * self.T
299
+ self.tokens = _load_data_shard(self.files[self.current_shard])
300
+
301
+ def next_batch(self):
302
+ B = self.B
303
+ T = self.T
304
+ buf = self.tokens[self.current_position : self.current_position+B*T+1]
305
+ buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
306
+ x = (buf[:-1]).view(B, T) # inputs
307
+ y = (buf[1:]).view(B, T) # targets
308
+ # advance current position and load next shard if necessary
309
+ self.current_position += B * T * self.num_processes
310
+ if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
311
+ self.advance()
312
+ return x.cuda(), y.cuda()
313
+
314
+ # -----------------------------------------------------------------------------
315
+ # int main
316
+
317
+ @dataclass
318
+ class Hyperparameters:
319
+ # data hyperparams
320
+ input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
321
+ input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
322
+ # optimization hyperparams
323
+ batch_size : int = 8*64 # batch size, in sequences, across all devices
324
+ device_batch_size : int = 64 # batch size, in sequences, per device
325
+ sequence_length : int = 1024 # sequence length, in tokens
326
+ num_iterations : int = 5100 # number of iterations to run
327
+ learning_rate : float = 0.0036
328
+ warmup_iters : int = 0
329
+ warmdown_iters : int = 1450 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule
330
+ weight_decay : float = 0
331
+ # evaluation and logging hyperparams
332
+ val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
333
+ val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
334
+ save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
335
+ args = Hyperparameters()
336
+
337
+ # set up DDP (distributed data parallel). torchrun sets this env variable
338
+ assert torch.cuda.is_available()
339
+ dist.init_process_group(backend='nccl')
340
+ ddp_rank = int(os.environ['RANK'])
341
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
342
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
343
+ device = f'cuda:{ddp_local_rank}'
344
+ torch.cuda.set_device(device)
345
+ print(f"using device: {device}")
346
+ master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
347
+
348
+ # convenience variables
349
+ B, T = args.device_batch_size, args.sequence_length
350
+ # calculate the number of steps to take in the val loop.
351
+ assert args.val_tokens % (B * T * ddp_world_size) == 0
352
+ val_steps = args.val_tokens // (B * T * ddp_world_size)
353
+ # calculate the steps of gradient accumulation required to attain the desired global batch size.
354
+ assert args.batch_size % (B * ddp_world_size) == 0
355
+ train_accumulation_steps = args.batch_size // (B * ddp_world_size)
356
+
357
+ # load tokens
358
+ train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
359
+ val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
360
+ if master_process:
361
+ print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
362
+ print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
363
+ x, y = train_loader.next_batch()
364
+
365
+ # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
366
+ # this originates from Karpathy's experiments.
367
+ num_vocab = 50304
368
+ model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
369
+ model = model.cuda()
370
+ if hasattr(config, "coordinate_descent_tuning"):
371
+ config.coordinate_descent_tuning = True # suggested by @Chillee
372
+ model = torch.compile(model)
373
+ # here we wrap model into DDP container
374
+ model = DDP(model, device_ids=[ddp_local_rank])
375
+ raw_model = model.module # always contains the "raw" unwrapped model
376
+ ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
377
+
378
+ # init the optimizer(s)
379
+ optimizer1 = torch.optim.AdamW(raw_model.lm_head.parameters(), lr=args.learning_rate, betas=(0.9, 0.95),
380
+ weight_decay=args.weight_decay, fused=True)
381
+ optimizer2 = Muon(raw_model.transformer.h.parameters(), lr=0.1*args.learning_rate, momentum=0.95)
382
+ optimizers = [optimizer1, optimizer2]
383
+ # learning rate decay scheduler (linear warmup and warmdown)
384
+ def get_lr(it):
385
+ assert it <= args.num_iterations
386
+ # 1) linear warmup for warmup_iters steps
387
+ if it < args.warmup_iters:
388
+ return (it+1) / args.warmup_iters
389
+ # 2) constant lr for a while
390
+ elif it < args.num_iterations - args.warmdown_iters:
391
+ return 1.0
392
+ # 3) linear warmdown
393
+ else:
394
+ decay_ratio = (args.num_iterations - it) / args.warmdown_iters
395
+ return decay_ratio
396
+ schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
397
+
398
+ # begin logging
399
+ if master_process:
400
+ run_id = str(uuid.uuid4())
401
+ logdir = 'logs/%s/' % run_id
402
+ os.makedirs(logdir, exist_ok=True)
403
+ logfile = 'logs/%s.txt' % run_id
404
+ # create the log file
405
+ with open(logfile, "w") as f:
406
+ # begin the log by printing this file (the Python code)
407
+ f.write('='*100 + '\n')
408
+ f.write(code)
409
+ f.write('='*100 + '\n')
410
+ # log information about the hardware/software environment this is running on
411
+ # and print the full `nvidia-smi` to file
412
+ f.write(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n")
413
+ import subprocess
414
+ result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
415
+ f.write(f'{result.stdout}\n')
416
+ f.write('='*100 + '\n')
417
+
418
+ training_time_ms = 0
419
+ # start the clock
420
+ torch.cuda.synchronize()
421
+ t0 = time.time()
422
+ # begin training
423
+ train_loader.reset()
424
+ for step in range(args.num_iterations + 1):
425
+ last_step = (step == args.num_iterations)
426
+ # This effectively ignores timing first 10 steps, which are slower for weird reasons.
427
+ # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
428
+ # steps with dummy data first, and then re-initialize the model and reset the loader.
429
+ if step == 10:
430
+ training_time_ms = 0
431
+ t0 = time.time()
432
+ timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
433
+
434
+ # once in a while evaluate the validation dataset
435
+ if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
436
+ # stop the clock
437
+ torch.cuda.synchronize()
438
+ training_time_ms += 1000 * (time.time() - t0)
439
+ # run validation batches
440
+ model.eval()
441
+ val_loader.reset()
442
+ val_loss = 0.0
443
+ for _ in range(val_steps):
444
+ x_val, y_val = val_loader.next_batch()
445
+ with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason
446
+ _, loss = model(x_val, y_val, return_logits=False)
447
+ val_loss += loss.detach()
448
+ del loss
449
+ dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
450
+ val_loss /= val_steps
451
+ # log val loss to console and to logfile
452
+ if master_process:
453
+ print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
454
+ with open(logfile, "a") as f:
455
+ f.write(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n')
456
+ # start the clock again
457
+ torch.cuda.synchronize()
458
+ t0 = time.time()
459
+
460
+ if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
461
+ # stop the clock
462
+ torch.cuda.synchronize()
463
+ training_time_ms += 1000 * (time.time() - t0)
464
+ # save the state of the training process
465
+ log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
466
+ torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
467
+ # start the clock again
468
+ torch.cuda.synchronize()
469
+ t0 = time.time()
470
+
471
+ # bit confusing: we want to make sure to eval on 0th iteration
472
+ # but also after the very last iteration. so we loop for step <= num_iterations
473
+ # instead of just < num_iterations (one extra due to <=), only to do
474
+ # the validation/sampling one last time, and then we break right here as we're done.
475
+ if last_step:
476
+ break
477
+
478
+ # --------------- TRAINING SECTION BEGIN -----------------
479
+ model.train()
480
+ for i in range(1, train_accumulation_steps+1):
481
+ # forward pass
482
+ with ctx:
483
+ _, loss = model(x, y, return_logits=False)
484
+ train_loss = loss.detach()
485
+ # advance the dataset for the next batch
486
+ x, y = train_loader.next_batch()
487
+ # backward pass
488
+ if i < train_accumulation_steps:
489
+ with model.no_sync(): # there's no need to sync gradients every accumulation step
490
+ loss.backward()
491
+ else:
492
+ loss.backward() # just sync on the last step
493
+ for p in model.parameters():
494
+ p.grad /= train_accumulation_steps
495
+ # step the optimizers and schedulers
496
+ for opt, sched in zip(optimizers, schedulers):
497
+ opt.step()
498
+ sched.step()
499
+ # null the gradients
500
+ model.zero_grad(set_to_none=True)
501
+ # --------------- TRAINING SECTION END -------------------
502
+ # everything that follows now is just diagnostics, prints, logging, etc.
503
+
504
+ #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
505
+ if master_process:
506
+ approx_time = training_time_ms + 1000 * (time.time() - t0)
507
+ print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
508
+ with open(logfile, "a") as f:
509
+ f.write(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n")
510
+
511
+ if master_process:
512
+ print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
513
+
514
+ # -------------------------------------------------------------------------
515
+ # clean up nice
516
+ dist.destroy_process_group()
records/track_1_short/2024-10-17_DistributedMuon/22d24867-eb5a-4fcc-ae2c-263d0277dfd1.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-18_PyTorch25/d4bfb25f-688d-4da5-8743-33926fad4842.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-20_ScaleUp1B/87bd51fd-6203-4c88-b3aa-8a849a6a83ca.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-20_ScaleUp1B/ad8d7ae5-7b2d-4ee9-bc52-f912e9174d7a.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-20_ScaleUp1B/c0078066-c8c9-49c8-868a-ff4d4f32e615.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-29_Optimizers/8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-29_Optimizers/8d6193f4-27fc-4e68-899f-af70019a4d54.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-29_Optimizers/95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-29_Optimizers/README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Optimizer comparison for NanoGPT speedrunning
2
+
3
+ This is a comparison between the four best optimizers I am aware of for NanoGPT speedrunning. They are compared using the 10/18/24 NanoGPT speedrunning record.
4
+
5
+ Reproducible logs:
6
+ * [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt)
7
+ * [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3)
8
+ * [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt)
9
+ * [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt)
10
+
11
+ Results:
12
+ ![1](nanogpt_speedrun81w.png)
13
+ ![2](nanogpt_speedrun82w.png)
14
+
15
+ ### General notes for all optimizers
16
+
17
+ All optimizers are run using zero weight decay (which is found to be empirically optimal).
18
+
19
+ And they are all run with a warmup-stable-decay / trapezoidal schedule, which also seems to be optimal. That's what causes the kink in the loss curve ~75% of the way to the end.
20
+
21
+ In addition, in all cases, we optimize the shared embedding/head layer just using Adam (which is also found to be empirically optimal).
22
+ Note that in the following code snippets, `raw_model.transformer.h.parameters()` gives all parameters besides those two.
23
+
24
+ In each case, the hyperparameters are the best ones I could find in around 20 attempts.
25
+
26
+ ## [Adam](95a9fd44-7c13-49c7-b324-3e7d9e23a499.txt)
27
+ The optimizer here is equivalent to:
28
+ ```
29
+ torch.optim.Adam(raw_model.transformer.h.parameters(), lr=0.0018, betas=(0.9, 0.95))
30
+ ```
31
+
32
+
33
+ ## [DistributedShampoo](8bfe4e35-c3fc-4b70-a984-3be937b71ff3.txt)
34
+ Run as follows:
35
+ ```
36
+ DistributedShampoo(
37
+ raw_model.transformer.h.parameters(),
38
+ lr=0.0018,
39
+ betas=(0.95, 0.95),
40
+ epsilon=1e-12,
41
+ weight_decay=0,
42
+ max_preconditioner_dim=8192,
43
+ precondition_frequency=10,
44
+ use_decoupled_weight_decay=True,
45
+ grafting_config=AdamGraftingConfig(
46
+ beta2=0.95,
47
+ epsilon=1e-8,
48
+ ),
49
+ distributed_config=DDPShampooConfig(
50
+ communication_dtype=CommunicationDType.FP32,
51
+ num_trainers_per_group=8,
52
+ communicate_params=False,
53
+ ),
54
+ )
55
+ ```
56
+
57
+ This is using the official `DistributedShampoo` implementation from [here](https://github.com/facebookresearch/optimizers/tree/ad2809a291c01859f68fcabbcb49a2aa75fd7827/distributed_shampoo).
58
+
59
+ Things that turned out to be important:
60
+ * Don't use epsilon above 1e-8; this loses performance. Epsilon 1e-12 performs as well as 1e-15
61
+ * Betas=(0.95, 0.95) seemed optimal, which turns out to be the same thing that SOAP uses
62
+ * Higher preconditioner update frequency is better but slower
63
+
64
+ I'm open to hyperparameter suggestions; the experiment takes ~20-30 minutes to run on a fresh 8xH100 instance, so it's not hard for me to run more attempts.
65
+
66
+
67
+ ## [SOAP](e21a2838-a0f2-46f2-a247-db0021165682.txt)
68
+ ```
69
+ SOAP(model.transformer.h.parameters(), lr=0.0018, betas=(.95, .95), precondition_frequency=10)
70
+ ```
71
+
72
+ This is using the official SOAP implementation [here](https://github.com/nikhilvyas/SOAP/blob/bbce86e890d3b697380f4376acb600c2d6c3d203/soap.py).
73
+
74
+ Based on conversations with the authors, it is likely that a future SOAP implementation will significantly reduce the wallclock overhead.
75
+
76
+
77
+ ## [Muon](8d6193f4-27fc-4e68-899f-af70019a4d54.txt)
78
+ ```
79
+ Muon(raw_model.transformer.h.parameters(), lr=0.02, momentum=0.95)
80
+ ```
81
+
82
+
83
+ ## Openness
84
+
85
+ These training logs are reproducible (just cut out the part besides the code, and run it using the `run.sh` in the top-level folder). They take 12-25 minutes to run.
86
+
87
+ I tried to do a good job sweeping the hyperparameters for each optimizer, but I can easily have missed something, or just not have performed enough runs.
88
+
89
+ Therefore, I am interested in any better hyperparameter settings which other researchers can find, for any of the optimizers.
90
+ If you post or send me your own reproducible log with one of these optimizers, I will be very happy to boost it in any way I can.
91
+
92
+ ## Appendix: Negative results
93
+
94
+ I believe it was Shazeer who said something like "negative results in machine learning are not worth much, because your inability to make something work doesn't prove that it can't work"
95
+
96
+ Given that disclaimer, here are some optimizers that I tried to make work, but was unable to get a significant boost over Adam with:
97
+ * Sophia
98
+ * Lion
99
+ * AdamWScheduleFree
100
+ * AdEmaMix (actually this was slightly better than Adam, just not enough to get near competing with the three Shampoo-like optimizers)
101
+
102
+ Of course, this is just for NanoGPT speedrunning (short train duration); it's quite possible they work better at longer training duration or for larger models.
103
+
records/track_1_short/2024-10-29_Optimizers/e21a2838-a0f2-46f2-a247-db0021165682.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-10-29_Optimizers/nanogpt_speedrun81w.png ADDED

Git LFS Details

  • SHA256: 7fbb61bc47aa8dec2cf5760171d9bdbd76f8814dda162e140653135a8ed6b7d3
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
records/track_1_short/2024-10-29_Optimizers/nanogpt_speedrun82w.png ADDED

Git LFS Details

  • SHA256: 91c8b5af26b51c8bcb912ebd952ad4824d6bef0c51df05d264d51bc5444150e1
  • Pointer size: 131 Bytes
  • Size of remote file: 170 kB
records/track_1_short/2024-11-03_UntieEmbed/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # New record 11/03/24
2
+
3
+
4
+ New NanoGPT training speed record: 3.28 FineWeb val loss in 10.8 minutes on 8xH100
5
+
6
+ Previous record: 12.0 minutes
7
+ Changelog:
8
+ - untied embed and head weights
9
+ - added RMSNorm after embed
10
+ - init head to zero
11
+
12
+ Driven by @Grad62304977
13
+
14
+ ---
15
+
16
+ Technically, this is somewhat of an "any%" record, since untying the embedding and lm_head adds 39M parameters.
17
+
18
+ However, it doesn't change the number of active parameters or the inference throughput. Future records will stay constrained to 124M active parameters.
19
+
20
+ ---
21
+
22
+ Like the last architectural change, this record was driven by @Grad62304977. I just finetuned some things and did bookkeeping.
23
+
24
+ ---
25
+
26
+ Shoutout to @cloneofsimo whose scaling guide already suggests initializing the head to zero. This works quite well and is a significant fraction of the record.
27
+
records/track_1_short/2024-11-03_UntieEmbed/d6b50d71-f419-4d26-bb39-a60d55ae7a04.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-11-04_50Bruns/3d715d41-453a-40d6-9506-421ba69766b2.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-11-04_50Bruns/4fbe61ec-f79a-4c19-836d-46d599deecce.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-11-04_50Bruns/530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-11-04_50Bruns/69c33fc9-eabb-4a38-aa08-6922914eb405.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-11-04_50Bruns/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 50B-token runs
2
+
3
+ This folder contains four runs generated by extending the 11/03/24 speedrun record to 50B FineWeb tokens.
4
+ The goal is to test how the speedrun generalizes to long durations, and especially how well Muon does.
5
+
6
+ We compare two things:
7
+ 1. We compare Muon to Adam as the optimizer for the transformer body. (The head and embedding are always optimized by Adam.)
8
+ 2. We compare training on 5 epochs of 10B tokens to training on 50B tokens. (Surprisingly this does about the same)
9
+
10
+ The four resulting runs are as follows:
11
+
12
+ * [Muon 50B tokens](./530f3ee1-8862-4d21-be2b-da10eb05e6a9.txt) (HellaSwag=35.82)
13
+ * [Adam 50B tokens](./69c33fc9-eabb-4a38-aa08-6922914eb405.txt) (HellaSwag=34.26)
14
+ * [Muon 5x10B tokens](./4fbe61ec-f79a-4c19-836d-46d599deecce.txt) (HellaSwag=36.17)
15
+ * [Adam 5x10B tokens](./3d715d41-453a-40d6-9506-421ba69766b2.txt) (HellaSwag=34.05)
16
+
17
+ To get a sense of what a good HellaSwag score would be for this scale of model, here are some baselines:
18
+ * Karpathy's baseline llm.c training (trained for 10B FineWeb tokens): 29.9
19
+ * OpenAI GPT-2 (124M): 29.4
20
+ * OpenAI GPT-3 (124M) (trained for 300B WebText tokens): 33.7
21
+ * Huggingface SmolLM2-135M (trained for 2T FineWeb/DCLM/etc tokens): 42.1
22
+
23
+ Note: I'm a little concerned that the learning rate schedule (WSD) and weight decay (zero), which are tuned for the speedrun duration,
24
+ might become undertuned/suboptimal for trainings of this duration.
25
+ It does look like the gap between Muon/Adam is too large to be closed by something like this, and the HellaSwag scores look quite reasonable, but you never know.
26
+
records/track_1_short/2024-11-06_ShortcutsTweaks/042f9e87-07e6-4504-bb04-4ec59a380211.txt ADDED
The diff for this file is too large to render. See raw diff
 
records/track_1_short/2024-11-06_ShortcutsTweaks/05b29e54-0be0-4a0f-a1e2-7d5317daedd3.txt ADDED
The diff for this file is too large to render. See raw diff