iamrahulreddy commited on
Commit
4fc1bb9
·
verified ·
1 Parent(s): cbe6941

release: publish Quintus project files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ assets/benchmark_scoreboard.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Muskula Rahul
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,94 +1,326 @@
1
  ---
2
  license: mit
3
  language:
4
- - en
 
 
 
 
 
 
 
 
 
 
5
  tags:
6
- - text-generation
7
- - conversational
8
- - qwen3
9
- - knowledge-distillation
10
- base_model: Qwen/Qwen3-1.7B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- # Quintus 1.7B
14
 
15
- **Quintus-1.7B** is a compact, instruction-following AI assistant, Built on the **Qwen3-1.7B** architecture, Quintus bridges the gap between massive parameter sizes and low-resource edge deployment through a two-stage training paradigm: Online Knowledge Distillation (KD) followed by Supervised Fine-Tuning (SFT).
 
 
 
 
 
 
16
 
17
- The final model weights are publicly available on Hugging Face: [iamrahulreddy/Quintus](https://huggingface.co/iamrahulreddy/Quintus).
 
 
 
18
 
19
- ## Model Details
20
- - **Architecture**: Qwen3-1.7B
21
- - **Teacher Model**: Qwen3-8B-Instruct
22
- - **Training Paradigm**: Online Full-Vocab Knowledge Distillation + Targeted Persona SFT
23
- - **Language**: English
24
- - **License**: MIT
25
 
26
- ## Core Methodology & Architecture
27
 
28
- The Quintus pipeline implements two primary phases to overcome the performance limitations of compact base models without standard SFT dataset scaling limits:
 
 
 
 
 
 
 
 
 
29
 
30
- 1. **Online Knowledge Distillation (KD)**: Rather than caching teacher logits offline, the Quintus engine streams the 8B teacher's full-vocabulary probability distribution live during the student's forward pass.
31
- 2. **Targeted Persona SFT**: A final fine-tuning phase on LIMA and identity data grounds the model's persona and prevents infinite repetition loops.
32
 
33
- ## Dataset & Training Details
34
 
35
- - **Training Dataset**: Fine-tuned using the [DistilQwen_100k](https://huggingface.co/datasets/alibaba-pai/DistilQwen_100k) dataset. Approximately 90,000 instruction-following examples were used after filtering out non-English (Chinese, Japanese, Korean) samples.
36
- - **High-Throughput Optimizations**:
37
- - **Sequence Packing**: Dense sequence packing utilizing a First-Fit Decreasing (FFD) binning algorithm to eliminate VRAM waste from padding.
38
- - **Memory & Compute Kernels**: Accelerated gradient computations using **FlashAttention-2** and **Liger Kernels** (fused operators).
39
- - **Optimizer**: Fused AdamW optimizer configuration for faster, memory-efficient weight updates.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  ## Benchmark Scoreboard
42
 
43
- Quintus 1.7B demonstrates a crossover phenomenon, successfully outperforming the official instruction-tuned `Qwen3-1.7B-Instruct` model on multiple reasoning and coding tasks.
 
44
 
45
- | Benchmark | Qwen3-1.7B-Base | Qwen3-1.7B-Instruct | **Quintus 1.7B** |
46
- | :--- | :---: | :---: | :---: |
47
- | **HumanEval** pass@1 | 67.1% | **70.7%** | 67.7% |
48
- | **MBPP** pass@1 | 67.2% | 58.2% | **64.8%** |
49
- | **GSM8K** (10-shot, flexible) | 69.98% | 69.75% | **74.30%** |
50
- | **ARC-Challenge** acc_norm | 55.72% | 52.99% | **58.36%** |
51
- | **WinoGrande** (5-shot) | 65.67% | 61.01% | **66.38%** |
52
- | **PIQA** acc_norm | 75.63% | 72.09% | **75.57%** |
53
 
54
- ## Usage: Quick Run in Google Colab / CLI
 
 
55
 
56
- You can easily run Quintus interactively. The following script sets up a conversational loop with streaming text output, perfect for Google Colab or a local terminal.
 
 
 
57
 
58
- Make sure you have the required libraries installed:
59
- ```python
60
- # Install if necessary - pip install torch transformers accelerate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  import torch
63
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
64
 
65
- # Repo name
66
  PUBLIC_REPO_ID = "iamrahulreddy/Quintus"
67
 
68
  print(f"Loading Quintus from {PUBLIC_REPO_ID}...")
69
  tokenizer = AutoTokenizer.from_pretrained(PUBLIC_REPO_ID, trust_remote_code=True)
70
  model = AutoModelForCausalLM.from_pretrained(
71
- PUBLIC_REPO_ID,
72
- device_map="auto",
73
- dtype=torch.float16,
74
- trust_remote_code=True
75
  )
76
 
77
- # Stopping criteria
78
  stop_tokens = ["<|endoftext|>", "<|im_end|>"]
79
  eos_token_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
80
  for token in stop_tokens:
81
- t_id = tokenizer.convert_tokens_to_ids(token)
82
- if t_id is not None and t_id not in eos_token_ids:
83
- eos_token_ids.append(t_id)
84
 
85
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
86
 
87
  conversation_history = [
88
- {"role": "system", "content": "You are Quintus, a highly capable AI assistant created by Muskula Rahul. You are helpful, precise, and logically sound."}
 
 
 
 
 
 
89
  ]
90
 
91
- print("\nQuintus Chat (type 'quit' to exit)\n")
 
 
92
 
93
  while True:
94
  try:
@@ -100,17 +332,17 @@ while True:
100
  continue
101
 
102
  conversation_history.append({"role": "user", "content": user_input})
103
-
104
  prompt = tokenizer.apply_chat_template(
105
- conversation_history,
106
- tokenize=False,
107
- add_generation_prompt=True
108
  )
109
-
110
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
111
-
112
  print("Quintus: ", end="", flush=True)
113
-
114
  with torch.no_grad():
115
  outputs = model.generate(
116
  **inputs,
@@ -120,15 +352,87 @@ while True:
120
  do_sample=True,
121
  streamer=streamer,
122
  pad_token_id=tokenizer.eos_token_id,
123
- eos_token_id=eos_token_ids
124
  )
125
-
126
- # Extract response for history
127
  generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
128
- assistant_response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
 
 
 
129
  conversation_history.append({"role": "assistant", "content": assistant_response})
130
  print()
131
-
132
  except KeyboardInterrupt:
133
  print("\n\nGoodbye!")
134
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  language:
4
+ - en
5
+ library_name: transformers
6
+ pipeline_tag: text-generation
7
+ base_model: Qwen/Qwen3-1.7B-Base
8
+ base_model_relation: finetune
9
+ datasets:
10
+ - alibaba-pai/DistilQwen_100k
11
+ metrics:
12
+ - accuracy
13
+ - exact_match
14
+ - code_eval
15
  tags:
16
+ - qwen3
17
+ - qwen
18
+ - qwen3-1.7b
19
+ - qwen3-8b
20
+ - quintus
21
+ - quintus-1.7b
22
+ - causal-lm
23
+ - text-generation
24
+ - language-model
25
+ - chat
26
+ - assistant
27
+ - compact-llm
28
+ - small-language-model
29
+ - knowledge-distillation
30
+ - online-kd
31
+ - full-vocabulary-kd
32
+ - supervised-fine-tuning
33
+ - sft
34
+ - reasoning
35
+ - code-generation
36
+ - english
37
+ - pytorch
38
+ - transformers
39
+ - vllm
40
+ widget:
41
+ - text: "Explain knowledge distillation in simple terms."
42
+ - text: "Solve this step by step: If a train travels 180 km in 3 hours, what is its average speed?"
43
  ---
44
 
45
+ # Quintus
46
 
47
+ [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TdMSN5HzD1mToCFVf_qQoj10NGZLy2V0?usp=sharing)
48
+ [![Hugging Face Model](https://img.shields.io/badge/Hugging%20Face-Quintus-ffcc4d?style=flat-square&logo=huggingface&logoColor=yellow)](https://huggingface.co/iamrahulreddy/Quintus)
49
+ [![Docs](https://img.shields.io/badge/Docs-Project%20Guide-0f766e?style=flat-square&logo=googledocs&logoColor=white)](docs/index.md)
50
+ [![Benchmarks](https://img.shields.io/badge/Benchmarks-Scoreboard-2563eb?style=flat-square&logo=speedtest&logoColor=white)](docs/benchmarks.md)
51
+ [![License: MIT](https://img.shields.io/badge/License-MIT-111827?style=flat-square)](LICENSE)
52
+ [![Base Model](https://img.shields.io/badge/Base-Qwen3--1.7B--Base-7c3aed?style=flat-square)](https://huggingface.co/Qwen/Qwen3-1.7B-Base)
53
+ [![Teacher](https://img.shields.io/badge/Teacher-Qwen3--8B-b45309?style=flat-square)](https://huggingface.co/Qwen/Qwen3-8B)
54
 
55
+ **Quintus-1.7B** is a compact English-focused assistant built from
56
+ `Qwen/Qwen3-1.7B-Base`. The project uses **online full-vocabulary knowledge
57
+ distillation** from a `Qwen/Qwen3-8B` teacher, followed by a targeted SFT stage
58
+ for assistant behavior, identity grounding, and generation stability.
59
 
60
+ Final model weights:
61
+ [iamrahulreddy/Quintus](https://huggingface.co/iamrahulreddy/Quintus)
 
 
 
 
62
 
63
+ ## Core Technical Points
64
 
65
+ - **Dense KD signal:** the final training path streams the teacher's full
66
+ vocabulary distribution live instead of relying on sparse cached top-k logits.
67
+ - **Base-student strategy:** the student starts from `Qwen/Qwen3-1.7B-Base`,
68
+ leaving more room for distillation before assistant-format tuning.
69
+ - **Assistant-only supervision:** prompt text, chat headers, separators, and
70
+ padding are masked out of the supervised target region.
71
+ - **Sequence packing:** deterministic first-fit decreasing packing improves
72
+ useful-token throughput at 4096-token context length.
73
+ - **Public benchmark controls:** raw/chat prompt format, metric extraction,
74
+ generation budget, and artifact hygiene are documented explicitly.
75
 
76
+ ## Training Summary
 
77
 
78
+ The release training path is a two-stage pipeline:
79
 
80
+ 1. **Online KD:** train the 1.7B base student against live teacher logits from a
81
+ Qwen3-8B teacher.
82
+ 2. **Targeted SFT:** tune the distilled checkpoint for assistant-style
83
+ interaction, persona consistency, and repetition control.
84
+
85
+ ## Reuse As A KD Framework
86
+
87
+ Quintus is released as a trained 1.7B assistant, but the repository is also a
88
+ reusable reference pipeline for compact-model distillation. The same structure
89
+ can be adapted to other teacher/student pairs with changes to the model IDs,
90
+ tokenizer, dataset source, local paths, sequence length, batch schedule, and
91
+ hardware-specific memory settings in [configs/config.yaml](configs/config.yaml).
92
+
93
+ The reusable pieces are split across the codebase: assistant-only masking,
94
+ sequence packing, online full-vocabulary KD loss, checkpoint/resume metadata,
95
+ validation, provenance checks, SFT, and evaluation. The final pattern is:
96
+
97
+ 1. Distill a smaller base student from a stronger teacher with online KD.
98
+ 2. Apply targeted SFT to recover assistant behavior, formatting, identity, and
99
+ generation stability.
100
+
101
+ ![Quintus Architecture](assets/quintus_architecture.svg)
102
+
103
+ Core KD objective:
104
+
105
+ $$
106
+ \mathcal{L}_{\text{total}}
107
+ = \alpha \mathcal{L}_{\text{CE}}
108
+ + (1 - \alpha)\mathcal{L}_{\text{KD}}
109
+ $$
110
+
111
+ For the final run,
112
+
113
+ $$
114
+ \alpha = 0.3,\quad T = 2.0
115
+ $$
116
+
117
+ Configuration snapshot:
118
+
119
+ | Setting | Value |
120
+ | :--- | :--- |
121
+ | Teacher | `Qwen/Qwen3-8B` |
122
+ | Student | `Qwen/Qwen3-1.7B-Base` |
123
+ | Tokenizer | `Qwen/Qwen3-1.7B` |
124
+ | Data | ~90K English-only samples from DistilQwen_100k |
125
+ | Max sequence length | 4096 |
126
+ | Epochs | 1 |
127
+ | Learning rate | `5.0e-6` |
128
+ | Weight decay | `0.1` |
129
+ | Warmup ratio | `0.05` |
130
+ | Online KD token chunk | 2048 |
131
+ | Micro batch | 4 |
132
+ | Gradient accumulation | 2 |
133
+ | Sequence packing | enabled, `pack_length = 4096` |
134
+ | Attention | FlashAttention-2 when available |
135
+ | Liger kernels | enabled for compatible Qwen-family ops |
136
+ | Optimizer | fused AdamW |
137
+ | `torch.compile` | disabled |
138
+ | Gradient checkpointing | disabled |
139
+ | Seed | 25 |
140
+
141
+ > [!NOTE]
142
+ > FlashAttention-2, Liger kernels, and fused AdamW are acceleration paths. Keep
143
+ > the baseline load path compatible with standard Transformers and vLLM APIs
144
+ > before publishing checkpoints. `torch.compile` stayed disabled because this
145
+ > KD shape showed high Inductor memory overhead, dynamic-shape graph breaks,
146
+ > recompile overhead, and checkpoint portability risk from `_orig_mod.` state
147
+ > dict prefixes when compiled modules are not unwrapped before saving.
148
+
149
+ > [!TIP]
150
+ > The B200-oriented defaults are conservative for the 8B teacher to 1.7B
151
+ > student workload. Smaller teacher/student pairs may tolerate larger
152
+ > micro-batches, but full-vocabulary KD scales sharply with vocabulary width.
153
+
154
+ The editable run configuration lives in [configs/config.yaml](configs/config.yaml).
155
+ Paths and Hub destinations are left as placeholders so each runner can set local
156
+ directories and repository names directly.
157
+
158
+ ## Why Online KD Replaced Offline Top-K KD
159
+
160
+ Earlier experiments cached only the teacher's top-k logits. That made storage
161
+ smaller, but with a Qwen vocabulary around 151K tokens, $k = 8$ exposes only:
162
+
163
+ $$
164
+ \frac{k}{|V|}
165
+ = \frac{8}{151{,}665}
166
+ \approx 5.3 \times 10^{-5}
167
+ = 0.0053\%
168
+ $$
169
+
170
+ of the vocabulary support at each position. The sparse signal could perturb the
171
+ student, but it did not consistently transfer deeper reasoning behavior.
172
+
173
+ The final online path keeps the teacher and student in memory together and
174
+ computes KL divergence against the teacher's full-vocabulary distribution. Token
175
+ chunking keeps that dense objective feasible without materializing a single
176
+ large KL workspace.
177
 
178
  ## Benchmark Scoreboard
179
 
180
+ The final public scoreboard compares `Qwen/Qwen3-1.7B-Base`,
181
+ `Qwen/Qwen3-1.7B-Instruct`, and Quintus-1.7B.
182
 
183
+ ![Model Evaluation Scoreboard](assets/benchmark_scoreboard.png)
 
 
 
 
 
 
 
184
 
185
+ The strongest signal is the reasoning crossover: Quintus beats both the base
186
+ and official 1.7B instruct model on GSM8K, ARC-Challenge, and WinoGrande while
187
+ remaining at the same parameter scale.
188
 
189
+ See [docs/benchmarks.md](docs/benchmarks.md) for the numeric table and
190
+ interpretation. See
191
+ [docs/evaluation_methodology.md](docs/evaluation_methodology.md) for benchmark
192
+ controls.
193
 
194
+ ## Evaluation Notes
195
+
196
+ Evaluation uses a mixture of EvalPlus and `lm-evaluation-harness`/vLLM style
197
+ benchmarks. The repository keeps evaluation methodology separate because prompt
198
+ format can change the result:
199
+
200
+ - Raw completion comparisons are used for base capability.
201
+ - Chat-template comparisons are used for assistant-format behavior.
202
+ - Log-likelihood tasks such as ARC-Challenge and PIQA should usually stay raw.
203
+ - GSM8K can differ between strict `####` parsing and flexible number
204
+ extraction.
205
+ - Metric extraction must ignore `stderr`, aliases, and wrong filter keys.
206
+ - Runtime versions, checkpoint identity, generation budget, and stale output
207
+ cleanup are part of the evaluation contract.
208
+
209
+ The active benchmark runner is [sft/evaluate.py](sft/evaluate.py). It covers
210
+ EvalPlus code tasks and `lm-evaluation-harness`/vLLM tasks, including GSM8K
211
+ 10-shot evaluation with an extended generation budget.
212
+
213
+ ## Repository Map
214
+
215
+ ```text
216
+ configs/ Public run profile and DeepSpeed Zero-2 template.
217
+ src/ Data prep, online KD, losses, packing, checkpoints, provenance.
218
+ sft/ Post-KD SFT, local chat, and consolidated evaluation runner.
219
+ docs/ Public architecture, training, evaluation, and release notes.
220
+ weight_audit/ Checkpoint structure and weight-divergence audit material.
221
+ ```
222
+
223
+ Key files:
224
+
225
+ - [src/train.py](src/train.py): SFT, offline KD compatibility, and final
226
+ `online_kd` training entry point.
227
+ - [src/download.py](src/download.py): model setup, dataset loading, schema
228
+ normalization, tokenization, and assistant-only loss masks.
229
+ - [src/losses.py](src/losses.py): CE/KD objective, including online full-vocab
230
+ KD token chunking.
231
+ - [src/sequence_packing.py](src/sequence_packing.py): deterministic first-fit
232
+ decreasing sequence packing.
233
+ - [src/checkpoints.py](src/checkpoints.py): checkpoint save/resume metadata and
234
+ packing compatibility checks.
235
+ - [src/provenance.py](src/provenance.py): tokenizer/model/data contract checks.
236
+ - [sft/train_sft.py](sft/train_sft.py): post-KD supervised fine-tuning.
237
+ - [sft/evaluate.py](sft/evaluate.py): EvalPlus and
238
+ `lm-evaluation-harness`/vLLM benchmark runner.
239
+ - [sft/chat.py](sft/chat.py): local interactive chat wrapper.
240
+
241
+ ## Commands
242
+
243
+ Install the base dependencies:
244
+
245
+ ```bash
246
+ pip install -r requirements.txt
247
+ ```
248
+
249
+ For training and benchmark runs, install the matching extras:
250
 
251
+ ```bash
252
+ pip install -r requirements-train.txt
253
+ pip install -r requirements-eval.txt
254
+ ```
255
+
256
+ Inspect or prepare data/model assets:
257
+
258
+ ```bash
259
+ python -m src.download --help
260
+ ```
261
+
262
+ Run the final KD path after editing [configs/config.yaml](configs/config.yaml)
263
+ for local paths and hardware:
264
+
265
+ ```bash
266
+ python -m src.train --phase online_kd
267
+ ```
268
+
269
+ Hub checkpoint uploads are off by default for local runs. Pass
270
+ `--upload_last_checkpoint` or the step/epoch upload flags only after setting the
271
+ target repository and `HF_TOKEN`.
272
+
273
+ Run the consolidated benchmark suite:
274
+
275
+ ```bash
276
+ python sft/evaluate.py
277
+ ```
278
+
279
+ Start local chat with a downloaded or local checkpoint:
280
+
281
+ ```bash
282
+ python sft/chat.py --model_path path/to/quintus/checkpoint
283
+ ```
284
+
285
+ ## Interactive Chat
286
+
287
+ ```python
288
  import torch
289
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
290
 
 
291
  PUBLIC_REPO_ID = "iamrahulreddy/Quintus"
292
 
293
  print(f"Loading Quintus from {PUBLIC_REPO_ID}...")
294
  tokenizer = AutoTokenizer.from_pretrained(PUBLIC_REPO_ID, trust_remote_code=True)
295
  model = AutoModelForCausalLM.from_pretrained(
296
+ PUBLIC_REPO_ID,
297
+ device_map="auto",
298
+ dtype=torch.float16,
299
+ trust_remote_code=True,
300
  )
301
 
 
302
  stop_tokens = ["<|endoftext|>", "<|im_end|>"]
303
  eos_token_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
304
  for token in stop_tokens:
305
+ token_id = tokenizer.convert_tokens_to_ids(token)
306
+ if token_id is not None and token_id not in eos_token_ids:
307
+ eos_token_ids.append(token_id)
308
 
309
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
310
 
311
  conversation_history = [
312
+ {
313
+ "role": "system",
314
+ "content": (
315
+ "You are Quintus, a highly capable AI assistant created by "
316
+ "Muskula Rahul. You are helpful, precise, and logically sound."
317
+ ),
318
+ }
319
  ]
320
 
321
+ print()
322
+ print("Quintus Chat (type 'quit' to exit)")
323
+ print()
324
 
325
  while True:
326
  try:
 
332
  continue
333
 
334
  conversation_history.append({"role": "user", "content": user_input})
335
+
336
  prompt = tokenizer.apply_chat_template(
337
+ conversation_history,
338
+ tokenize=False,
339
+ add_generation_prompt=True,
340
  )
341
+
342
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
343
+
344
  print("Quintus: ", end="", flush=True)
345
+
346
  with torch.no_grad():
347
  outputs = model.generate(
348
  **inputs,
 
352
  do_sample=True,
353
  streamer=streamer,
354
  pad_token_id=tokenizer.eos_token_id,
355
+ eos_token_id=eos_token_ids,
356
  )
357
+
 
358
  generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
359
+ assistant_response = tokenizer.decode(
360
+ generated_ids,
361
+ skip_special_tokens=True,
362
+ ).strip()
363
  conversation_history.append({"role": "assistant", "content": assistant_response})
364
  print()
365
+
366
  except KeyboardInterrupt:
367
  print("\n\nGoodbye!")
368
+ break
369
+ ```
370
+
371
+ ## Documentation
372
+
373
+ - [Documentation Index](docs/index.md): recommended public reading order.
374
+ - [Architecture](docs/architecture.md): end-to-end data flow, modules, and
375
+ training phases.
376
+ - [Experiment Timeline](docs/experiment_timeline.md): why the project moved
377
+ from offline top-k KD to online full-vocabulary KD.
378
+ - [Training Playbook](docs/training_playbook.md): memory rules, packing,
379
+ kernels, checkpointing, and B200-oriented guidance.
380
+ - [Pipeline Hardening](docs/pipeline_hardening.md): silent-failure classes,
381
+ artifact contracts, and safety checks.
382
+ - [Evaluation Methodology](docs/evaluation_methodology.md): raw/chat controls,
383
+ parser traps, metric extraction, and qualitative evaluation rules.
384
+ - [Engineering Insights](docs/engineering_insights.md): condensed lessons and
385
+ design decisions.
386
+ - [Benchmarks](docs/benchmarks.md): verified scoreboard and interpretation.
387
+ - [Weight Audit](docs/weight_audit.md): structural checkpoint sanity checks and
388
+ weight-divergence summary.
389
+ - [Hugging Face Model Card](docs/huggingface_model_card.md): release-page
390
+ copy for the public model card.
391
+
392
+ ## Limitations
393
+
394
+ - Quintus is still a 1.7B model and inherits compact-model capacity limits.
395
+ - Factual answers can be confidently wrong and should be verified.
396
+ - Code generation may still contradict stated complexity or edge-case
397
+ requirements.
398
+ - Raw and chat-template results are not interchangeable.
399
+ - Additional preference tuning or DPO would likely improve calibration, refusal
400
+ behavior, and open-ended assistant polish.
401
+
402
+ ## Credits
403
+
404
+ Quintus builds on open model, dataset, and tooling work from the broader LLM
405
+ community:
406
+
407
+ - [Qwen Team](https://qwenlm.github.io/) and the
408
+ [Qwen Hugging Face organization](https://huggingface.co/Qwen) for the Qwen3
409
+ model family.
410
+ - [`Qwen/Qwen3-8B`](https://huggingface.co/Qwen/Qwen3-8B), used as the
411
+ distillation teacher.
412
+ - [`Qwen/Qwen3-1.7B-Base`](https://huggingface.co/Qwen/Qwen3-1.7B-Base), used
413
+ as the base student checkpoint.
414
+ - [`Qwen/Qwen3-1.7B`](https://huggingface.co/Qwen/Qwen3-1.7B), used for the
415
+ tokenizer and chat-template contract.
416
+ - [Alibaba PAI](https://huggingface.co/alibaba-pai) for the
417
+ [`DistilQwen_100k`](https://huggingface.co/datasets/alibaba-pai/DistilQwen_100k)
418
+ dataset used as the primary instruction source after filtering.
419
+ - [Hugging Face Transformers](https://github.com/huggingface/transformers) for
420
+ model loading, tokenization, and generation APIs.
421
+ - [vLLM](https://github.com/vllm-project/vllm),
422
+ [EvalPlus](https://github.com/evalplus/evalplus), and
423
+ [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
424
+ for evaluation infrastructure.
425
+ - [FlashAttention](https://github.com/Dao-AILab/flash-attention) and
426
+ [Liger Kernel](https://github.com/linkedin/Liger-Kernel) for performance
427
+ kernels used or validated during training.
428
+
429
+ ## License And Author
430
+
431
+ This software is distributed under the MIT License. Refer to the
432
+ [LICENSE](LICENSE) file for full text.
433
+
434
+ Author: Muskula Rahul - [@iamrahulreddy](https://github.com/iamrahulreddy)
435
+
436
+ ## Citation
437
+
438
+ If this model, codebase, or training pipeline is useful in your work, please cite this repository and acknowledge the upstream Qwen3 models.
assets/benchmark_scoreboard.png ADDED

Git LFS Details

  • SHA256: 2070f6f33338dab31006a253b2eb4a3f9c1655490c444af5daa7c2cb07bb9b15
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
assets/offline_vs_online_kd.svg ADDED
assets/pipeline_hardening_flow.svg ADDED
assets/quintus_architecture.svg ADDED
configs/__init__.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+ import time
7
+ from datetime import timezone, timedelta
8
+ from pathlib import Path
9
+ from zoneinfo import ZoneInfo
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ _THIS_DIR = Path(__file__).resolve().parent
14
+ _YAML_PATH = _THIS_DIR / "config.yaml"
15
+
16
+ def _load_cfg():
17
+ return OmegaConf.load(_YAML_PATH)
18
+
19
+ cfg = _load_cfg()
20
+
21
+ _LOG_TZ_NAME = os.environ.get("QUINTUS_LOG_TZ", "Asia/Kolkata")
22
+ try:
23
+ _LOG_TZ = ZoneInfo(_LOG_TZ_NAME)
24
+ except Exception:
25
+ _LOG_TZ = timezone(timedelta(hours=5, minutes=30))
26
+ _LOG_TZ_NAME = "Asia/Kolkata"
27
+
28
+ os.environ["TZ"] = _LOG_TZ_NAME
29
+ if hasattr(time, "tzset"):
30
+ time.tzset()
31
+ _LOG_TZ_LABEL = "IST" if _LOG_TZ_NAME == "Asia/Kolkata" else _LOG_TZ_NAME
32
+
33
+ def _read_bool_env(name: str) -> bool | None:
34
+ raw = os.environ.get(name)
35
+ if raw is None:
36
+ return None
37
+
38
+ normalised = raw.strip().lower()
39
+ if normalised in {"1", "true", "yes", "on"}:
40
+ return True
41
+ if normalised in {"0", "false", "no", "off"}:
42
+ return False
43
+ raise ValueError(
44
+ f"Invalid boolean value for {name}: {raw!r}. "
45
+ "Use 1/0, true/false, yes/no, or on/off."
46
+ )
47
+
48
+ # Environment variable overrides used by the wrapper.
49
+ if os.environ.get("QUINTUS_TEACHER_MODEL"):
50
+ cfg.model.teacher = os.environ["QUINTUS_TEACHER_MODEL"]
51
+ if os.environ.get("QUINTUS_TEACHER_REVISION"):
52
+ cfg.model.teacher_revision = os.environ["QUINTUS_TEACHER_REVISION"]
53
+ if os.environ.get("QUINTUS_STUDENT_MODEL"):
54
+ cfg.model.student = os.environ["QUINTUS_STUDENT_MODEL"]
55
+ if os.environ.get("QUINTUS_STUDENT_REVISION"):
56
+ cfg.model.student_revision = os.environ["QUINTUS_STUDENT_REVISION"]
57
+ if os.environ.get("QUINTUS_TOKENIZER_MODEL"):
58
+ cfg.model.tokenizer = os.environ["QUINTUS_TOKENIZER_MODEL"]
59
+ if os.environ.get("QUINTUS_TOKENIZER_REVISION"):
60
+ cfg.model.tokenizer_revision = os.environ["QUINTUS_TOKENIZER_REVISION"]
61
+ if os.environ.get("QUINTUS_STUDENT_DIR"):
62
+ cfg.paths.student_dir = os.environ["QUINTUS_STUDENT_DIR"]
63
+ if os.environ.get("QUINTUS_TOKENIZER_DIR"):
64
+ cfg.paths.tokenizer_dir = os.environ["QUINTUS_TOKENIZER_DIR"]
65
+ if os.environ.get("NUM_SAMPLES"):
66
+ cfg.data.num_samples = int(os.environ["NUM_SAMPLES"])
67
+ if os.environ.get("TRAIN_NUM_EPOCHS"):
68
+ cfg.training.num_epochs = int(os.environ["TRAIN_NUM_EPOCHS"])
69
+ if os.environ.get("TRAIN_LEARNING_RATE"):
70
+ cfg.training.learning_rate = float(os.environ["TRAIN_LEARNING_RATE"])
71
+ if os.environ.get("TRAIN_ALPHA"):
72
+ cfg.training.alpha = float(os.environ["TRAIN_ALPHA"])
73
+ if os.environ.get("TRAIN_TEMPERATURE"):
74
+ cfg.training.temperature = float(os.environ["TRAIN_TEMPERATURE"])
75
+ if os.environ.get("TRAIN_TOP_K"):
76
+ cfg.training.top_k = int(os.environ["TRAIN_TOP_K"])
77
+ if os.environ.get("QUINTUS_ONLINE_KD_TOKEN_CHUNK_SIZE"):
78
+ cfg.training.online_kd_token_chunk_size = int(os.environ["QUINTUS_ONLINE_KD_TOKEN_CHUNK_SIZE"])
79
+ if os.environ.get("TRAIN_MICRO_BATCH_SIZE"):
80
+ cfg.training.micro_batch_size = int(os.environ["TRAIN_MICRO_BATCH_SIZE"])
81
+ if os.environ.get("TRAIN_GRAD_ACCUM_STEPS"):
82
+ cfg.training.grad_accum_steps = int(os.environ["TRAIN_GRAD_ACCUM_STEPS"])
83
+ if os.environ.get("TRAIN_DATALOADER_WORKERS"):
84
+ cfg.training.dataloader_workers = int(os.environ["TRAIN_DATALOADER_WORKERS"])
85
+ if os.environ.get("TRAIN_PREFETCH_FACTOR"):
86
+ cfg.training.prefetch_factor = int(os.environ["TRAIN_PREFETCH_FACTOR"])
87
+ sequence_packing_override = _read_bool_env("QUINTUS_SEQUENCE_PACKING")
88
+ if sequence_packing_override is not None:
89
+ cfg.training.sequence_packing.enabled = sequence_packing_override
90
+ if os.environ.get("QUINTUS_PACK_LENGTH"):
91
+ cfg.training.sequence_packing.pack_length = int(os.environ["QUINTUS_PACK_LENGTH"])
92
+ compile_override = _read_bool_env("QUINTUS_COMPILE_MODEL")
93
+ if compile_override is not None:
94
+ cfg.training.compile_model = compile_override
95
+ fused_adamw_override = _read_bool_env("TRAIN_FUSED_ADAMW")
96
+ if fused_adamw_override is not None:
97
+ cfg.training.fused_adamw = fused_adamw_override
98
+ if os.environ.get("QUINTUS_DISTILLED_DIR"):
99
+ cfg.paths.distilled_dir = os.environ["QUINTUS_DISTILLED_DIR"]
100
+ if os.environ.get("DATA_STREAM_SHUFFLE_BUFFER_SIZE"):
101
+ cfg.data.stream_shuffle_buffer_size = int(os.environ["DATA_STREAM_SHUFFLE_BUFFER_SIZE"])
102
+ if os.environ.get("DATA_STREAM_SHUFFLE_SEED"):
103
+ cfg.data.stream_shuffle_seed = int(os.environ["DATA_STREAM_SHUFFLE_SEED"])
104
+ remote_code_override = _read_bool_env("QUINTUS_ALLOW_REMOTE_CODE")
105
+ if remote_code_override is not None:
106
+ cfg.model.allow_remote_code = remote_code_override
107
+
108
+ class _TagFormatter(logging.Formatter):
109
+ def __init__(self, tag: str, fmt: str, datefmt: str | None = None):
110
+ super().__init__(fmt=fmt, datefmt=datefmt)
111
+ self.tag = tag
112
+
113
+ def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str:
114
+ dt = datetime_from_timestamp(record.created)
115
+ if datefmt:
116
+ return dt.strftime(datefmt)
117
+ return dt.isoformat(timespec="seconds")
118
+
119
+ def format(self, record: logging.LogRecord) -> str:
120
+ record.tag = self.tag # type: ignore[attr-defined]
121
+ return super().format(record)
122
+
123
+ def datetime_from_timestamp(timestamp: float):
124
+ from datetime import datetime
125
+
126
+ return datetime.fromtimestamp(timestamp, tz=_LOG_TZ)
127
+
128
+ def setup_logger(module_tag: str, rank: int = -1) -> logging.Logger:
129
+ name = f"quintus.{module_tag}"
130
+ logger = logging.getLogger(name)
131
+
132
+ if logger.handlers:
133
+ return logger
134
+
135
+ logger.setLevel(logging.DEBUG)
136
+ logger.propagate = False
137
+
138
+ # Suppress duplicate output from non-primary ranks.
139
+ if rank not in (-1, 0):
140
+ logger.addHandler(logging.NullHandler())
141
+ return logger
142
+
143
+ # Plain text file handler.
144
+ file_fmt = _TagFormatter(
145
+ tag=module_tag,
146
+ fmt=f"[%(asctime)s {_LOG_TZ_LABEL}] [%(levelname)-5s] [%(tag)-8s] %(message)s",
147
+ datefmt="%Y-%m-%d %H:%M:%S",
148
+ )
149
+ log_dir = os.path.dirname(cfg.paths.log_file)
150
+ if log_dir:
151
+ os.makedirs(log_dir, exist_ok=True)
152
+ file_handler = logging.FileHandler(cfg.paths.log_file, mode="a", encoding="utf-8")
153
+ file_handler.setLevel(logging.DEBUG)
154
+ file_handler.setFormatter(file_fmt)
155
+ logger.addHandler(file_handler)
156
+
157
+ # Plain text console handler. Keep the runtime logs stable across terminals,
158
+ # notebooks and log files.
159
+ console_handler = logging.StreamHandler(sys.stdout)
160
+ console_handler.setLevel(logging.INFO)
161
+ console_handler.setFormatter(file_fmt)
162
+ logger.addHandler(console_handler)
163
+
164
+ return logger
165
+
166
+ def emit_log_spacing(logger: logging.Logger, count: int = 2) -> None:
167
+ if count <= 0:
168
+ return
169
+
170
+ blank_block = "\n" * count
171
+ for handler in logger.handlers:
172
+ if isinstance(handler, logging.NullHandler):
173
+ continue
174
+
175
+ stream = getattr(handler, "stream", None)
176
+ if stream is not None and hasattr(stream, "write"):
177
+ stream.write(blank_block)
178
+ flush = getattr(stream, "flush", None)
179
+ if callable(flush):
180
+ flush()
181
+ continue
182
+
183
+ console = getattr(handler, "console", None)
184
+ if console is not None:
185
+ console.print(blank_block, end="")
186
+
configs/config.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quintus Distillation Pipeline
2
+ # Run profile: online full-vocabulary KD, 8B teacher -> 1.7B-Base student.
3
+ # Data: ~90K English-only samples from DistilQwen_100k.
4
+
5
+ data:
6
+ dataset_path: "<REDACTED_ON_PURPOSE>"
7
+ num_samples: 90234
8
+ max_seq_len: 4096
9
+ stream_shuffle_buffer_size: 20000
10
+ stream_shuffle_seed: 25
11
+
12
+ model:
13
+ teacher: "Qwen/Qwen3-8B"
14
+ student: "Qwen/Qwen3-1.7B-Base"
15
+
16
+ # The instruct tokenizer carries the chat template used to format the base
17
+ # student into assistant-style training examples.
18
+ tokenizer: "Qwen/Qwen3-1.7B"
19
+
20
+ teacher_revision: "main"
21
+ student_revision: "main"
22
+ tokenizer_revision: "main"
23
+ allow_remote_code: false
24
+
25
+ training:
26
+ # Schedule
27
+ num_epochs: 1
28
+ validation_ratio: 0.02
29
+ split_seed: 25
30
+
31
+ # Optimizer
32
+ learning_rate: 5.0e-6
33
+ weight_decay: 0.1
34
+ warmup_ratio: 0.05
35
+
36
+ # Loss mix used by src/losses.py:
37
+ # total = alpha * CE + (1 - alpha) * KD
38
+ alpha: 0.3
39
+ temperature: 2.0
40
+
41
+ # Online KD streams full-vocabulary teacher logits. top_k is retained for
42
+ # offline-KD compatibility/provenance checks.
43
+ top_k: 8
44
+ online_kd_token_chunk_size: 2048
45
+
46
+ # Conservative B200 profile. Effective batch = 4 * 2 = 8.
47
+ # If VRAM headroom is comfortable and Liger is installed, try 8 * 1.
48
+ micro_batch_size: 4
49
+ grad_accum_steps: 2
50
+ gradient_checkpointing: false
51
+ compile_model: false
52
+ fused_adamw: true
53
+
54
+ dataloader_workers: 8
55
+ prefetch_factor: 2
56
+
57
+ sequence_packing:
58
+ enabled: true
59
+ pack_length: 4096
60
+ mask_first_token_after_separator: true
61
+
62
+ hub:
63
+ # Prefer HF_TOKEN or huggingface-cli login for real runs.
64
+ token: null
65
+ username: "<REDACTED_ON_PURPOSE>"
66
+ repo_name: "<REDACTED_ON_PURPOSE>"
67
+
68
+ paths:
69
+ teacher_dir: "<REDACTED_ON_PURPOSE>"
70
+ student_dir: "<REDACTED_ON_PURPOSE>"
71
+ tokenizer_dir: "<REDACTED_ON_PURPOSE>"
72
+ tokenized_dir: "<REDACTED_ON_PURPOSE>"
73
+ logits_dir: "<REDACTED_ON_PURPOSE>"
74
+ distilled_dir: "<REDACTED_ON_PURPOSE>"
75
+ log_file: "<REDACTED_ON_PURPOSE>"
76
+ system_info: "<REDACTED_ON_PURPOSE>"
77
+ loss_csv: "<REDACTED_ON_PURPOSE>"
configs/ds_zero2.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "zero_optimization": {
3
+ "stage": 2,
4
+ "allgather_partitions": true,
5
+ "allgather_bucket_size": 500000000,
6
+ "reduce_scatter": true,
7
+ "reduce_bucket_size": 500000000,
8
+ "overlap_comm": true,
9
+ "contiguous_gradients": true
10
+ },
11
+ "bf16": {
12
+ "enabled": true
13
+ },
14
+ "gradient_clipping": 1.0,
15
+ "steps_per_print": 50,
16
+ "wall_clock_breakdown": false,
17
+ "comms_logger": {
18
+ "enabled": false
19
+ }
20
+ }
docs/architecture.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Architecture
2
+
3
+ Quintus is built as a two-stage model development pipeline:
4
+
5
+ 1. Online full-vocabulary knowledge distillation from a larger Qwen3 teacher into a Qwen3-1.7B base student.
6
+ 2. Targeted SFT to improve instruction-following behavior, persona consistency, and generation stability.
7
+
8
+ ![Quintus Architecture](../assets/quintus_architecture.svg)
9
+
10
+ ## Core Training Path
11
+
12
+ The main training entry point is `src/train.py`. It supports three phases:
13
+
14
+ - `sft`: Cross-entropy training on assistant response tokens.
15
+ - `kd`: Offline top-k teacher-logit distillation, retained for compatibility and provenance checks.
16
+ - `online_kd`: The final preferred path. Teacher logits are produced live during the student forward pass.
17
+
18
+ The final KD objective is implemented in `src/losses.py`:
19
+
20
+ $$
21
+ \mathcal{L}_{\text{total}}
22
+ = \alpha \mathcal{L}_{\text{CE}}
23
+ + (1 - \alpha)\mathcal{L}_{\text{KD}}
24
+ $$
25
+
26
+ For the final run, $\alpha = 0.3$ and $T = 2.0$. In this codebase, $\alpha$ is the cross-entropy weight. The complementary weight is assigned to the KD term.
27
+
28
+ ## Data Flow
29
+
30
+ `src/download.py` prepares the training data. It handles both pre-tokenized rows and raw instruction data. For raw rows, it normalizes common conversation schemas, applies the tokenizer chat template, and builds an assistant-only `loss_mask`.
31
+
32
+ Important details:
33
+
34
+ - Prompt and formatting tokens are masked out.
35
+ - Assistant response tokens receive loss.
36
+ - Samples longer than `max_seq_len` are rejected rather than silently truncated.
37
+ - The tokenizer contract is later validated to avoid teacher/student vocabulary mismatches.
38
+
39
+ ## Sequence Packing
40
+
41
+ `src/sequence_packing.py` implements deterministic first-fit decreasing packing. It places multiple shorter samples into fixed-length bins, separated by EOS tokens.
42
+
43
+ Packing properties:
44
+
45
+ - Training split is packed; validation can remain unpacked for interpretability.
46
+ - Bins are fixed at `pack_length = 4096` in the final profile.
47
+ - EOS separators have `loss_mask = 0`.
48
+ - The first token after a separator is optionally masked to avoid cross-sample target leakage.
49
+ - Attention masks are built from the true packed length, not by comparing token IDs against `pad_token_id`.
50
+
51
+ The attention-mask detail is important because Qwen tokenizers can reuse EOS-like IDs in ways that make token-identity-derived padding masks unsafe.
52
+
53
+ ## Online KD Memory Strategy
54
+
55
+ Full-vocabulary KD is expensive because both student and teacher produce logits shaped as:
56
+
57
+ $$
58
+ \text{student\_logits},\ \text{teacher\_logits}
59
+ \in \mathbb{R}^{B \times S \times |V|}
60
+ $$
61
+
62
+ The implementation keeps this feasible by chunking along the token dimension with:
63
+
64
+ $$
65
+ C_{\text{KD}} = 2048
66
+ $$
67
+
68
+ Each chunk computes the teacher softmax, student log-softmax, and masked KL contribution, then accumulates the result. This preserves the dense teacher distribution while avoiding a single large KL workspace.
69
+
70
+ ## Validation, Provenance, And Safety Checks
71
+
72
+ Several modules exist to prevent silent training corruption:
73
+
74
+ - `src/provenance.py`: Validates tokenizer contracts, vocab sizes, revisions, and teacher-logit metadata.
75
+ - `src/kd_contracts.py`: Builds deterministic tokenizer fingerprints.
76
+ - `src/training_schedule.py`: Aligns train/validation splits with batch and gradient-accumulation constraints.
77
+ - `src/checkpoints.py`: Saves model, tokenizer, scheduler, trainer state, and packing metadata; validates resume compatibility.
78
+ - `src/transformers_compat.py`: Resolves attention backend and formats model-loading errors.
79
+
80
+ ## SFT Layer
81
+
82
+ The `sft/` directory contains the post-KD alignment layer:
83
+
84
+ - `sft/train_sft.py`: SFT training with optional sequence packing, LoRA/QLoRA paths, and built-in spot evaluations.
85
+ - `sft/evaluate.py`: EvalPlus and lm-evaluation-harness orchestration.
86
+ - `sft/chat.py`: Local interactive chat wrapper using the tokenizer chat template.
87
+
88
+ This stage is intentionally separate from KD. KD transfers the teacher's probability structure; SFT teaches the model how to expose that capability in the intended assistant format.
docs/benchmarks.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Benchmarks
2
+
3
+ The release scoreboard compares Qwen3-1.7B-Base, Qwen3-1.7B-Instruct, and Quintus-1.7B. Evaluations use a mixture of EvalPlus and lm-evaluation-harness style benchmarks, with greedy or deterministic settings where applicable.
4
+
5
+ For the detailed benchmark-control rules, see [Evaluation Methodology](evaluation_methodology.md).
6
+
7
+ ## Final Scoreboard
8
+
9
+ | Benchmark | Qwen3-1.7B-Base | Qwen3-1.7B-Instruct | Quintus-1.7B |
10
+ | :--- | :---: | :---: | :---: |
11
+ | HumanEval pass@1 | 67.1% | 70.7% | 67.7% |
12
+ | MBPP pass@1 | 67.2% | 58.2% | 64.8% |
13
+ | GSM8K, 10-shot flexible | 69.98% | 69.75% | 74.30% |
14
+ | ARC-Challenge acc_norm | 55.72% | 52.99% | 58.36% |
15
+ | WinoGrande, 5-shot | 65.67% | 61.01% | 66.38% |
16
+ | PIQA acc_norm | 75.63% | 72.09% | 75.57% |
17
+
18
+ ## Interpretation
19
+
20
+ The strongest result is the reasoning crossover: Quintus beats both the base and the official 1.7B instruct model on GSM8K, ARC-Challenge, and WinoGrande, despite remaining at the same parameter scale.
21
+
22
+ The coding picture is mixed but useful:
23
+
24
+ - HumanEval remains slightly below Qwen3-1.7B-Instruct.
25
+ - MBPP is substantially above Qwen3-1.7B-Instruct, though still below the base model.
26
+
27
+ This suggests the model gained useful instruction-following and reasoning behavior without fully matching larger or more heavily aligned code-specialized models.
28
+
29
+ ## What The Benchmarks Support
30
+
31
+ These results support four claims:
32
+
33
+ 1. Online KD transferred reasoning capability into a compact student.
34
+ 2. The final model did not merely memorize assistant formatting; it improved several reasoning and commonsense metrics.
35
+ 3. SFT helped expose the distilled capability in an assistant setting.
36
+ 4. The model still has capacity limits typical of the 1.7B scale, especially on code execution reliability and long multi-step algorithm generation.
37
+
38
+ ## Evaluation Caveats
39
+
40
+ Benchmark comparisons are sensitive to prompt format. Raw completion, chat-template generation, and log-likelihood multiple-choice scoring can produce different rankings. For fair interpretation:
41
+
42
+ - Compare raw models against raw models when measuring base reasoning.
43
+ - Compare chat-wrapped models against chat-wrapped models when measuring format alignment.
44
+ - Treat open-ended qualitative prompts as alignment tests, not as a replacement for standardized benchmarks.
45
+
46
+ Important implementation caveats:
47
+
48
+ - GSM8K extraction can differ between strict `####` parsing and flexible number extraction.
49
+ - Multiple-choice log-likelihood tasks can be distorted by chat templates.
50
+ - `acc_norm` is preferred when answer-option length bias can change the ranking.
51
+ - Metric extraction scripts must reject `stderr` and `alias` fields when looking for the actual score.
52
+ - Runtime versions should be recorded with benchmark outputs because harness behavior can change across releases.
53
+
54
+ ## Earlier Development Signals
55
+
56
+ Before the final Qwen3 8B -> 1.7B run, earlier experiments showed that sparse offline top-k KD could not consistently outperform strong baselines. Those runs were useful because they identified the bottleneck: sparse cached teacher logits were not dense enough to transfer deeper reasoning pathways.
57
+
58
+ The final move to online full-vocabulary KD is the key methodological change behind the stronger final results.
docs/engineering_insights.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Engineering Insights
2
+
3
+ This project evolved through several failed and successful training designs. The useful lessons are summarized here as public engineering notes.
4
+
5
+ For expanded operational detail, see [Training Playbook](training_playbook.md), [Pipeline Hardening](pipeline_hardening.md), and [Evaluation Methodology](evaluation_methodology.md).
6
+
7
+ ## 1. Sparse Offline KD Hit A Ceiling
8
+
9
+ The earliest distillation path cached only a small top-k slice of teacher logits. That made training cheaper, but it discarded most of the teacher distribution. With a vocabulary of roughly 151K tokens and $k = 8$, the visible support was:
10
+
11
+ $$
12
+ \frac{k}{|V|}
13
+ = \frac{8}{151{,}665}
14
+ \approx 5.3 \times 10^{-5}
15
+ = 0.0053\%
16
+ $$
17
+
18
+ The result was clear: top-k KD could perturb the student, but it did not transfer enough "dark knowledge" to reliably improve reasoning. Different alphas, epochs, and student initializations could not escape this sparse-signal ceiling.
19
+
20
+ The final fix was to use online KD: load teacher and student together, run both forward passes, and compute KL against the teacher's full vocabulary distribution.
21
+
22
+ ## 2. Base Student Was Better Than Fighting An Aligned Space
23
+
24
+ Distilling into an already instruction-tuned student can cause destructive interference. The student's weights already encode one aligned behavior manifold, while the teacher's soft logits pull toward another. Training can look numerically stable while reasoning metrics regress.
25
+
26
+ The final path uses `Qwen/Qwen3-1.7B-Base` as the student. The base model has more plasticity, while the CE term and later SFT stage teach assistant formatting.
27
+
28
+ ## 3. KD And Alignment Are Different Problems
29
+
30
+ Standardized benchmarks showed that KD can improve reasoning and calibration, but open-ended chat quality still needs alignment data.
31
+
32
+ The important diagnosis:
33
+
34
+ - A distillation failure means the student did not absorb the teacher's useful probability structure.
35
+ - An alignment gap means the student has capability, but the generation path is not yet trained to behave like a polished assistant.
36
+
37
+ The project therefore separates the pipeline into KD first, then SFT.
38
+
39
+ ## 4. Assistant-Only Loss Masking Matters
40
+
41
+ A key bug class was assigning loss to chat formatting tokens instead of only assistant response content. If the model is trained to optimize structural tokens too heavily, it can learn formatting before substance.
42
+
43
+ The current tokenization path derives an assistant-only `loss_mask`, so:
44
+
45
+ - User prompts are context, not targets.
46
+ - Chat headers and separators are masked.
47
+ - Assistant response tokens are the only supervised targets.
48
+
49
+ This keeps training focused on semantic outputs rather than wrapper reproduction.
50
+
51
+ ## 5. Sequence Packing Was The Main Throughput Win
52
+
53
+ The dataset contains many sequences shorter than the maximum context length. Dynamic padding wastes a large fraction of compute. First-fit decreasing sequence packing converted that waste into useful tokens.
54
+
55
+ Observed engineering outcome:
56
+
57
+ - Unpacked B200 online KD ran around the low-20K tokens/sec range in earlier probes.
58
+ - Packed B200 online KD reached roughly the mid-40K tokens/sec range after warmup.
59
+ - Packed utilization was close to full 4096-token bins.
60
+
61
+ The final code keeps packing deterministic and stores packing metadata in checkpoints so packed/unpacked resume mismatches fail loudly.
62
+
63
+ ## 6. Full-Vocab KD Needed Token Chunking
64
+
65
+ Online KD preserves the full teacher distribution, but a full KL workspace at Qwen vocabulary scale is too large to materialize casually:
66
+
67
+ $$
68
+ \text{KL workspace} \sim \mathbb{R}^{B \times S \times |V|}
69
+ $$
70
+
71
+ The solution is token-dimension chunking. The current implementation uses:
72
+
73
+ $$
74
+ C_{\text{KD}} = 2048
75
+ $$
76
+
77
+ Larger chunks reduce loop overhead, but increase temporary memory pressure. The selected value is a practical B200-oriented balance for the 8B -> 1.7B workload.
78
+
79
+ ## 7. Shape Churn And Synchronization Can Quietly Drain Throughput
80
+
81
+ Several performance bugs were not correctness bugs:
82
+
83
+ - Dynamic sequence lengths caused allocator churn.
84
+ - Repeated `.item()` calls forced CPU-GPU synchronization.
85
+ - Single-GPU DeepSpeed could add overhead when the model already fit comfortably.
86
+ - `torch.compile` added memory overhead, dynamic-shape graph breaks, recompile overhead, and checkpoint portability risk.
87
+
88
+ The final training loop favors stable shapes, fewer scalar syncs, fused AdamW when available, FlashAttention when available, and Liger kernels where they do not conflict with KD logits.
89
+
90
+ ## 8. Evaluation Requires Controlled Comparisons
91
+
92
+ Raw completion and chat-template evaluation activate different behavior. A base model can perform well in raw mode and poorly under chat markup. A chat-aligned model can underperform on raw continuation-style tasks if the benchmark asks for direct option likelihoods.
93
+
94
+ The project uses both controls:
95
+
96
+ - Raw-to-raw comparisons isolate distilled base capability.
97
+ - Chat-to-chat comparisons estimate template robustness and assistant-format alignment.
98
+
99
+ This distinction avoids blaming KD for failures that belong to alignment or benchmark formatting.
100
+
101
+ ## 9. Post-KD SFT Is Not Optional For Assistant Quality
102
+
103
+ KD transfers probability structure; it does not guarantee careful behavior, refusal policy, calibrated uncertainty, or code reliability. Targeted SFT was added to address:
104
+
105
+ - Confident hallucination in open-ended answers.
106
+ - Persona and identity consistency.
107
+ - Repetition loops.
108
+ - Chat-format stability.
109
+ - Practical assistant presentation.
110
+
111
+ Preference training or DPO would be the natural next layer if the project continues beyond the current release.
112
+
113
+ ## 10. Training Loss Is Not The Release Gate
114
+
115
+ Several development runs looked numerically healthy while downstream benchmarks moved in the wrong direction. That pattern is expected when the training objective is only a proxy for the release objective.
116
+
117
+ Useful release gates:
118
+
119
+ - Standardized benchmarks.
120
+ - Raw and chat controls.
121
+ - Mismatch inspection.
122
+ - Qualitative prompts after benchmark checks.
123
+ - Weight and checkpoint structure audits.
124
+
125
+ Held-out KD validation loss is important, but it cannot prove that the model improved on math, code, multiple-choice reasoning, or assistant behavior.
126
+
127
+ ## 11. Fail-Fast Beats Silent Recovery
128
+
129
+ The pipeline hardened around a simple rule: corrupt artifacts should stop the run.
130
+
131
+ Examples:
132
+
133
+ - Missing teacher-logit shards fail instead of becoming zero tensors.
134
+ - Tokenization with zero usable rows fails immediately.
135
+ - Shard schema mismatches are rejected.
136
+ - Packed/unpacked checkpoint resume mismatches are rejected.
137
+ - Stale evaluation outputs are cleaned before new scores are written.
138
+
139
+ This makes errors louder, but it keeps published numbers trustworthy.
140
+
141
+ ## 12. Public Docs Should Preserve Decisions
142
+
143
+ A release-quality project should expose durable engineering conclusions:
144
+
145
+ - why online KD replaced offline top-k KD,
146
+ - why assistant-only masking matters,
147
+ - why raw/chat evaluation controls are required,
148
+ - why sequence packing changed throughput,
149
+ - why SFT remains necessary after KD,
150
+ - why checkpoint and provenance checks exist.
151
+
152
+ That level of detail is enough for technical readers without turning the documentation into a chronological run journal.
docs/evaluation_methodology.md ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation Methodology
2
+
3
+ Evaluation was one of the hardest parts of Quintus. Several early scores were misleading until prompt format, metric extraction, parser behavior, and runtime artifacts were audited carefully.
4
+
5
+ ## Evaluation Principle
6
+
7
+ A model comparison is only meaningful when the prompt format and metric path match the question being asked.
8
+
9
+ Two distinct questions matter:
10
+
11
+ - Base capability: Does the distilled model improve raw reasoning and likelihood behavior?
12
+ - Assistant behavior: Does the distilled model handle chat formatting and produce usable responses?
13
+
14
+ Those questions need separate controls.
15
+
16
+ ## Run Identity And Determinism
17
+
18
+ A benchmark record should identify the checkpoint role (`best` or `last`), exact model directory or revision, prompt mode, seeds, decoding mode, and runtime versions. Greedy decoding with fixed seeds makes repeated runs easier to compare, but hardware and kernel drift can still change edge cases.
19
+
20
+ Treat determinism as an artifact contract, not a vague claim.
21
+
22
+ ## Raw-To-Raw And Chat-To-Chat
23
+
24
+ Raw completion and chat-template prompting activate different model behavior. A base model can be strong in raw mode and weak under chat markup. An instruct model can be strong in chat, but weak on raw continuation-style likelihood tasks.
25
+
26
+ Recommended controls:
27
+
28
+ - Raw-to-raw: compare base-style prompts against base-style prompts.
29
+ - Chat-to-chat: compare chat-wrapped prompts against chat-wrapped prompts.
30
+ - Raw-vs-chat within the same model: measure format tax.
31
+
32
+ Avoid comparing a chat-wrapped distilled model directly against a raw base baseline and treating the delta as pure capability transfer.
33
+
34
+ ## Log-Likelihood Tasks Should Usually Stay Raw
35
+
36
+ Multiple-choice tasks such as ARC-Challenge, HellaSwag, and PIQA often score options by likelihood:
37
+
38
+ $$
39
+ P(\text{option}\mid\text{prompt})
40
+ $$
41
+
42
+ Wrapping the prompt in chat markup changes the next-token distribution. An aligned model may not want to begin a response with a bare option string after `<|im_start|>assistant`, so option likelihoods can fall for formatting reasons rather than reasoning reasons.
43
+
44
+ For log-likelihood tasks:
45
+
46
+ - Use raw completion format unless the benchmark was designed for chat.
47
+ - Prefer `acc_norm` where length bias matters.
48
+ - Record whether chat templates were applied.
49
+
50
+ ## GSM8K Parser Traps
51
+
52
+ GSM8K evaluation can be distorted by parser behavior.
53
+
54
+ Two common filters behave differently:
55
+
56
+ - `strict-match`: looks for an answer after the `####` delimiter.
57
+ - `flexible-extract`: searches for numbers and may choose the last matched number.
58
+
59
+ A chat model can solve the problem, emit the correct `####` answer, miss EOS, and continue into a hallucinated next dialogue turn containing another number. In that case:
60
+
61
+ - `strict-match` may score the response correct.
62
+ - `flexible-extract` may grab the later hallucinated number and score it wrong.
63
+
64
+ This is not just a parser detail. It reveals an EOS and prompt-format interaction.
65
+
66
+ Mitigations:
67
+
68
+ - Register all relevant EOS tokens, including `<|im_end|>` and `<|endoftext|>`.
69
+ - Use deterministic generation for benchmark runs.
70
+ - Avoid excessive `fewshot_as_multiturn` wrapping unless the model was trained for that shape.
71
+ - Inspect mismatches, not just aggregate scores.
72
+
73
+ ## Reasoning Models Need Enough Generation Budget
74
+
75
+ Instruction-tuned reasoning models may spend hundreds of tokens inside a reasoning trace before reaching the final answer. If `max_new_tokens` is too small, the model can be cut off before emitting the final answer marker.
76
+
77
+ That can make a capable model appear weak under exact-match metrics.
78
+
79
+ For fair GSM8K-style generation:
80
+
81
+ - Set a sufficient generation limit.
82
+ - Track truncation rate.
83
+ - Compare extracted answers against raw responses during audits.
84
+
85
+ ## Batched Generation Details
86
+
87
+ Decoder-only batched generation should use left-padding. Right-padding can put the next-token position on padding for shorter prompts and make batched outputs differ from single-sample outputs.
88
+
89
+ Generation parsers should:
90
+
91
+ - Set `tokenizer.padding_side = "left"` for batched generation.
92
+ - Slice decoded continuations by each prompt's true input length.
93
+ - Stop at the first registered EOS token.
94
+ - Record truncation and empty-generation counts.
95
+
96
+ ## English-Only Evaluation Controls
97
+
98
+ For English-only release checks, filtering the dataset is necessary but not sufficient. Evaluation should also use an English-only system instruction when chat prompts are enabled, register all relevant EOS IDs, and clean generated artifacts that continue into another language after the intended answer.
99
+
100
+ This cleanup is an evaluation-artifact guard. It is not a substitute for training data quality, SFT, preference tuning, or behavioral calibration.
101
+
102
+ ## Metric Extraction Must Be Strict
103
+
104
+ Post-processing scripts should never fall back loosely to any metric key that starts with the right prefix. A loose fallback can accidentally read:
105
+
106
+ - `*_stderr`
107
+ - `alias`
108
+ - a different filter result
109
+
110
+ Robust extraction should:
111
+
112
+ - Match the exact metric and filter name.
113
+ - Ignore stderr and alias fields when extracting scores.
114
+ - Fail loudly if the expected key is absent.
115
+
116
+ ## Boolean CLI Flags
117
+
118
+ Some harness flags use `action="store_true"`. Passing `"False"` after such a flag does not disable it; the presence of the flag enables it.
119
+
120
+ Correct pattern:
121
+
122
+ - Include the flag only when true.
123
+ - Omit the flag when false.
124
+
125
+ This matters for options such as multiturn few-shot formatting.
126
+
127
+ ## Sample Log Format
128
+
129
+ `lm-evaluation-harness` may log different filters for the same document as separate JSONL objects with the same `doc_id`. A parser that assumes one object contains all filters can crash or silently compare the wrong fields.
130
+
131
+ Correct approach:
132
+
133
+ - Group sample records by `doc_id`.
134
+ - Index filter-specific records inside each group.
135
+ - Compare strict and flexible outputs from the same document.
136
+
137
+ ## JSONL Parsing With Unicode Line Separators
138
+
139
+ Model outputs can contain Unicode line separator characters such as `\u2028` or `\u2029`. Calling `str.splitlines()` on a whole JSONL file can split a valid JSON string into invalid fragments.
140
+
141
+ Robust JSONL parsing:
142
+
143
+ ```python
144
+ with open(path, "r", encoding="utf-8") as f:
145
+ for line in f:
146
+ if line.strip():
147
+ record = json.loads(line)
148
+ ```
149
+
150
+ Iterating the file handle respects actual line endings and does not split on Unicode separators inside JSON strings.
151
+
152
+ ## Hub Loading And Snapshot Hygiene
153
+
154
+ If weights or datasets are stored on the Hub, the client should be told the correct repository type. Download or snapshot the artifact first, verify that expected files exist, then pass the local directory to Transformers, vLLM, or the evaluation harness.
155
+
156
+ This separates transfer failures from engine construction and avoids repeated downloads during long benchmark runs.
157
+
158
+ Optional high-throughput Hub transfer backends such as `hf_transfer` can reduce setup time, but the correctness contract is still local snapshot validation.
159
+
160
+ ## Path-Length And Output Artifacts
161
+
162
+ Evaluation tools can derive output paths from model paths. Deep Hugging Face cache paths can become extremely long after sanitization, especially on Windows.
163
+
164
+ Public guidance:
165
+
166
+ - Copy or symlink model weights to a short local directory before evaluation.
167
+ - Pass short relative paths to the evaluator.
168
+ - Keep result directories shallow.
169
+ - Fail if expected sample files are missing.
170
+
171
+ This prevents silent write failures and missing-output confusion.
172
+
173
+ ## vLLM Evaluation Settings
174
+
175
+ For large benchmark runs, vLLM can greatly reduce runtime through continuous batching and KV-cache management.
176
+
177
+ Useful settings in development:
178
+
179
+ - `batch_size = auto`
180
+ - prefix caching enabled
181
+ - PagedAttention-backed KV-cache management when available
182
+ - bounded GPU memory utilization
183
+ - explicit `max_model_len` where context bounds matter
184
+ - explicit attention backend where the runtime supports it
185
+ - local pre-caching of model snapshots before engine construction
186
+ - explicit engine teardown between model runs
187
+
188
+ The benchmark artifact should record runtime versions for:
189
+
190
+ - `lm-eval`
191
+ - `vllm`
192
+ - `transformers`
193
+ - `torch`
194
+ - `datasets`
195
+ - `accelerate`
196
+
197
+ Version drift can change metric keys, generation behavior, attention backends, and output formats.
198
+
199
+ ## Qualitative Evaluation
200
+
201
+ Open-ended prompt suites are useful, but they are not replacements for standardized benchmarks.
202
+
203
+ A good qualitative suite should:
204
+
205
+ - Compare raw and chat modes separately.
206
+ - Use fixed prompts and deterministic ordering.
207
+ - Include benchmark-template leakage probes.
208
+ - Include factual, math, code, system design, and LLM-internals prompts.
209
+ - Record complete outputs.
210
+ - Inspect inherited base-model errors separately from new chat-mode errors.
211
+
212
+ Qualitative failures should be classified:
213
+
214
+ - Distillation failure: the student did not absorb useful teacher probability structure.
215
+ - Alignment gap: capability exists, but the generation path lacks SFT, preference tuning, or calibration.
216
+ - Data contamination: the model repeats benchmark or pretraining artifacts.
217
+ - Code reliability gap: prose is correct, but generated code violates stated constraints.
218
+
219
+ This distinction prevents the wrong fix. Distillation failures need KD changes. Alignment gaps need SFT, DPO, RLHF, or curated behavior data.
220
+
221
+ ## Release Gate
222
+
223
+ The final checkpoint should pass all of these before public claims are made:
224
+
225
+ - Benchmark tasks use the intended prompt format.
226
+ - Metric keys are exact.
227
+ - Sample counts match the full benchmark set.
228
+ - Raw and chat comparisons are not mixed.
229
+ - Generation limits are sufficient for the model style.
230
+ - Checkpoint identity is explicit.
231
+ - Missing requested checkpoints fail instead of falling back to older local weights.
232
+ - Runtime versions are recorded.
233
+ - Mismatch samples are inspected for parser artifacts.
234
+ - No stale result directory or old JSON file is reused.
docs/experiment_timeline.md ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment Timeline
2
+
3
+ This timeline explains why the final Quintus design looks the way it does. It focuses on the technical evolution from sparse offline distillation to the final online full-vocabulary pipeline.
4
+
5
+ ## 1. Offline Top-K KD Prototype
6
+
7
+ The earliest design precomputed teacher logits to disk and trained the student from cached top-k supports.
8
+
9
+ Why it was attractive:
10
+
11
+ - Avoided loading teacher and student together.
12
+ - Reduced KD memory from full vocabulary to top-k support.
13
+ - Made cloud interruptions easier to survive because teacher logits were already saved.
14
+
15
+ Main lessons:
16
+
17
+ - Serialization contracts matter as much as loss math.
18
+ - Top-k token IDs need safe dtypes.
19
+ - Teacher-logit shards must preserve original row order.
20
+ - Missing or stale shards should fail loudly.
21
+
22
+ ## 2. Static Audit And Fail-Fast Hardening
23
+
24
+ The project then moved through a static-audit phase focused on silent failure modes.
25
+
26
+ Major hardening themes:
27
+
28
+ - Dataset zero-retention checks.
29
+ - Missing-shard hard failures.
30
+ - Stale artifact cleanup.
31
+ - DeepSpeed accumulation correctness.
32
+ - Rank-safe writes.
33
+ - Explicit model revision and remote-code policy.
34
+ - Stronger provenance metadata.
35
+
36
+ This phase turned the code from a script bundle into a more reliable training pipeline.
37
+
38
+ ## 3. Assistant-Only Supervision
39
+
40
+ The tokenization path originally risked supervising the whole conversation. That can over-train prompts, headers, and formatting tokens.
41
+
42
+ The corrected path derives `loss_mask` and trains only on assistant response tokens.
43
+
44
+ This changed the training contract:
45
+
46
+ - Prompt tokens provide context.
47
+ - Assistant tokens receive CE and KD loss.
48
+ - Rows without assistant targets are rejected.
49
+ - Checkpoints and datasets must agree on the mask schema.
50
+
51
+ ## 4. Top-K Plus Residual Bucket
52
+
53
+ A later offline-KD pass improved the sparse support by adding an "other" bucket for teacher probability mass outside top-k.
54
+
55
+ This fixed a mathematical weakness: the student should be normalized against the full vocabulary before comparison, not only inside top-k. The residual bucket made offline KD less wrong, but it still compressed most of the teacher distribution into one scalar.
56
+
57
+ That design was useful, but not enough for flagship results.
58
+
59
+ ## 5. Dataset And Objective Mismatch
60
+
61
+ Smoke runs showed a pattern that became important later: held-out KD validation loss can improve while benchmark quality worsens.
62
+
63
+ Key diagnosis:
64
+
65
+ - Matching teacher token distributions on a training corpus is not identical to improving GSM8K, ARC, coding, or open-ended assistant quality.
66
+ - Dataset order and first-N streaming can bias sample selection.
67
+ - Long reasoning traces can overweight style and process tokens relative to final answers.
68
+ - Small students can forget useful baseline behavior when full-parameter training is too aggressive.
69
+
70
+ This motivated stricter downstream evaluation gates.
71
+
72
+ ## 6. Base Student Pivot
73
+
74
+ Several runs tested whether distilling into an already-instruct-tuned student caused destructive interference. The base-student hypothesis was sound: a raw base model has more plasticity and fewer alignment paths to overwrite.
75
+
76
+ The result was only a marginal improvement under offline top-k KD. That was the decisive clue.
77
+
78
+ Conclusion:
79
+
80
+ The student choice was not the main bottleneck. Offline top-k sparsity was the main bottleneck.
81
+
82
+ ## 7. Offline Top-K Ceiling
83
+
84
+ With $k = 8$, the student saw only a tiny fraction of the teacher vocabulary distribution per target token:
85
+
86
+ $$
87
+ \frac{k}{|V|}
88
+ = \frac{8}{151{,}665}
89
+ \approx 5.3 \times 10^{-5}
90
+ = 0.0053\%
91
+ $$
92
+
93
+ Different $\alpha$ values, epochs, and student initializations did not remove this limit.
94
+
95
+ Offline top-k KD could perturb the student and sometimes improve narrow metrics, but it could not reliably transfer the teacher's broader reasoning distribution.
96
+
97
+ The project stopped treating offline top-k KD as the path to a flagship model.
98
+
99
+ ## 8. Online Full-Vocabulary KD
100
+
101
+ ![Offline vs Online KD](../assets/offline_vs_online_kd.svg)
102
+
103
+ Online KD became the final architecture.
104
+
105
+ Instead of reading cached teacher shards, the training loop loads a frozen teacher and runs live teacher forward passes beside the student. The KD loss uses the teacher's full-vocabulary distribution.
106
+
107
+ Benefits:
108
+
109
+ - No top-k sparsity ceiling.
110
+ - No shard-order mismatch risk.
111
+ - No stale teacher-logit cache.
112
+ - Stronger transfer signal for reasoning.
113
+
114
+ Cost:
115
+
116
+ - Higher VRAM footprint.
117
+ - Teacher and student must fit together.
118
+ - KL computation needs chunking.
119
+ - Throughput depends heavily on packing and kernels.
120
+
121
+ ## 9. Sequence Packing And B200 Tuning
122
+
123
+ Sequence packing converted padding waste into useful tokens.
124
+
125
+ The packing implementation:
126
+
127
+ - Packs only training data.
128
+ - Keeps validation easier to interpret.
129
+ - Uses fixed 4096-token bins.
130
+ - Inserts masked EOS separators.
131
+ - Stores packing metadata in checkpoints.
132
+ - Rejects packed/unpacked resume mismatches.
133
+
134
+ Development probes showed the expected utilization improvement and made online KD fast enough for serious single-GPU runs.
135
+
136
+ ## 10. English-Only Final Data
137
+
138
+ The release run focuses on English samples.
139
+
140
+ Reasons:
141
+
142
+ - Reduce language drift in open-ended outputs.
143
+ - Keep the model's assistant behavior aligned with the intended release language.
144
+ - Make qualitative evaluation cleaner.
145
+ - Avoid CJK continuation artifacts after missed EOS.
146
+
147
+ The tradeoff is real: removing multilingual data can reduce access to some reasoning traces. For a public English assistant, language stability is worth that tradeoff.
148
+
149
+ ## 11. Targeted SFT After KD
150
+
151
+ Online KD transferred capability, but raw KD is not a full assistant-alignment process.
152
+
153
+ Targeted SFT was added after KD to improve:
154
+
155
+ - identity grounding,
156
+ - chat format stability,
157
+ - practical assistant style,
158
+ - repetition control,
159
+ - response presentation.
160
+
161
+ This created the final two-stage public model:
162
+
163
+ ```text
164
+ Qwen3-1.7B-Base
165
+ -> online full-vocab KD from Qwen3-8B
166
+ -> targeted SFT
167
+ -> Quintus-1.7B
168
+ ```
169
+
170
+ ## 12. Release Verification
171
+
172
+ The final release surface combines:
173
+
174
+ - benchmark scoreboard,
175
+ - architecture documentation,
176
+ - evaluation methodology notes,
177
+ - pipeline hardening notes,
178
+ - weight audit,
179
+ - model-card draft.
180
+
181
+ The public docs focus on reusable methods, release results, and reproducible checks.
docs/huggingface_model_card.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quintus-1.7B
2
+
3
+ Quintus-1.7B is a compact instruction-following assistant derived from `Qwen/Qwen3-1.7B-Base`. It was trained with online full-vocabulary knowledge distillation from a larger Qwen3-8B teacher, followed by targeted SFT for assistant behavior and generation stability.
4
+
5
+ ## Model Details
6
+
7
+ - Base architecture: Qwen3-1.7B
8
+ - Base checkpoint: `Qwen/Qwen3-1.7B-Base`
9
+ - Distillation teacher: Qwen3-8B class teacher
10
+ - Training method: Online full-vocabulary KD + targeted SFT
11
+ - Context length used in training: 4096 tokens
12
+ - Primary language focus: English
13
+ - Release repository: `iamrahulreddy/Quintus`
14
+ - Attention path: FlashAttention-2 when available
15
+ - Training kernels: Liger kernels for compatible Qwen-family operators
16
+ - Optimizer: fused AdamW
17
+
18
+ ## Intended Use
19
+
20
+ Quintus is intended for:
21
+
22
+ - General assistant use.
23
+ - Reasoning and math prompts.
24
+ - Lightweight coding assistance.
25
+ - Local experimentation with compact LLMs.
26
+ - Research into online KD and small-model alignment.
27
+
28
+ It is not intended as a safety-critical decision system. Like other compact language models, it can hallucinate and should be verified on high-stakes tasks.
29
+
30
+ ## Training Summary
31
+
32
+ The training pipeline has two main stages:
33
+
34
+ 1. Online KD: The student learns from the teacher's dense full-vocabulary probability distribution. This avoids the sparse top-k ceiling encountered in earlier offline KD experiments.
35
+ 2. SFT: The distilled checkpoint is tuned on curated instruction/persona data to improve assistant-style behavior and reduce repetition or formatting drift.
36
+
37
+ The KD loss combines assistant-token cross entropy and teacher-student KL divergence:
38
+
39
+ $$
40
+ \mathcal{L}_{\text{total}}
41
+ = \alpha \mathcal{L}_{\text{CE}}
42
+ + (1 - \alpha)\mathcal{L}_{\text{KD}}
43
+ $$
44
+
45
+ For the release run, $\alpha = 0.3$ and $T = 2.0$.
46
+
47
+ `torch.compile` was kept disabled for the final KD path because this workload showed high Inductor memory overhead, dynamic-shape graph breaks, recompile overhead, and checkpoint portability risk from `_orig_mod.` state-dict prefixes when compiled modules are not unwrapped before saving.
48
+
49
+ ## Evaluation
50
+
51
+ | Benchmark | Qwen3-1.7B-Base | Qwen3-1.7B-Instruct | Quintus-1.7B |
52
+ | :--- | :---: | :---: | :---: |
53
+ | HumanEval pass@1 | 67.1% | 70.7% | 67.7% |
54
+ | MBPP pass@1 | 67.2% | 58.2% | 64.8% |
55
+ | GSM8K, 10-shot flexible | 69.98% | 69.75% | 74.30% |
56
+ | ARC-Challenge acc_norm | 55.72% | 52.99% | 58.36% |
57
+ | WinoGrande, 5-shot | 65.67% | 61.01% | 66.38% |
58
+ | PIQA acc_norm | 75.63% | 72.09% | 75.57% |
59
+
60
+ ## Strengths
61
+
62
+ - Strong math and reasoning transfer for the 1.7B parameter scale.
63
+ - Good commonsense and ARC-style benchmark performance.
64
+ - Compact enough for lower-resource deployment compared with larger teachers.
65
+ - Public weight audit indicates healthy structural divergence from the base checkpoint without collapse.
66
+
67
+ ## Limitations
68
+
69
+ - The model can still produce confident factual errors.
70
+ - Code generation can contradict stated complexity constraints.
71
+ - It is smaller than the teacher and inherits capacity limits of the 1.7B scale.
72
+ - Evaluation results depend on prompt format; raw and chat-template modes are not interchangeable.
73
+ - Additional preference tuning would likely improve calibration and refusal behavior.
74
+
75
+ ## Example Usage
76
+
77
+ ```python
78
+ import torch
79
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
80
+
81
+ PUBLIC_REPO_ID = "iamrahulreddy/Quintus"
82
+
83
+ print(f"Loading Quintus from {PUBLIC_REPO_ID}...")
84
+ tokenizer = AutoTokenizer.from_pretrained(PUBLIC_REPO_ID, trust_remote_code=True)
85
+ model = AutoModelForCausalLM.from_pretrained(
86
+ PUBLIC_REPO_ID,
87
+ device_map="auto",
88
+ dtype=torch.float16,
89
+ trust_remote_code=True,
90
+ )
91
+
92
+ stop_tokens = ["<|endoftext|>", "<|im_end|>"]
93
+ eos_token_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
94
+ for token in stop_tokens:
95
+ token_id = tokenizer.convert_tokens_to_ids(token)
96
+ if token_id is not None and token_id not in eos_token_ids:
97
+ eos_token_ids.append(token_id)
98
+
99
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
100
+
101
+ conversation_history = [
102
+ {
103
+ "role": "system",
104
+ "content": (
105
+ "You are Quintus, a highly capable AI assistant created by "
106
+ "Muskula Rahul. You are helpful, precise, and logically sound."
107
+ ),
108
+ }
109
+ ]
110
+
111
+ print()
112
+ print("Quintus Chat (type 'quit' to exit)")
113
+ print()
114
+
115
+ while True:
116
+ try:
117
+ user_input = input("You: ").strip()
118
+ if user_input.lower() in ["quit", "exit"]:
119
+ print("\nGoodbye!")
120
+ break
121
+ if not user_input:
122
+ continue
123
+
124
+ conversation_history.append({"role": "user", "content": user_input})
125
+
126
+ prompt = tokenizer.apply_chat_template(
127
+ conversation_history,
128
+ tokenize=False,
129
+ add_generation_prompt=True,
130
+ )
131
+
132
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
133
+
134
+ print("Quintus: ", end="", flush=True)
135
+
136
+ with torch.no_grad():
137
+ outputs = model.generate(
138
+ **inputs,
139
+ max_new_tokens=512,
140
+ temperature=0.7,
141
+ top_p=0.9,
142
+ do_sample=True,
143
+ streamer=streamer,
144
+ pad_token_id=tokenizer.eos_token_id,
145
+ eos_token_id=eos_token_ids,
146
+ )
147
+
148
+ generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
149
+ assistant_response = tokenizer.decode(
150
+ generated_ids,
151
+ skip_special_tokens=True,
152
+ ).strip()
153
+ conversation_history.append({"role": "assistant", "content": assistant_response})
154
+ print()
155
+
156
+ except KeyboardInterrupt:
157
+ print("\n\nGoodbye!")
158
+ break
159
+ ```
160
+
161
+ ## Credits
162
+
163
+ - [Qwen Team](https://qwenlm.github.io/) and the [Qwen Hugging Face organization](https://huggingface.co/Qwen) for the Qwen3 model family.
164
+ - [`Qwen/Qwen3-8B`](https://huggingface.co/Qwen/Qwen3-8B), used as the distillation teacher.
165
+ - [`Qwen/Qwen3-1.7B-Base`](https://huggingface.co/Qwen/Qwen3-1.7B-Base), used as the base student checkpoint.
166
+ - [`Qwen/Qwen3-1.7B`](https://huggingface.co/Qwen/Qwen3-1.7B), used for the tokenizer and chat-template contract.
167
+ - [Alibaba PAI](https://huggingface.co/alibaba-pai) for [`DistilQwen_100k`](https://huggingface.co/datasets/alibaba-pai/DistilQwen_100k), the primary instruction source after filtering.
168
+ - [Hugging Face Transformers](https://github.com/huggingface/transformers), [vLLM](https://github.com/vllm-project/vllm), [EvalPlus](https://github.com/evalplus/evalplus), [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness), [FlashAttention](https://github.com/Dao-AILab/flash-attention), and [Liger Kernel](https://github.com/linkedin/Liger-Kernel) for training and evaluation infrastructure.
169
+
170
+ ## License And Author
171
+
172
+ This software is distributed under the MIT License. Refer to the repository [LICENSE](../LICENSE) file for full text.
173
+
174
+ Author: Muskula Rahul - [@iamrahulreddy](https://github.com/iamrahulreddy)
175
+
176
+ ## Citation
177
+
178
+ If you use this model or code, cite the repository and the upstream Qwen3 models.
docs/index.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quintus Documentation
2
+
3
+ Quintus-1.7B is a compact assistant built from the Qwen3-1.7B-Base architecture. The project uses online full-vocabulary knowledge distillation from a Qwen3-8B teacher, followed by targeted SFT for instruction style, identity grounding, and generation stability.
4
+
5
+ This documentation summarizes the public architecture, training decisions, evaluation controls, and release artifacts for the showcase branch.
6
+
7
+ ## Reading Order
8
+
9
+ - [Architecture](architecture.md): End-to-end pipeline, modules, data flow, and training phases.
10
+ - [Experiment Timeline](experiment_timeline.md): How the project moved from offline top-k KD to final online full-vocabulary KD.
11
+ - [Training Playbook](training_playbook.md): Practical training choices, memory rules, packing, kernels, and checkpointing.
12
+ - [Pipeline Hardening](pipeline_hardening.md): Silent-failure classes and the safeguards added around artifacts, provenance, and runtime.
13
+ - [Evaluation Methodology](evaluation_methodology.md): Benchmark controls, parser traps, raw/chat comparisons, and qualitative evaluation rules.
14
+ - [Engineering Insights](engineering_insights.md): Condensed technical lessons and design decisions.
15
+ - [Benchmarks](benchmarks.md): Verified evaluation results and interpretation.
16
+ - [Weight Audit](weight_audit.md): Structural checkpoint verification and what the audit means.
17
+ - [Hugging Face Model Card](huggingface_model_card.md): Release-page text for the public model card.
18
+
19
+ ## Project Summary
20
+
21
+ The core thesis is simple: a small base model can absorb useful reasoning behavior from a larger instruction model if the distillation signal is dense enough and the evaluation controls are fair.
22
+
23
+ The project initially explored sparse offline top-k distillation, but that approach hit a ceiling because the student only saw a tiny fraction of the teacher vocabulary distribution. The final pipeline pivots to online KD, where teacher and student are run together and the student receives the teacher's full-vocabulary probability distribution during training.
24
+
25
+ After KD, a small SFT stage teaches the model how to expose that knowledge in a conversational interface. This separation matters: KD transfers capability; SFT and later preference training improve behavior, style, and confidence calibration.
26
+
27
+ ## Repository Map
28
+
29
+ ```text
30
+ configs/ Training configuration and DeepSpeed template.
31
+ src/ Online KD, data loading, losses, checkpointing, and packing.
32
+ sft/ Post-KD supervised fine-tuning, chat, and consolidated evaluation runner.
33
+ weight_audit/ Checkpoint structure and weight-divergence audit.
34
+ docs/ Public architecture, training, evaluation, and release notes.
35
+ ```
36
+
37
+ ## Main Public Artifact
38
+
39
+ The final model weights are available at: [Quintus](https://huggingface.co/iamrahulreddy/Quintus)
40
+
41
+ The Colab quickstart is available at: [Colab Quick Chat](https://colab.research.google.com/drive/1TdMSN5HzD1mToCFVf_qQoj10NGZLy2V0?usp=sharing)
42
+
docs/pipeline_hardening.md ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pipeline Hardening
2
+
3
+ ![Pipeline Hardening Flow](../assets/pipeline_hardening_flow.svg)
4
+
5
+ This page summarizes the correctness and reliability lessons that shaped the Quintus codebase. Most of these are silent-failure classes: the pipeline can appear to run while producing invalid or misleading artifacts.
6
+
7
+ ## Silent Serialization Bugs
8
+
9
+ Teacher token IDs must be stored in a dtype that can represent the tokenizer vocabulary.
10
+
11
+ An early offline-KD path stored top-k token IDs too narrowly. Qwen token IDs exceed signed 16-bit range, so IDs could wrap negative and later be clamped into valid-looking but wrong positions. Training could continue, but the KL support was corrupted.
12
+
13
+ Hardening rule:
14
+
15
+ - Store token IDs as `int32` or wider.
16
+ - Validate IDs on load.
17
+ - Reject negative IDs.
18
+ - Reject IDs outside the student vocabulary.
19
+ - Treat dtype as part of the shard contract.
20
+
21
+ ## Row-Order Preservation
22
+
23
+ Teacher-logit extraction often sorts samples by length for throughput. Training usually expects logits to match the original tokenized row order.
24
+
25
+ If sorted extraction writes shards in sorted order without restoring original indices, the student receives teacher logits for the wrong sample. This is a model-poisoning bug, not a performance issue.
26
+
27
+ Hardening rule:
28
+
29
+ - Batch by sorted length if useful.
30
+ - Preserve `original_idx`.
31
+ - Write final shards in original dataset order.
32
+ - Verify teacher-logit length against the tokenized row length at training time.
33
+
34
+ ## Dataset Schema And Decoding
35
+
36
+ Public instruction datasets do not share a single row schema. Some rows arrive as `messages`; others use Alpaca-style `instruction`, `input`, and `output` fields. Some content fields contain nested dict/list payloads that need structured coercion before templating.
37
+
38
+ Dataset streaming can also fail late when a compression codec or file decoder is missing. That failure should remain visible instead of being replaced by a generic "zero samples" result.
39
+
40
+ Hardening rule:
41
+
42
+ - Detect Alpaca-style instruction/output rows before chat-message conversion.
43
+ - Coerce nested dict/list content through structured serialization, then normalize to text.
44
+ - Normalize common role aliases before applying a chat template.
45
+ - Preserve the first real dataset exception when streaming fails.
46
+ - Validate dataset decoding and schema mapping before large model downloads.
47
+
48
+ ## Zero-Data And Data-Erasure Guards
49
+
50
+ Data preparation should fail when no usable rows are produced. It should also distinguish "download only" from "tokenize and overwrite output".
51
+
52
+ Hardening rule:
53
+
54
+ - Abort if filtering retains zero samples.
55
+ - Abort if tokenization writes zero rows.
56
+ - Do not open tokenized output in write mode for asset-only setup.
57
+ - Use explicit flags for model-only or data-only phases.
58
+
59
+ ## Missing Shards Must Fail
60
+
61
+ Replacing missing teacher-logit shards with zero tensors makes the training loop look healthy while removing the KD signal.
62
+
63
+ Hardening rule:
64
+
65
+ - Missing shard means hard failure.
66
+ - Stale shard directories are cleaned before extraction.
67
+ - `_provenance.json` is required for KD.
68
+ - Shard count, sample count, max sequence length, temperature, top-k, and schema version are checked before training.
69
+
70
+ ## Provenance Contracts
71
+
72
+ Path equality is weak provenance because paths change across machines. Data identity should come from content and model contracts.
73
+
74
+ Useful provenance fields:
75
+
76
+ - schema version
77
+ - dataset fingerprint or SHA-256
78
+ - sample count
79
+ - shard count
80
+ - max sequence length
81
+ - top-k or full-vocab mode
82
+ - temperature
83
+ - teacher model ID and revision
84
+ - student model ID and revision
85
+ - tokenizer sizes
86
+ - tokenizer fingerprints
87
+ - shard dtypes
88
+
89
+ Tokenizer fingerprints can drift across library versions. Vocab size and schema compatibility should remain hard gates; fingerprint drift can be a warning when stronger invariants still match.
90
+
91
+ ## Assistant-Only Loss Masks
92
+
93
+ Supervising prompt and chat-template tokens can teach formatting before substance. It can also make chat-mode behavior fragile.
94
+
95
+ Hardening rule:
96
+
97
+ - Tokenized rows must include `loss_mask`.
98
+ - Loss mask must be binary.
99
+ - Rows with zero assistant targets are rejected.
100
+ - User prompts, system prompts, separators, and padding are not targets.
101
+ - Assistant response tokens are the supervised region.
102
+
103
+ Prefix-stable mask derivation is useful when tokenizer-provided assistant masks are unavailable.
104
+
105
+ ## Gradient Accumulation Semantics
106
+
107
+ DeepSpeed and non-DeepSpeed paths need different step-accounting logic.
108
+
109
+ DeepSpeed accumulation is global across the full run, not local to each epoch. Epoch-end remainder branches should not create phantom optimizer steps.
110
+
111
+ Non-DeepSpeed accumulation needs an explicit final flush when a leftover accumulation window exists. That flush must rescale gradients so the update represents the mean over the remainder, not a shrunken `remainder / grad_accum` update.
112
+
113
+ Hardening rule:
114
+
115
+ - Advance `global_step` only after a real optimizer update.
116
+ - Align scheduler steps with real updates.
117
+ - Log flush steps.
118
+ - Include flush steps in training-loss CSVs.
119
+ - Prefer validation split sizes that align with effective batch size.
120
+
121
+ ## Checkpoint Semantics
122
+
123
+ `init_from_checkpoint` and `resume_from_checkpoint` are different operations.
124
+
125
+ - Initialization starts a new phase from an existing model.
126
+ - Resume continues an interrupted phase from training state.
127
+
128
+ Mixing the two can skip training, restart from the wrong model, or reuse stale state.
129
+
130
+ Hardening rule:
131
+
132
+ - Forbid simultaneous init and resume.
133
+ - Save trainer state and scheduler state.
134
+ - Search both `step_*` and `epoch_*` checkpoints for resume.
135
+ - Store batch offset for mid-epoch resume.
136
+ - Keep final model-loading checkpoints portable.
137
+
138
+ ## Compiler Portability
139
+
140
+ Compiled PyTorch modules can save weights with `_orig_mod.` prefixes if not unwrapped. Standard Transformers and vLLM loaders do not expect those keys.
141
+
142
+ Hardening rule:
143
+
144
+ - Keep `torch.compile` opt-in.
145
+ - Treat dynamic-shape recompile overhead as a throughput risk, not just a startup cost.
146
+ - Unwrap compiled modules before saving.
147
+ - Strip `_orig_mod.` only as a repair path, not as the normal release path.
148
+ - Verify saved checkpoints load through standard APIs.
149
+
150
+ ## Artifact Hygiene
151
+
152
+ Stale outputs are a real ML correctness problem. Old result JSONs, old plots, or old sample logs can make a failed run look successful.
153
+
154
+ Hardening rule:
155
+
156
+ - Clean evaluation output directories before a new run.
157
+ - Clean stale plots before rendering.
158
+ - Select result files by clear recency rules.
159
+ - Fail if expected task outputs are incomplete.
160
+ - Fail if a requested checkpoint is missing; do not fall back to older local weights.
161
+ - Include runtime versions in result summaries.
162
+
163
+ ## Environment Contracts
164
+
165
+ Notebook and cloud images often contain mixed binary packages. Import success for `torch` alone does not prove the stack is healthy.
166
+
167
+ Hardening rule:
168
+
169
+ - Treat `torch`, `torchvision`, and `torchaudio` as one binary compatibility family.
170
+ - Use staged dependency manifests instead of ad hoc installs.
171
+ - Keep vLLM dependencies separate from HF-only evaluation dependencies.
172
+ - Prefer clear preflight errors over late framework crashes.
173
+ - Print exception chains, not only the outer error.
174
+
175
+ ## Remote Code And Revisions
176
+
177
+ Model loading should be reproducible and explicit.
178
+
179
+ Hardening rule:
180
+
181
+ - Pin teacher, student, and tokenizer revisions when possible.
182
+ - Default remote-code trust to false.
183
+ - Provide an explicit override for models that need custom code.
184
+ - Explain remote-code failures clearly.
185
+
186
+ ## Safe Logging
187
+
188
+ Training logs should be rich enough for issue diagnosis without dumping config internals.
189
+
190
+ Hardening rule:
191
+
192
+ - Avoid logging authentication values or full config payloads.
193
+ - Disable traceback local-variable dumps in rich tracebacks.
194
+ - Strip ANSI sequences from file logs while keeping colored notebook output if desired.
195
+ - Use UTF-8 file logs and replacement-safe console output for generated model text.
196
+ - Log checkpoint save/upload intent, output size, duration, and destination path without sensitive values.
197
+
198
+ ## Public Release Rule
199
+
200
+ A project can be release-ready without every possible production safeguard. The line is crossed when:
201
+
202
+ - known silent corruption paths are removed,
203
+ - remaining tradeoffs are documented,
204
+ - artifacts are reproducible enough to audit,
205
+ - public docs focus on decisions, methods, and release artifacts,
206
+ - evaluation claims are tied to clear methodology.
207
+
208
+ For Quintus, the release surface should describe the engineering decisions and results.
docs/training_playbook.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Playbook
2
+
3
+ This page captures the practical training lessons behind Quintus. It focuses on the engineering decisions that made the final online-KD run stable, reproducible, and fast enough to complete on large single-GPU hardware.
4
+
5
+ ## Core Objective
6
+
7
+ The training objective combines assistant-token cross entropy with teacher-student KL divergence:
8
+
9
+ $$
10
+ \mathcal{L}_{\text{total}}
11
+ = \alpha \mathcal{L}_{\text{CE}}
12
+ + (1 - \alpha)\mathcal{L}_{\text{KD}}
13
+ $$
14
+
15
+ For the final Qwen3 run:
16
+
17
+ $$
18
+ \alpha = 0.3,\quad
19
+ T = 2.0,\quad
20
+ C_{\text{KD}} = 2048,\quad
21
+ S_{\max} = 4096
22
+ $$
23
+
24
+ In this codebase, $\alpha$ is the cross-entropy weight. Lower $\alpha$ gives the teacher distribution more influence. Higher $\alpha$ gives hard assistant targets more influence.
25
+
26
+ ## Why Online KD Replaced Offline Top-K KD
27
+
28
+ The early pipeline precomputed only a small top-k slice of the teacher distribution. That made storage and training cheaper, but it created a hard information ceiling.
29
+
30
+ With a Qwen vocabulary around 151K tokens:
31
+
32
+ $$
33
+ \frac{k}{|V|}
34
+ = \frac{8}{151{,}665}
35
+ \approx 5.3 \times 10^{-5}
36
+ = 0.0053\%
37
+ $$
38
+
39
+ That sparse signal was enough to disturb student weights, but not enough to reliably transfer deeper reasoning behavior. Several development probes changed alpha, epochs, and student initialization; the same ceiling remained.
40
+
41
+ The final online path removes that bottleneck. Teacher and student run together, and the KL term is computed from the live full-vocabulary teacher distribution.
42
+
43
+ ## Memory Shape To Respect
44
+
45
+ Full-vocabulary KD is dominated by logits:
46
+
47
+ $$
48
+ \text{student\_logits},\ \text{teacher\_logits}
49
+ \in \mathbb{R}^{B \times S \times |V|}
50
+ $$
51
+
52
+ At Qwen vocabulary scale, increasing micro-batch size by one can add many GiB of temporary memory pressure. Effective batch size is not the same as memory cost. Peak memory is mostly driven by micro-batch size, sequence length, vocabulary width, activation storage, and the backward pass.
53
+
54
+ Useful rule:
55
+
56
+ $$
57
+ B_{\text{eff}} = B_{\mu} \times A
58
+ $$
59
+
60
+ Keeping $B_{\mu}$ lower and $A$ higher is often safer than a large micro-batch with the same effective batch size.
61
+
62
+ ## Token Chunking
63
+
64
+ A naive full-vocabulary KL implementation materializes too much temporary state. Quintus computes KD over token chunks:
65
+
66
+ $$
67
+ C_{\text{KD}} = 2048
68
+ $$
69
+
70
+ Larger chunks reduce loop overhead but increase temporary memory. Smaller chunks save memory but can add kernel-launch and Python overhead. The final value is a B200-oriented balance for the 8B -> 1.7B workload.
71
+
72
+ ## Sequence Packing
73
+
74
+ Sequence packing was the largest throughput win in development probes.
75
+
76
+ The packing strategy:
77
+
78
+ - Sort samples by length descending.
79
+ - Pack samples with deterministic first-fit decreasing binning.
80
+ - Insert EOS separators between samples.
81
+ - Set separator `loss_mask = 0`.
82
+ - Optionally mask the first token after each separator.
83
+ - Build `attention_mask` from true packed length, not from token identity.
84
+
85
+ The attention-mask detail matters because Qwen tokenizers can share EOS-like IDs with padding behavior. Deriving attention from `input_ids != pad_token_id` can accidentally mask real EOS separators inside packed rows.
86
+
87
+ Packing probes showed an unpacked B200 online-KD baseline around the low-20K tokens/sec range. Packed training reached roughly the mid-40K tokens/sec range after warmup. The final Qwen3 profile uses the same design principle with a conservative 8B -> 1.7B batch shape.
88
+
89
+ ## B200-Oriented Final Shape
90
+
91
+ The Qwen3 config is intentionally conservative:
92
+
93
+ $$
94
+ B_{\mu}=4,\quad
95
+ A=2,\quad
96
+ B_{\text{eff}}=8,\quad
97
+ L_{\text{pack}}=4096
98
+ $$
99
+
100
+ Runtime choices:
101
+
102
+ - `gradient_checkpointing = false`
103
+ - `compile_model = false`
104
+ - `fused_adamw = true`
105
+ - `sequence_packing.enabled = true`
106
+ - FlashAttention-2 when available
107
+ - Liger kernels for compatible Qwen-family operators
108
+
109
+ The main reason is the 8B teacher plus 1.7B student online-KD footprint. A smaller teacher/student pair can use larger micro-batches, but the release workload reserves more headroom.
110
+
111
+ ## Kernel Choices
112
+
113
+ FlashAttention-2 is the preferred stable attention path when available.
114
+
115
+ Liger kernels are useful for Qwen-family training, but KD places an important constraint on fusion:
116
+
117
+ - Safe to fuse: RMSNorm, RoPE, SwiGLU.
118
+ - Avoid for KD: fused linear cross entropy that hides raw student logits.
119
+
120
+ The KD loss needs raw student logits to compute teacher-student KL. Any optimization that bypasses logits entirely can break the objective.
121
+
122
+ ## Why `torch.compile` Stayed Off
123
+
124
+ `torch.compile` can be useful for some SFT paths, but it was not the production choice for final KD.
125
+
126
+ Observed risks:
127
+
128
+ - Large Inductor memory overhead.
129
+ - Warmup cost on short-lived cloud instances.
130
+ - Dynamic-shape graph breaks from variable sequence lengths.
131
+ - Recompile overhead that reduced cumulative throughput in probes.
132
+ - `_orig_mod.` prefixes in saved checkpoints if compiled modules are not unwrapped before saving.
133
+ - Limited benefit after FlashAttention and Liger already fuse the major kernels.
134
+
135
+ For this workload, stable eager execution with targeted kernels was more predictable than compiler-driven fusion.
136
+
137
+ ## DataLoader And Cloud Stability
138
+
139
+ Large worker counts can improve throughput on local systems, but notebook and cloud environments can deadlock through multiprocessing queues, IPC limits, or shared-memory pressure.
140
+
141
+ Practical policy:
142
+
143
+ - Start with conservative worker and prefetch settings.
144
+ - Treat a silent training hang as a DataLoader candidate, even when GPU utilization remains high.
145
+ - For some cloud notebook runs, `dataloader_workers = 0` was the most stable choice.
146
+ - For the release config, `dataloader_workers = 8` and `prefetch_factor = 2` are a controlled default, not a universal rule.
147
+
148
+ ## Checkpointing And Resume
149
+
150
+ Cloud GPUs are preemptible and notebook sessions disappear. The training loop therefore treats checkpointing as a core training feature, not an afterthought.
151
+
152
+ Important design points:
153
+
154
+ - `best` is selected from validation loss where available.
155
+ - `last` is saved for final-state inspection.
156
+ - Step checkpoints can resume mid-epoch.
157
+ - Scheduler state is saved.
158
+ - Optimizer state may be intentionally omitted for very large runs to avoid massive checkpoint overhead.
159
+ - Resume semantics distinguish initialization from a completed checkpoint and continuation from an interrupted checkpoint.
160
+
161
+ This avoids the common trap where `resume_from_checkpoint` silently starts from the wrong phase or stale state.
162
+
163
+ ## Provenance Rules
164
+
165
+ The pipeline is strict about artifact compatibility:
166
+
167
+ - Tokenizer vocabulary sizes must match the model contract.
168
+ - Teacher-logit metadata must match expected temperature, sample count, max sequence length, and tokenizer/model identity.
169
+ - Dataset fingerprints are preferred over path equality because paths are machine-local.
170
+ - Tokenizer fingerprints can drift across library versions, so hard checks should focus on vocab-size and schema invariants.
171
+
172
+ The principle is simple: train only when artifacts prove they belong together.
173
+
174
+ ## Dataset Sampling
175
+
176
+ Taking the first N valid streamed examples can bias a run if the upstream dataset is ordered by source, task, difficulty, or language. Later configs added stream shuffling before selection.
177
+
178
+ The config uses a non-default seed:
179
+
180
+ ```text
181
+ stream_shuffle_seed = 25
182
+ split_seed = 25
183
+ ```
184
+
185
+ The number is intentionally explicit. Reproducibility needs stable seeds; it does not require the overused value `42`.
186
+
187
+ ## Practical Watchpoints
188
+
189
+ During a run, these signals matter more than a single loss number:
190
+
191
+ - Loss stays finite from the first logging window.
192
+ - CE and KD move in plausible ranges.
193
+ - Rolling throughput remains stable after warmup.
194
+ - GPU memory is high but not near an unpredictable OOM edge.
195
+ - Validation loss is computed on the intended holdout.
196
+ - Saved checkpoints load in standard Transformers and vLLM paths.
197
+ - Downstream benchmark results agree with the training story.
198
+
199
+ Held-out KD loss is useful, but it is not the release gate. Standardized benchmarks and qualitative checks must decide whether the checkpoint improved the target behavior.
docs/weight_audit.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Weight Audit
2
+
3
+ The `weight_audit/` directory contains a structural audit script and a generated report comparing the final distilled checkpoint against `Qwen/Qwen3-1.7B-Base`.
4
+
5
+ The audit is not a behavioral benchmark. It answers a narrower question: is the checkpoint structurally intact, same-architecture, and plausibly modified by training without signs of collapse?
6
+
7
+ ## What Was Checked
8
+
9
+ The audit verifies:
10
+
11
+ - Base and distilled checkpoint commits.
12
+ - Architecture and config compatibility.
13
+ - Parameter counts and tensor keys.
14
+ - Weight tying between embeddings and LM head.
15
+ - Per-tensor statistics.
16
+ - Layer-type aggregate statistics.
17
+ - Isotropy of 2D weight matrices.
18
+ - Base-vs-distilled divergence for all shared tensors.
19
+ - Sparsity, dead rows, low cosine similarity, and low SNR warnings.
20
+
21
+ ## Headline Result
22
+
23
+ The final report shows:
24
+
25
+ ```text
26
+ shared tensors : 311
27
+ tensors changed vs base : 277 / 311
28
+ cosine similarity : mean = 0.999991 | median = 0.999992
29
+ relative error : mean = 0.001093 | median = 0.001293
30
+ SNR dB : mean = 81.86 | median = 47.79
31
+ high-sparsity layers (>10%) : 0
32
+ heavy-tail layers (|kurt_d|>5.0) : 0
33
+ dead-row layers : 0
34
+ low-cos layers (<0.95) : 0
35
+ low-SNR layers (<20 dB) : 0
36
+ ```
37
+
38
+ ## Interpretation
39
+
40
+ This is a healthy pattern for light-touch distillation:
41
+
42
+ - The architecture is unchanged.
43
+ - Most tensors changed.
44
+ - The changes are small relative to the original base weights.
45
+ - Projection matrices, embeddings, and MLP/attention layers moved.
46
+ - Some normalization tensors remained unchanged or changed only slightly.
47
+ - No layer shows obvious structural collapse.
48
+
49
+ The unchanged tensors are primarily normalization-related weights. That is not concerning by itself. It suggests the main semantic projection weights absorbed the training signal while basic scaling structure stayed stable.
50
+
51
+ ## Why Isotropy Matters
52
+
53
+ The report's global isotropy score is close to zero. Near-zero average pairwise row cosine means the weight rows are not collapsing into one shared direction.
54
+
55
+ This is useful as a sanity check after KD. A collapsed model can sometimes load and produce text, but its internal geometry becomes degenerate. The audit does not show that pattern.
56
+
57
+ ## What The Audit Does Not Prove
58
+
59
+ The weight audit does not prove that answers are correct, safe, or well calibrated. It should be read alongside:
60
+
61
+ - Standard benchmarks.
62
+ - Open-ended qualitative evaluations.
63
+ - SFT evaluation outputs.
64
+ - Manual regression prompts.
65
+
66
+ The audit says the checkpoint is structurally ready for downstream evaluation and release packaging.
requirements-eval.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ -r requirements.txt
2
+
3
+ # Consolidated benchmark runner dependencies.
4
+ evalplus>=0.3
5
+ lm-eval>=0.4.8
6
+ vllm>=0.8; platform_system == "Linux"
requirements-train.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -r requirements.txt
2
+
3
+ # Optional but recommended for the documented high-throughput training path.
4
+ # These packages are Linux/CUDA-oriented and may require matching compiler,
5
+ # CUDA, and PyTorch builds.
6
+ liger-kernel>=0.5
7
+ flash-attn>=2.7; platform_system == "Linux"
8
+ deepspeed>=0.16; platform_system == "Linux"
9
+
10
+ # Optional SFT/QLoRA paths in sft/train_sft.py.
11
+ peft>=0.14
12
+ bitsandbytes>=0.45; platform_system == "Linux"
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies for downloading, training entry points, local chat, and
2
+ # lightweight repository utilities. Install CUDA-specific PyTorch wheels from
3
+ # the official PyTorch index when your environment requires a specific CUDA
4
+ # build.
5
+ torch>=2.6
6
+ transformers>=4.52
7
+ datasets>=2.19
8
+ huggingface-hub>=0.31
9
+ omegaconf>=2.3
10
+ PyYAML>=6.0
11
+ safetensors>=0.4
12
+ accelerate>=1.0
13
+ tqdm>=4.66
sft/chat.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
+ import sys
4
+ import argparse
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description="Quintus Interactive Chat")
8
+ parser.add_argument("--model_path", type=str, default="iamrahulreddy/Quintus", help="Model repo ID or local weights directory")
9
+ parser.add_argument("--trust_remote_code", action="store_true", help="Allow custom code from the model repository.")
10
+ args = parser.parse_args()
11
+
12
+ model_path = args.model_path
13
+ print(f"Loading Quintus from {model_path}...")
14
+ try:
15
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=args.trust_remote_code)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_path,
18
+ device_map="auto",
19
+ dtype=torch.float16,
20
+ trust_remote_code=args.trust_remote_code
21
+ )
22
+ except Exception as e:
23
+ print(f"Error loading model: {e}")
24
+ print(f"Ensure '{model_path}' exists and contains the model weights.")
25
+ sys.exit(1)
26
+
27
+ # Defining stopping criteria
28
+ stop_tokens = ["<|endoftext|>", "<|im_end|>"]
29
+ eos_token_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
30
+ for token in stop_tokens:
31
+ t_id = tokenizer.convert_tokens_to_ids(token)
32
+ if t_id is not None and t_id not in eos_token_ids:
33
+ eos_token_ids.append(t_id)
34
+
35
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
36
+
37
+ conversation_history = [
38
+ {"role": "system", "content": "You are Quintus, a highly capable AI assistant created by Muskula Rahul. You are helpful, precise, and logically sound."}
39
+ ]
40
+
41
+ print()
42
+ print("Quintus Chat (type 'quit' to exit)")
43
+ print()
44
+
45
+ while True:
46
+ try:
47
+ user_input = input("You: ").strip()
48
+ if user_input.lower() in ["quit", "exit"]:
49
+ print("\nGoodbye!")
50
+ break
51
+ if not user_input:
52
+ continue
53
+
54
+ conversation_history.append({"role": "user", "content": user_input})
55
+
56
+ prompt = tokenizer.apply_chat_template(
57
+ conversation_history,
58
+ tokenize=False,
59
+ add_generation_prompt=True
60
+ )
61
+
62
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
63
+
64
+ print("Quintus: ", end="", flush=True)
65
+
66
+ with torch.no_grad():
67
+ outputs = model.generate(
68
+ **inputs,
69
+ max_new_tokens=512,
70
+ temperature=0.7,
71
+ top_p=0.9,
72
+ do_sample=True,
73
+ streamer=streamer,
74
+ pad_token_id=tokenizer.eos_token_id,
75
+ eos_token_id=eos_token_ids
76
+ )
77
+
78
+ # Extract response for history
79
+ generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
80
+ assistant_response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
81
+ conversation_history.append({"role": "assistant", "content": assistant_response})
82
+ print()
83
+
84
+ except KeyboardInterrupt:
85
+ print("\n\nGoodbye!")
86
+ break
87
+
88
+ if __name__ == "__main__":
89
+ main()
sft/evaluate.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Automated EvalPlus runner for HumanEval and MBPP benchmarks.
2
+ # Using the vLLM backend in greedy mode.
3
+
4
+ import os
5
+ import sys
6
+ import subprocess
7
+ import time
8
+ import json
9
+ import re
10
+ from pathlib import Path
11
+ from datetime import datetime
12
+ from huggingface_hub import snapshot_download
13
+
14
+ MODELS = [
15
+ {
16
+ "name": "Quintus-1.7B",
17
+ "id": "iamrahulreddy/Quintus",
18
+ "is_local": False
19
+ },
20
+ {
21
+ "name": "Qwen3-1.7B-Instruct",
22
+ "id": "Qwen/Qwen3-1.7B",
23
+ "is_local": False
24
+ },
25
+ {
26
+ "name": "Qwen3-1.7B-Base",
27
+ "id": "Qwen/Qwen3-1.7B-Base",
28
+ "is_local": False
29
+ }
30
+ ]
31
+
32
+ DATASETS = [
33
+ "humaneval", "mbpp", # EvalPlus benchmarks
34
+ "gsm8k", "winogrande", # lm-eval fast benchmarks
35
+ "arc_challenge", "boolq", "piqa"
36
+ ]
37
+ EVALPLUS_DATASETS = {"humaneval", "mbpp"}
38
+
39
+ LM_EVAL_SHOTS = {
40
+ "gsm8k": "10",
41
+ "winogrande": "5",
42
+ "arc_challenge": "25",
43
+ "boolq": "0",
44
+ "piqa": "0"
45
+ }
46
+
47
+ HF_TOKEN = os.environ.get("HF_TOKEN")
48
+ TRUST_REMOTE_CODE = os.environ.get("QUINTUS_TRUST_REMOTE_CODE", "").strip().lower() in {"1", "true", "yes", "on"}
49
+
50
+ def extract_lm_eval_score(results_dir: Path, task: str) -> str:
51
+ """Finds and extracts the primary score from JSON files outputted by lm-evaluation-harness."""
52
+ for json_path in sorted(results_dir.rglob("*.json"), reverse=True):
53
+ try:
54
+ with open(json_path, encoding="utf-8") as fh:
55
+ data = json.load(fh)
56
+ task_results = data.get("results", {})
57
+ for candidate in (task, f"leaderboard_{task}"):
58
+ if candidate in task_results:
59
+ task_data = task_results[candidate]
60
+ # Try common metric names
61
+ for metric in ["acc,none", "acc_norm,none", "exact_match,strict-match", "exact_match,none"]:
62
+ if metric in task_data:
63
+ return f"{task_data[metric]*100:.1f}"
64
+ except Exception:
65
+ continue
66
+ return "N/A"
67
+
68
+ def is_noise(line: str) -> bool:
69
+ l = line.strip()
70
+ if not l:
71
+ return False
72
+ # Progress bar indicators & block characters
73
+ if any(c in l for c in ["█", "━", "╸", "•", "━━━━━━━━"]):
74
+ return True
75
+ # vLLM, ray, flash_attn, huggingface setup/warnings logs
76
+ noise_keywords = [
77
+ "INFO ", "WARNING ", "DEBUG ", "ERROR ", "(EngineCore",
78
+ "Loading safetensors", "Capturing CUDA graphs",
79
+ "Codegen:", "Downloading dataset", "downloading dataset",
80
+ "Initializing a decoder", "Unknown vLLM environment",
81
+ "world_size=", "Using V2 Model Runner", "Model loading took",
82
+ "Using FLASH_ATTN", "Using FlashAttention", "Kernel JIT monitor",
83
+ "autotuner.py", "autotuning", "Autotuning", "loading weights",
84
+ "Loading weights", "Failed to get device capability", "Sanitized code outputs",
85
+ "Raw outputs will be saved", "init engine", "Dynamo bytecode",
86
+ "Directly load the compiled graph", "Directly load AOT compilation", "torch.compile took"
87
+ ]
88
+ if any(k.lower() in l.lower() for k in noise_keywords):
89
+ return True
90
+ # TQDM lines (e.g. 100%|... [00:17<00:00, 9.45it/s])
91
+ if "%|" in l and ("it/s" in l or "s/it" in l):
92
+ return True
93
+ return False
94
+
95
+ def main():
96
+ print("=" * 80)
97
+ print(" EVALPLUS BENCHMARK RUNNER (HUMANEVAL & MBPP)")
98
+ print("=" * 80)
99
+ print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
100
+ print(f"Models to evaluate: {[m['name'] for m in MODELS]}")
101
+ print(f"Datasets: {DATASETS}")
102
+ print("=" * 80)
103
+
104
+ # Set optional HF token and runtime configuration.
105
+ if HF_TOKEN:
106
+ os.environ["HF_TOKEN"] = HF_TOKEN
107
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
108
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
109
+ os.environ["VLLM_MAX_MODEL_LEN"] = "4096"
110
+
111
+ # Step 1: Pre-download and prepare model caches
112
+ print("\n--- STAGE 1: WARMING UP MODEL WEIGHTS CACHE ---")
113
+
114
+ # Cache all models
115
+ for model in MODELS:
116
+ if model["is_local"]:
117
+ continue
118
+ print(f"\n[DOWNLOADING] Fetching cache for {model['name']} ({model['id']})...")
119
+ try:
120
+ snapshot_download(
121
+ repo_id=model["id"],
122
+ token=HF_TOKEN or None
123
+ )
124
+ print(f"[DOWNLOAD SUCCESS] {model['name']} is cached and ready.")
125
+ except Exception as e:
126
+ print(f"[DOWNLOAD WARNING] Could not pre-download model {model['name']} via snapshot_download: {e}")
127
+ print("The evaluation run will attempt to download it directly during execution.")
128
+
129
+ print("\n--- STAGE 2: SEQUENTIAL EVALPLUS EVALUATION ---")
130
+ results = []
131
+
132
+ # Run evaluations sequentially
133
+ for model in MODELS:
134
+ # Resolve path
135
+ model_path = str(Path(model["id"]).resolve()) if model["is_local"] else model["id"]
136
+
137
+ for dataset in DATASETS:
138
+ print(f"\n[STARTING] Evaluating {model['name']} on {dataset}...")
139
+ print("-" * 60)
140
+
141
+ if dataset in EVALPLUS_DATASETS:
142
+ cmd = [
143
+ sys.executable, "-m", "evalplus.evaluate",
144
+ "--model", model_path,
145
+ "--dataset", dataset,
146
+ "--backend", "vllm",
147
+ "--greedy"
148
+ ]
149
+ else:
150
+ shots = LM_EVAL_SHOTS.get(dataset, "0")
151
+ out_dir = Path("eval_results") / model["name"] / dataset
152
+ out_dir.mkdir(parents=True, exist_ok=True)
153
+
154
+ model_args = (
155
+ f"pretrained={model_path},dtype=bfloat16,"
156
+ f"trust_remote_code={str(TRUST_REMOTE_CODE).lower()},"
157
+ "gpu_memory_utilization=0.9,max_model_len=4096"
158
+ )
159
+ cmd = [
160
+ sys.executable, "-m", "lm_eval",
161
+ "--model", "vllm",
162
+ "--model_args", model_args,
163
+ "--tasks", dataset,
164
+ "--num_fewshot", shots,
165
+ "--batch_size", "auto",
166
+ "--output_path", str(out_dir),
167
+ "--log_samples"
168
+ ]
169
+ if dataset == "gsm8k":
170
+ cmd.extend(["--gen_kwargs", "max_gen_toks=512"])
171
+
172
+ print(f"Running command: {' '.join(cmd)}")
173
+ start_time = time.time()
174
+
175
+ try:
176
+ # Run the command and stream output
177
+ process = subprocess.Popen(
178
+ cmd,
179
+ stdout=subprocess.PIPE,
180
+ stderr=subprocess.STDOUT,
181
+ text=True,
182
+ bufsize=1
183
+ )
184
+
185
+ # Stream and capture output (filtering out vLLM and progress bar noise)
186
+ stdout_text = ""
187
+ for line in process.stdout:
188
+ stdout_text += line
189
+ if not is_noise(line):
190
+ print(line, end="")
191
+
192
+ process.wait()
193
+ duration = time.time() - start_time
194
+ time.sleep(5) # Let OS/driver fully reclaim GPU VRAM before starting next subprocess
195
+
196
+ score_str = "N/A"
197
+ if process.returncode == 0:
198
+ print(f"[SUCCESS] Completed {model['name']} on {dataset} in {duration:.1f} seconds.")
199
+
200
+ # Parse scores
201
+ if dataset in EVALPLUS_DATASETS:
202
+ # Find all pass@1 scores
203
+ matches = re.findall(r"pass@1:\s+([0-9.]+)", stdout_text)
204
+ if len(matches) >= 2:
205
+ val0 = float(matches[0])
206
+ val1 = float(matches[1])
207
+ if val0 <= 1.0: val0 *= 100
208
+ if val1 <= 1.0: val1 *= 100
209
+ score_str = f"Base: {val0:.1f} | Plus: {val1:.1f}"
210
+ elif len(matches) == 1:
211
+ val0 = float(matches[0])
212
+ if val0 <= 1.0: val0 *= 100
213
+ score_str = f"Base: {val0:.1f}"
214
+ else:
215
+ score_str = extract_lm_eval_score(out_dir, dataset)
216
+
217
+ results.append({
218
+ "model": model["name"],
219
+ "dataset": dataset,
220
+ "status": "Success",
221
+ "score": score_str,
222
+ "duration": f"{duration/60:.1f} min"
223
+ })
224
+ else:
225
+ print(f"[ERROR] command failed with exit code {process.returncode}")
226
+ results.append({
227
+ "model": model["name"],
228
+ "dataset": dataset,
229
+ "status": f"Failed ({process.returncode})",
230
+ "score": "ERROR",
231
+ "duration": f"{duration/60:.1f} min"
232
+ })
233
+
234
+ except Exception as e:
235
+ duration = time.time() - start_time
236
+ print(f"[ERROR] Failed to run benchmark: {e}")
237
+ results.append({
238
+ "model": model["name"],
239
+ "dataset": dataset,
240
+ "status": f"Error",
241
+ "score": "ERROR",
242
+ "duration": f"{duration/60:.1f} min"
243
+ })
244
+ print("-" * 60)
245
+
246
+ # Print and save summary report
247
+ report_lines = []
248
+ report_lines.append("\n" + "=" * 100)
249
+ report_lines.append(" BENCHMARK RUN SUMMARY")
250
+ report_lines.append("=" * 100)
251
+ report_lines.append(f"| {'Model':<30} | {'Dataset':<15} | {'Score':<25} | {'Status':<10} | {'Time':<8} |")
252
+ report_lines.append(f"|{'-'*32}|{'-'*17}|{'-'*27}|{'-'*12}|{'-'*10}|")
253
+ for r in results:
254
+ report_lines.append(f"| {r['model']:<30} | {r['dataset']:<15} | {r['score']:<25} | {r['status']:<10} | {r['duration']:<8} |")
255
+ report_lines.append("=" * 100)
256
+
257
+ report_text = "\n".join(report_lines)
258
+ print(report_text)
259
+ print("\nNote: Results are saved in the default EvalPlus directory and eval_results/.")
260
+
261
+ # Save to file
262
+ with open("qwen_quintus_scores.txt", "w", encoding="utf-8") as f:
263
+ f.write(report_text + "\n")
264
+ print("\n[SUCCESS] Final score report saved to 'qwen_quintus_scores.txt'")
265
+
266
+ if __name__ == "__main__":
267
+ main()
sft/train_sft.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SFT Training and Downstream Evaluation Pipeline
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import gc
7
+ import json
8
+ import os
9
+ import re
10
+ import sys
11
+ import time
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ import yaml
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import DataLoader, Dataset
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
20
+
21
+ # Load Configuration
22
+ def load_config() -> dict:
23
+ cfg_path = Path(__file__).resolve().parent / "config.yaml"
24
+ if not cfg_path.exists():
25
+ return {}
26
+ with open(cfg_path, "r", encoding="utf-8") as f:
27
+ return yaml.safe_load(f) or {}
28
+
29
+ cfg = load_config()
30
+
31
+ # PROMPTS (50 PROMPTS)
32
+ EASY_PROMPTS = [
33
+ "What is the capital of Japan, and what is it known for?",
34
+ "What does the term 'CPU' stand for, and what is its role in a computer?",
35
+ "Name three mammals that live primarily in water.",
36
+ "What is the difference between a virus and a bacterium?",
37
+ "Convert 72 degrees Fahrenheit to Celsius.",
38
+ "What is the purpose of a hash function?",
39
+ "What does HTTP stand for and what is it used for?",
40
+ "In which continent is the Amazon rainforest located?",
41
+ "What is the difference between RAM and ROM?",
42
+ "Name two programming languages commonly used for data science.",
43
+ "What is the function of the mitochondria in a cell?",
44
+ "What is a palindrome? Give two examples.",
45
+ "What is the difference between a compiler and an interpreter?",
46
+ "What unit is used to measure electrical resistance?",
47
+ "Name the four blood types in the ABO system.",
48
+ "What is the primary purpose of DNS in networking?",
49
+ "What does it mean for a function to be 'pure' in programming?"
50
+ ]
51
+
52
+ MEDIUM_PROMPTS = [
53
+ "Explain the difference between supervised and unsupervised learning with a concrete example of each.",
54
+ "Write a Python function that takes a list of integers and returns all pairs that sum to a given target.",
55
+ "Explain how TCP/IP ensures reliable data delivery over an unreliable network.",
56
+ "What are the trade-offs between using a relational database and a document store for a user profile system?",
57
+ "Describe how gradient descent works and explain the role of the learning rate.",
58
+ "Write a SQL query that returns the top 5 customers by total order value, including customers with no orders.",
59
+ "What is the CAP theorem and what does it imply for distributed system design?",
60
+ "Explain the difference between process and thread, including when you would prefer one over the other.",
61
+ "How does HTTPS prevent a man-in-the-middle attack? Walk through the handshake at a high level.",
62
+ "Write a regex that validates an email address and annotate each part of the pattern.",
63
+ "What is the difference between memoization and dynamic programming?",
64
+ "Describe three ways to handle class imbalance in a machine learning dataset.",
65
+ "Explain what a foreign key constraint does and give an example of why it matters.",
66
+ "What is the difference between horizontal and vertical scaling, and when would you choose each?",
67
+ "How does Python's garbage collector handle circular references?",
68
+ "Explain the intuition behind the attention mechanism in Transformer models.",
69
+ "What is a race condition? Write a minimal pseudocode example that demonstrates one."
70
+ ]
71
+
72
+ TOUGH_PROMPTS = [
73
+ "Design a rate limiter for a public API that must handle 100k requests per second across multiple regions. Describe the data structures, algorithms, and infrastructure trade-offs involved.",
74
+ "Explain why training very deep neural networks with sigmoid activations suffers from vanishing gradients. How do residual connections and normalization layers address this, and what are their respective limitations?",
75
+ "A message queue is consuming events from an upstream producer faster than a downstream consumer can process them. The queue is filling up and the producer cannot be slowed down. Describe at least three architectural strategies to resolve this, with trade-offs.",
76
+ "Given an undirected weighted graph, write Python code to find the minimum spanning tree using Kruskal's algorithm. Include the union-find data structure. Analyze time and space complexity.",
77
+ "You are given two sorted arrays of size m and n. Find the median of the combined array in O(log(m+n)) time. Explain the approach before writing the code.",
78
+ "Explain the difference between Byzantine fault tolerance and crash fault tolerance. In what scenario does the distinction become critical, and how does a consensus protocol like PBFT address Byzantine failures?",
79
+ "A large language model fine-tuned on customer service data starts producing confident but factually wrong answers about product details. Propose a complete mitigation strategy covering training, inference, and deployment layers.",
80
+ "Explain the mechanism behind speculative execution in modern CPUs and how it led to the Spectre vulnerability. What classes of software-level mitigations exist and what performance cost do they carry?",
81
+ "Design a schema and indexing strategy for a social graph where you need to efficiently answer: (1) mutual friends between two users, (2) shortest path between two users, (3) top-k most influential accounts. Justify your choices.",
82
+ "Implement a thread-safe LRU cache in Python with O(1) get and put operations. Explain why your synchronization approach is correct and where contention bottlenecks might appear under high concurrency.",
83
+ "Explain the difference between weak, strong, and eventual consistency in distributed databases. Give a concrete example of a bug that arises when a developer assumes strong consistency but the system only guarantees eventual consistency.",
84
+ "You are designing the storage layer for a time-series database that ingests 1 million data points per second and must support range queries going back 2 years. Describe compression strategies, write amplification concerns, and compaction trade-offs.",
85
+ "Explain how LoRA (Low-Rank Adaptation) reduces the number of trainable parameters in fine-tuning. Derive why a weight update matrix can be approximated as a product of two low-rank matrices and discuss what is lost in this approximation.",
86
+ "A binary tree is given where each node has a value. Write an algorithm to find the maximum path sum between any two nodes (not necessarily leaf nodes). Prove the correctness of your recurrence relation.",
87
+ "Explain the economic concept of Goodhart's Law and give three examples of how it manifests in AI system evaluation.",
88
+ "Describe the full lifecycle of a memory allocation in a system using jemalloc or tcmalloc. How do thread-local caches, size classes, and slab allocation interact, and what are the implications for long-running server processes?"
89
+ ]
90
+
91
+ ALL_PROMPTS = []
92
+ for p in EASY_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "EASY"})
93
+ for p in MEDIUM_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "MEDIUM"})
94
+ for p in TOUGH_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "TOUGH"})
95
+
96
+ # UTILITIES AND DATASET LOADERS
97
+ class SFTDataset(Dataset):
98
+ def __init__(self, file_path: str, max_samples: int = -1):
99
+ self.samples = []
100
+ with open(file_path, "r", encoding="utf-8") as f:
101
+ for line in f:
102
+ if 0 < max_samples <= len(self.samples):
103
+ break
104
+ self.samples.append(json.loads(line))
105
+ print(f"Loaded {len(self.samples)} SFT samples from {file_path}")
106
+
107
+ def __len__(self) -> int:
108
+ return len(self.samples)
109
+
110
+ def __getitem__(self, idx: int) -> dict:
111
+ return self.samples[idx]
112
+
113
+ def pack_sequences(samples: list[dict], pack_length: int, pad_token_id: int, eos_token_id: int) -> list[dict]:
114
+ """Sort and pack short samples into fixed-size bins (FFD packing) to accelerate training."""
115
+ print(f"Packing sequences into {pack_length}-token bins...")
116
+ # Sort samples by input_ids length descending
117
+ indexed_samples = sorted(
118
+ samples,
119
+ key=lambda x: len(x["input_ids"]),
120
+ reverse=True
121
+ )
122
+
123
+ bins: list[list[dict]] = []
124
+ bin_lengths: list[int] = []
125
+
126
+ for sample in indexed_samples:
127
+ s_len = len(sample["input_ids"])
128
+ if s_len > pack_length:
129
+ sample["input_ids"] = sample["input_ids"][:pack_length]
130
+ sample["loss_mask"] = sample["loss_mask"][:pack_length]
131
+ s_len = pack_length
132
+
133
+ # Try to place sample into an existing bin
134
+ placed = False
135
+ for b_idx in range(len(bins)):
136
+ needed = s_len + (1 if len(bins[b_idx]) > 0 else 0)
137
+ if bin_lengths[b_idx] + needed <= pack_length:
138
+ bins[b_idx].append(sample)
139
+ bin_lengths[b_idx] += needed
140
+ placed = True
141
+ break
142
+
143
+ if not placed:
144
+ bins.append([sample])
145
+ bin_lengths.append(s_len)
146
+
147
+ # Convert packed bins to training formats
148
+ packed_samples = []
149
+ for b in bins:
150
+ input_ids = []
151
+ loss_mask = []
152
+ for i, sample in enumerate(b):
153
+ if i > 0:
154
+ input_ids.append(eos_token_id)
155
+ loss_mask.append(0) # Mask out the EOS separator token
156
+ input_ids.extend(sample["input_ids"])
157
+ loss_mask.extend(sample["loss_mask"])
158
+
159
+ real_len = len(input_ids)
160
+ pad_len = pack_length - real_len
161
+ if pad_len > 0:
162
+ input_ids.extend([pad_token_id] * pad_len)
163
+ loss_mask.extend([0] * pad_len)
164
+
165
+ packed_samples.append({
166
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
167
+ "loss_mask": torch.tensor(loss_mask, dtype=torch.long),
168
+ "attention_mask": torch.cat([
169
+ torch.ones(real_len, dtype=torch.long),
170
+ torch.zeros(pad_len, dtype=torch.long)
171
+ ])
172
+ })
173
+
174
+ utilization = sum(bin_lengths) / (len(bins) * pack_length)
175
+ print(f"Packed {len(samples)} samples into {len(bins)} bins. Utilization: {utilization * 100:.2f}%")
176
+ return packed_samples
177
+
178
+ def collate_sft(batch: list[dict], pad_token_id: int) -> dict:
179
+ """Collates batch for standard unpacked training, dynamically padding batch to max length."""
180
+ max_len = max(len(s["input_ids"]) for s in batch)
181
+ input_ids_list = []
182
+ attention_mask_list = []
183
+ labels_list = []
184
+
185
+ for s in batch:
186
+ ids = s["input_ids"]
187
+ mask = s["loss_mask"]
188
+ pad_len = max_len - len(ids)
189
+
190
+ padded_ids = ids + [pad_token_id] * pad_len
191
+ padded_labels = [ids[i] if mask[i] == 1 else -100 for i in range(len(ids))] + [-100] * pad_len
192
+
193
+ input_ids_list.append(torch.tensor(padded_ids, dtype=torch.long))
194
+ attention_mask_list.append(torch.tensor([1] * len(ids) + [0] * pad_len, dtype=torch.long))
195
+ labels_list.append(torch.tensor(padded_labels, dtype=torch.long))
196
+
197
+ return {
198
+ "input_ids": torch.stack(input_ids_list),
199
+ "attention_mask": torch.stack(attention_mask_list),
200
+ "labels": torch.stack(labels_list)
201
+ }
202
+
203
+ def collate_packed(batch: list[dict]) -> dict:
204
+ """Collates pre-packed sequence bins by simple stacking."""
205
+ input_ids = torch.stack([item["input_ids"] for item in batch])
206
+ attention_mask = torch.stack([item["attention_mask"] for item in batch])
207
+ loss_mask = torch.stack([item["loss_mask"] for item in batch])
208
+
209
+ labels = input_ids.clone()
210
+ labels = labels.masked_fill(loss_mask == 0, -100)
211
+
212
+ return {
213
+ "input_ids": input_ids,
214
+ "attention_mask": attention_mask,
215
+ "labels": labels
216
+ }
217
+
218
+ # PARSING AND MAIN LOGIC
219
+ def parse_args() -> argparse.Namespace:
220
+ parser = argparse.ArgumentParser(description="Clean SFT training and evaluation suite")
221
+ parser.add_argument("--student_model", type=str, default=cfg.get("model", {}).get("student", "Qwen/Qwen3-1.7B-Base"))
222
+ parser.add_argument("--tokenizer_model", type=str, default=cfg.get("model", {}).get("tokenizer", "Qwen/Qwen3-1.7B"))
223
+ parser.add_argument("--data_repo", type=str, default=os.environ.get("QUINTUS_SFT_DATA_REPO"), help="HF dataset repo containing train_sft.jsonl. Optional when data/tokenized/train_sft.jsonl exists.")
224
+ parser.add_argument("--token", type=str, default=None)
225
+ parser.add_argument("--trust_remote_code", action="store_true", help="Allow custom code from model/tokenizer repositories.")
226
+
227
+ parser.add_argument("--num_epochs", type=int, default=1)
228
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
229
+ parser.add_argument("--micro_batch_size", type=int, default=4)
230
+ parser.add_argument("--grad_accum_steps", type=int, default=2)
231
+ parser.add_argument("--max_seq_len", type=int, default=4096)
232
+ parser.add_argument("--sequence_packing", action="store_true", default=True)
233
+ parser.add_argument("--no_sequence_packing", action="store_false", dest="sequence_packing")
234
+
235
+ parser.add_argument("--output_dir", type=str, default="quintus_sft_output")
236
+
237
+ parser.add_argument("--run_prompt_suite", action="store_true", default=True)
238
+ parser.add_argument("--no_prompt_suite", action="store_false", dest="run_prompt_suite")
239
+ parser.add_argument("--run_gsm8k", action="store_true", default=True)
240
+ parser.add_argument("--no_gsm8k", action="store_false", dest="run_gsm8k")
241
+ parser.add_argument("--gsm8k_samples", type=int, default=100)
242
+
243
+ parser.add_argument("--optim", type=str, choices=["adamw", "adamw_8bit"], default="adamw")
244
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
245
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
246
+ parser.add_argument("--use_lora", action="store_true", default=False)
247
+ parser.add_argument("--lora_r", type=int, default=8)
248
+ parser.add_argument("--lora_alpha", type=int, default=16)
249
+
250
+ parser.add_argument("--push_to_hub", action="store_true", default=False, help="Automatically push fine-tuned model to Hugging Face Hub after training")
251
+ parser.add_argument("--hub_model_id", type=str, default="iamrahulreddy/Quintus", help="Target Hugging Face Hub repository ID")
252
+
253
+ return parser.parse_args()
254
+
255
+ def download_hf_dataset(repo_id: str | None, token: str | None) -> str:
256
+ print(f"Checking for tokenized dataset in local folders...")
257
+ local_path = "data/tokenized/train_sft.jsonl"
258
+ if os.path.exists(local_path):
259
+ print(f"Found local dataset: {local_path}")
260
+ return local_path
261
+
262
+ if not repo_id:
263
+ raise ValueError(
264
+ "No local SFT dataset found at data/tokenized/train_sft.jsonl. "
265
+ "Pass --data_repo or set QUINTUS_SFT_DATA_REPO."
266
+ )
267
+
268
+ print(f"Local file not found. Pulling from Hugging Face: {repo_id}...")
269
+ from huggingface_hub import hf_hub_download
270
+ os.makedirs("data/tokenized", exist_ok=True)
271
+ downloaded = hf_hub_download(
272
+ repo_id=repo_id,
273
+ filename="train_sft.jsonl",
274
+ repo_type="dataset",
275
+ local_dir="data/tokenized",
276
+ token=token
277
+ )
278
+ # Ensure correct local path layout
279
+ if os.path.exists(downloaded) and downloaded != local_path:
280
+ os.rename(downloaded, local_path)
281
+ print(f"Dataset downloaded to: {local_path}")
282
+ return local_path
283
+
284
+ # DOWNSTREAM EVALUATION CODE
285
+ def run_prompt_suite(model, tokenizer, device, output_dir: str):
286
+ print("\n" + "="*70)
287
+ print("RUNNING QUALITATIVE PROMPT SUITE (50 Prompts)")
288
+ print("="*70)
289
+
290
+ # Compile stop token IDs
291
+ eos_token_ids = [tokenizer.eos_token_id]
292
+ for token in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
293
+ t_id = tokenizer.convert_tokens_to_ids(token)
294
+ if t_id is not None and t_id != tokenizer.unk_token_id:
295
+ eos_token_ids.append(t_id)
296
+ eos_token_ids = list(set(eos_token_ids))
297
+
298
+ # Initialize output file
299
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
300
+ out_path = os.path.join(output_dir, f"prompt_suite_eval_{timestamp}.txt")
301
+ os.makedirs(output_dir, exist_ok=True)
302
+
303
+ with open(out_path, "w", encoding="utf-8") as f:
304
+ f.write("QUINTUS SFT POST-TRAINING PROMPT SUITE\n")
305
+ f.write(f"Timestamp: {timestamp}\n")
306
+ f.write("="*72 + "\n\n")
307
+ f.flush()
308
+
309
+ # Set padding side to left for batch generation
310
+ orig_padding_side = tokenizer.padding_side
311
+ tokenizer.padding_side = "left"
312
+
313
+ batch_size = 16
314
+ for i in range(0, len(ALL_PROMPTS), batch_size):
315
+ batch_items = ALL_PROMPTS[i : i + batch_size]
316
+
317
+ # Format prompts
318
+ formatted_prompts = []
319
+ for item in batch_items:
320
+ prompt_text = item["text"]
321
+ if tokenizer.chat_template is not None:
322
+ prompt_str = tokenizer.apply_chat_template(
323
+ [{"role": "user", "content": prompt_text}],
324
+ tokenize=False, add_generation_prompt=True
325
+ )
326
+ else:
327
+ prompt_str = f"<|im_start|>user\n{prompt_text}<|im_end|>\n<|im_start|>assistant\n"
328
+ formatted_prompts.append(prompt_str)
329
+
330
+ # Tokenize with padding
331
+ inputs = tokenizer(formatted_prompts, padding=True, return_tensors="pt").to(device)
332
+
333
+ with torch.no_grad():
334
+ outputs = model.generate(
335
+ **inputs,
336
+ max_new_tokens=2048,
337
+ do_sample=False, # Greedy for clean, reproducible comparison
338
+ pad_token_id=tokenizer.pad_token_id,
339
+ eos_token_id=eos_token_ids
340
+ )
341
+
342
+ # Decode and write results in real-time
343
+ for idx, item in enumerate(batch_items):
344
+ input_len = inputs["input_ids"][idx].shape[0]
345
+ gen_tokens = outputs[idx][input_len:]
346
+
347
+ # Slice at the first EOS token
348
+ eos_indices = []
349
+ for eos_id in eos_token_ids:
350
+ indices = (gen_tokens == eos_id).nonzero(as_tuple=True)[0]
351
+ if len(indices) > 0:
352
+ eos_indices.append(indices[0].item())
353
+ if eos_indices:
354
+ gen_tokens = gen_tokens[:min(eos_indices)]
355
+
356
+ response = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
357
+
358
+ # Log progress
359
+ global_idx = i + idx + 1
360
+ print(f"[{global_idx:02d}/50] ({item['difficulty']}) Q: {item['text'][:40]}... -> Answered ({len(gen_tokens)} tokens)")
361
+
362
+ # Append directly to output file
363
+ with open(out_path, "a", encoding="utf-8") as f:
364
+ f.write(f"[{global_idx:02d}/50] {item['difficulty']}\n")
365
+ f.write(f"Q: {item['text']}\n\n")
366
+ f.write(f"Response:\n{response}\n")
367
+ f.write("\n" + "-"*72 + "\n\n")
368
+ f.flush()
369
+
370
+ # Restore original tokenizer settings
371
+ tokenizer.padding_side = orig_padding_side
372
+ print(f"\nPrompt suite evaluation complete. Saved report to: {out_path}\n")
373
+
374
+ def extract_gsm8k_answer(text: str) -> str | None:
375
+ text = text.replace(",", "")
376
+ match = re.findall(r"The answer is\s*:?\s*(-?\d+)", text, re.IGNORECASE)
377
+ if match:
378
+ return match[-1]
379
+ match = re.findall(r"(-?\d+)", text)
380
+ if match:
381
+ return match[-1]
382
+ return None
383
+
384
+ def run_gsm8k_eval(model, tokenizer, device, num_samples: int = 100):
385
+ print("\n" + "="*70)
386
+ print(f"RUNNING GSM8K MATH EVALUATION ({num_samples} Samples)")
387
+ print("="*70)
388
+
389
+ from datasets import load_dataset
390
+ try:
391
+ dataset = load_dataset("openai/gsm8k", "main", split="test")
392
+ except Exception as e:
393
+ print(f"Warning: Could not download GSM8K test set directly: {e}")
394
+ return
395
+
396
+ dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))
397
+
398
+ correct = 0
399
+ total = 0
400
+
401
+ for idx, item in enumerate(dataset):
402
+ question = item["question"]
403
+ answer = item["answer"]
404
+
405
+ target_match = re.search(r"####\s*(-?\d+)", answer)
406
+ if not target_match:
407
+ continue
408
+ target_val = target_match.group(1)
409
+
410
+ if tokenizer.chat_template is not None:
411
+ prompt = tokenizer.apply_chat_template(
412
+ [{"role": "user", "content": question + "\nShow your work and conclude with 'The answer is: <number>'."}],
413
+ tokenize=False, add_generation_prompt=True
414
+ )
415
+ else:
416
+ prompt = f"<|im_start|>user\n{question}\nShow your work and conclude with 'The answer is: <number>'.<|im_end|>\n<|im_start|>assistant\n"
417
+
418
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
419
+
420
+ with torch.no_grad():
421
+ outputs = model.generate(
422
+ **inputs,
423
+ max_new_tokens=1024,
424
+ do_sample=False,
425
+ pad_token_id=tokenizer.pad_token_id,
426
+ eos_token_id=tokenizer.eos_token_id
427
+ )
428
+
429
+ gen_tokens = outputs[0][inputs.input_ids.shape[1]:]
430
+ generated_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
431
+
432
+ pred_val = extract_gsm8k_answer(generated_text)
433
+ is_match = (pred_val == target_val)
434
+
435
+ if is_match:
436
+ correct += 1
437
+ total += 1
438
+
439
+ # Log sample output periodically
440
+ if idx % 10 == 0:
441
+ print(f"\n[GSM8K Sample {idx+1}]")
442
+ print(f"Q: {question[:80]}...")
443
+ print(f"A: {generated_text[:120]}... (Target: {target_val} | Pred: {pred_val})")
444
+ print(f"Match: {is_match}")
445
+
446
+ accuracy = (correct / total * 100) if total > 0 else 0
447
+ print("\n" + "="*70)
448
+ print(f"GSM8K EVALUATION SUMMARY: {correct}/{total} Correct -> Accuracy: {accuracy:.2f}%")
449
+ print("="*70 + "\n")
450
+
451
+ # TRAINING PIPELINE
452
+ def main() -> None:
453
+ args = parse_args()
454
+
455
+ # Propagate HF token to environment for auto-authentication of downstream hub calls
456
+ try:
457
+ import huggingface_hub
458
+ cached_token = huggingface_hub.get_token()
459
+ except Exception:
460
+ cached_token = None
461
+ resolved_token = os.environ.get("HF_TOKEN") or cached_token or args.token
462
+ if resolved_token:
463
+ os.environ["HF_TOKEN"] = resolved_token
464
+ args.token = resolved_token
465
+
466
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
467
+ print(f"SFT Environment initialized. Target device: {device}")
468
+
469
+ # 1. Pull dataset from HF
470
+ try:
471
+ dataset_file = download_hf_dataset(args.data_repo, args.token)
472
+ except ValueError as exc:
473
+ print(f"Error: {exc}")
474
+ sys.exit(1)
475
+
476
+ # 2. Setup Tokenizer and Model
477
+ print(f"Loading tokenizer: {args.tokenizer_model}")
478
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_model, trust_remote_code=args.trust_remote_code)
479
+ if tokenizer.pad_token is None:
480
+ tokenizer.pad_token = tokenizer.eos_token
481
+
482
+ # 4-bit configuration if requested
483
+ bnb_config = None
484
+ if args.load_in_4bit:
485
+ from transformers import BitsAndBytesConfig
486
+ bnb_config = BitsAndBytesConfig(
487
+ load_in_4bit=True,
488
+ bnb_4bit_compute_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
489
+ bnb_4bit_quant_type="nf4",
490
+ bnb_4bit_use_double_quant=True
491
+ )
492
+ print("Using 4-bit BitsAndBytes quantization.")
493
+
494
+ # Liger Kernel (skipped for 4-bit/PEFT as it can interfere with quantized layers)
495
+ if not args.load_in_4bit:
496
+ try:
497
+ from liger_kernel.transformers import apply_liger_kernel_to_qwen3
498
+ apply_liger_kernel_to_qwen3(
499
+ rope=True,
500
+ swiglu=True,
501
+ rms_norm=True,
502
+ cross_entropy=False,
503
+ fused_linear_cross_entropy=False,
504
+ )
505
+ print("Liger Kernel optimizations applied successfully.")
506
+ except ImportError:
507
+ print("Liger Kernel not installed, skipping optimizations.")
508
+
509
+ attn_impl = "sdpa"
510
+ if device.type == "cuda":
511
+ try:
512
+ import flash_attn
513
+ attn_impl = "flash_attention_2"
514
+ print("FlashAttention-2 enabled.")
515
+ except ImportError:
516
+ print("flash-attn not installed, falling back to SDPA.")
517
+
518
+ model = AutoModelForCausalLM.from_pretrained(
519
+ args.student_model,
520
+ quantization_config=bnb_config,
521
+ dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
522
+ trust_remote_code=args.trust_remote_code,
523
+ attn_implementation=attn_impl
524
+ )
525
+ if not args.load_in_4bit:
526
+ model = model.to(device)
527
+ model.config.use_cache = False
528
+
529
+ # Wrap with LoRA if requested or required for 4-bit training
530
+ if args.use_lora or args.load_in_4bit:
531
+ try:
532
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
533
+ if args.load_in_4bit:
534
+ model = prepare_model_for_kbit_training(model)
535
+
536
+ peft_config = LoraConfig(
537
+ r=args.lora_r,
538
+ lora_alpha=args.lora_alpha,
539
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
540
+ lora_dropout=0.05,
541
+ bias="none",
542
+ task_type="CAUSAL_LM"
543
+ )
544
+ model = get_peft_model(model, peft_config)
545
+ print("LoRA adapters successfully attached to target modules.")
546
+ model.print_trainable_parameters()
547
+ except ImportError:
548
+ print("Error: peft not installed. Please run `!pip install -q peft` to use LoRA/QLoRA.")
549
+ sys.exit(1)
550
+
551
+ if args.gradient_checkpointing:
552
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
553
+ print("Gradient checkpointing enabled.")
554
+
555
+ # 3. Prepare dataset
556
+ raw_dataset = SFTDataset(dataset_file)
557
+
558
+ if args.sequence_packing:
559
+ packed_samples = pack_sequences(
560
+ raw_dataset.samples,
561
+ pack_length=args.max_seq_len,
562
+ pad_token_id=tokenizer.pad_token_id,
563
+ eos_token_id=tokenizer.eos_token_id
564
+ )
565
+ train_dataloader = DataLoader(
566
+ packed_samples,
567
+ batch_size=args.micro_batch_size,
568
+ shuffle=True,
569
+ collate_fn=collate_packed
570
+ )
571
+ else:
572
+ train_dataloader = DataLoader(
573
+ raw_dataset,
574
+ batch_size=args.micro_batch_size,
575
+ shuffle=True,
576
+ collate_fn=lambda b: collate_sft(b, tokenizer.pad_token_id)
577
+ )
578
+
579
+ # 4. Optimizer and scheduler setup
580
+ if args.optim == "adamw_8bit":
581
+ try:
582
+ import bitsandbytes as bnb
583
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
584
+ print("Using BitsAndBytes 8-bit AdamW optimizer.")
585
+ except ImportError:
586
+ print("Warning: bitsandbytes not installed. Falling back to standard AdamW.")
587
+ use_fused = (device.type == "cuda")
588
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1, fused=use_fused)
589
+ else:
590
+ use_fused = (device.type == "cuda")
591
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1, fused=use_fused)
592
+ print(f"Using standard AdamW optimizer (fused={use_fused}).")
593
+ steps_per_epoch = (len(train_dataloader) + args.grad_accum_steps - 1) // args.grad_accum_steps
594
+ total_steps = steps_per_epoch * args.num_epochs
595
+ warmup_steps = int(total_steps * 0.05)
596
+ scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
597
+
598
+ # 5. Training Loop
599
+ print("\n" + "="*70)
600
+ print(f"STARTING SFT TRAINING (Epochs: {args.num_epochs} | Steps: {total_steps})")
601
+ print("="*70)
602
+
603
+ model.train()
604
+ step = 0
605
+ total_tokens_processed = 0
606
+ t0 = time.time()
607
+
608
+ for epoch in range(args.num_epochs):
609
+ epoch_loss = 0.0
610
+ for batch_idx, batch in enumerate(train_dataloader):
611
+ input_ids = batch["input_ids"].to(device)
612
+ attention_mask = batch["attention_mask"].to(device)
613
+ labels = batch["labels"].to(device)
614
+
615
+ # Accumulate the number of active (non-padded) tokens processed
616
+ total_tokens_processed += attention_mask.sum().item()
617
+
618
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
619
+ loss = outputs.loss / args.grad_accum_steps
620
+ loss.backward()
621
+
622
+ epoch_loss += loss.item() * args.grad_accum_steps
623
+
624
+ if (batch_idx + 1) % args.grad_accum_steps == 0 or (batch_idx + 1) == len(train_dataloader):
625
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
626
+ optimizer.step()
627
+ scheduler.step()
628
+ optimizer.zero_grad()
629
+ step += 1
630
+
631
+ if step % 5 == 0 or step == total_steps:
632
+ elapsed = time.time() - t0
633
+ tokens_per_sec = total_tokens_processed / max(elapsed, 1e-5)
634
+ print(
635
+ f"Epoch {epoch+1}/{args.num_epochs} | "
636
+ f"Step {step}/{total_steps} | "
637
+ f"Loss: {loss.item() * args.grad_accum_steps:.4f} | "
638
+ f"LR: {scheduler.get_last_lr()[0]:.2e} | "
639
+ f"Tokens: {total_tokens_processed} | "
640
+ f"Speed: {tokens_per_sec:.2f} tokens/s"
641
+ )
642
+
643
+ # 6. Save model weights and tokenizer
644
+ print(f"\nTraining complete in {time.time() - t0:.1f}s. Saving weights to: {args.output_dir}")
645
+ os.makedirs(args.output_dir, exist_ok=True)
646
+ if hasattr(model, "merge_and_unload") and not args.load_in_4bit:
647
+ print("Merging LoRA adapters into base weights...")
648
+ try:
649
+ merged_model = model.merge_and_unload()
650
+ merged_model.save_pretrained(args.output_dir)
651
+ print("Merged model weights saved successfully.")
652
+ except Exception as e:
653
+ print(f"Failed to merge and unload: {e}. Saving adapter weights only.")
654
+ model.save_pretrained(args.output_dir)
655
+ else:
656
+ model.save_pretrained(args.output_dir)
657
+ tokenizer.save_pretrained(args.output_dir)
658
+ print("Weights and configuration saved successfully.")
659
+
660
+ # 7. SFT Downstream Evaluations
661
+ model.eval()
662
+
663
+ if args.run_prompt_suite:
664
+ run_prompt_suite(model, tokenizer, device, args.output_dir)
665
+
666
+ if args.run_gsm8k:
667
+ run_gsm8k_eval(model, tokenizer, device, num_samples=args.gsm8k_samples)
668
+
669
+ if args.push_to_hub:
670
+ print(f"\nUploading fine-tuned model and tokenizer to Hugging Face Hub: {args.hub_model_id}...")
671
+ try:
672
+ from huggingface_hub import create_repo, HfApi
673
+ token_val = args.token or os.environ.get("HF_TOKEN")
674
+ create_repo(repo_id=args.hub_model_id, token=token_val, exist_ok=True)
675
+
676
+ api = HfApi()
677
+ api.upload_folder(
678
+ folder_path=args.output_dir,
679
+ repo_id=args.hub_model_id,
680
+ repo_type="model",
681
+ token=token_val
682
+ )
683
+ print("Successfully uploaded model and tokenizer to Hugging Face Hub!")
684
+ except Exception as hub_err:
685
+ print(f"Failed to push to Hub: {hub_err}")
686
+
687
+ print("Pipeline Execution Complete. Model is ready.")
688
+
689
+ if __name__ == "__main__":
690
+ main()
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # src package
src/checkpoints.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from configs import cfg
11
+
12
+ def checkpoint_rank(path: str) -> tuple[int, int]:
13
+ name = os.path.basename(path)
14
+ prefix, _, raw_value = name.partition("_")
15
+ try:
16
+ value = int(raw_value)
17
+ except ValueError:
18
+ value = -1
19
+ if prefix == "epoch":
20
+ return (2, value)
21
+ if prefix == "step":
22
+ return (1, value)
23
+ return (0, value)
24
+
25
+ def find_latest_training_checkpoint(output_dir: str) -> str | None:
26
+ candidates = []
27
+ for pattern in ("epoch_*", "step_*"):
28
+ candidates.extend(str(path) for path in Path(output_dir).glob(pattern) if path.is_dir())
29
+ if not candidates:
30
+ return None
31
+ return max(candidates, key=checkpoint_rank)
32
+
33
+ def load_trainer_state(checkpoint_dir: str, log) -> dict:
34
+ state_path = os.path.join(checkpoint_dir, "trainer_state.json")
35
+ if os.path.exists(state_path):
36
+ try:
37
+ with open(state_path, "r", encoding="utf-8") as f:
38
+ state = json.load(f)
39
+ if isinstance(state, dict):
40
+ return state
41
+ except (OSError, json.JSONDecodeError) as exc:
42
+ log.warning(f"Could not read trainer_state.json from {checkpoint_dir}: {exc}")
43
+
44
+ name = os.path.basename(checkpoint_dir)
45
+ prefix, _, raw_value = name.partition("_")
46
+ try:
47
+ value = int(raw_value)
48
+ except ValueError:
49
+ value = 0
50
+ if prefix == "epoch":
51
+ return {
52
+ "checkpoint_type": "epoch",
53
+ "start_epoch": value,
54
+ "global_step": 0,
55
+ "micro_step_global": 0,
56
+ "next_batch_in_epoch": 0,
57
+ }
58
+ if prefix == "step":
59
+ return {
60
+ "checkpoint_type": "step",
61
+ "start_epoch": 0,
62
+ "global_step": value,
63
+ "micro_step_global": 0,
64
+ "next_batch_in_epoch": 0,
65
+ }
66
+ return {}
67
+
68
+ def packing_checkpoint_metadata(enabled: bool, pack_length: int | None, max_seq_len: int) -> dict[str, int | bool | None]:
69
+ return {
70
+ "sequence_packing_enabled": bool(enabled),
71
+ "sequence_packing_pack_length": int(pack_length) if enabled and pack_length is not None else None,
72
+ "data_max_seq_len": int(max_seq_len),
73
+ }
74
+
75
+ def validate_resume_packing_state(
76
+ trainer_state: dict,
77
+ *,
78
+ enabled: bool,
79
+ pack_length: int,
80
+ max_seq_len: int,
81
+ log,
82
+ ) -> None:
83
+ checkpoint_enabled = bool(trainer_state.get("sequence_packing_enabled", False))
84
+ if checkpoint_enabled != bool(enabled):
85
+ log.error(
86
+ "Checkpoint sequence-packing state does not match the current run: "
87
+ f"checkpoint={checkpoint_enabled}, current={bool(enabled)}."
88
+ )
89
+ raise SystemExit(1)
90
+
91
+ if checkpoint_enabled:
92
+ checkpoint_pack_length = trainer_state.get("sequence_packing_pack_length")
93
+ try:
94
+ checkpoint_pack_length = int(checkpoint_pack_length)
95
+ except (TypeError, ValueError):
96
+ log.error("Checkpoint is missing a valid sequence_packing_pack_length value.")
97
+ raise SystemExit(1)
98
+ if checkpoint_pack_length != int(pack_length):
99
+ log.error(
100
+ "Checkpoint pack length does not match the current run: "
101
+ f"checkpoint={checkpoint_pack_length}, current={int(pack_length)}."
102
+ )
103
+ raise SystemExit(1)
104
+
105
+ checkpoint_max_seq_len = trainer_state.get("data_max_seq_len")
106
+ if checkpoint_max_seq_len is not None:
107
+ try:
108
+ checkpoint_max_seq_len = int(checkpoint_max_seq_len)
109
+ except (TypeError, ValueError):
110
+ log.error("Checkpoint is missing a valid data_max_seq_len value.")
111
+ raise SystemExit(1)
112
+ if checkpoint_max_seq_len != int(max_seq_len):
113
+ log.error(
114
+ "Checkpoint max sequence length does not match the current run: "
115
+ f"checkpoint={checkpoint_max_seq_len}, current={int(max_seq_len)}."
116
+ )
117
+ raise SystemExit(1)
118
+
119
+ def save_checkpoint(
120
+ model,
121
+ tokenizer,
122
+ output_dir: str,
123
+ tag: str,
124
+ logger,
125
+ *,
126
+ scheduler=None,
127
+ trainer_state: dict | None = None,
128
+ ) -> str:
129
+ save_dir = os.path.join(output_dir, tag)
130
+ os.makedirs(save_dir, exist_ok=True)
131
+ save_start = time.time()
132
+ logger.info(f"[CKPT] Saving {tag} -> {save_dir}/")
133
+
134
+ model_to_save = model.module if hasattr(model, "module") else model
135
+ if hasattr(model_to_save, "_orig_mod"):
136
+ model_to_save = model_to_save._orig_mod
137
+
138
+ model_to_save.config.save_pretrained(save_dir)
139
+ tokenizer.save_pretrained(save_dir)
140
+
141
+ try:
142
+ from safetensors.torch import save_file
143
+ state_dict = {k: v.contiguous().cpu() for k, v in model_to_save.state_dict().items()}
144
+ save_file(state_dict, os.path.join(save_dir, "model.safetensors"))
145
+ logger.info("[CKPT] Saved via safetensors")
146
+ except ImportError:
147
+ torch.save(model_to_save.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
148
+ logger.info("[CKPT] Saved via torch.save")
149
+
150
+ if scheduler is not None:
151
+ torch.save(scheduler.state_dict(), os.path.join(save_dir, "scheduler.pt"))
152
+
153
+ if trainer_state is not None:
154
+ trainer_state = dict(trainer_state)
155
+ trainer_state.setdefault("tag", tag)
156
+ trainer_state.setdefault("saved_at", time.strftime("%Y-%m-%d %H:%M:%S %Z"))
157
+ with open(os.path.join(save_dir, "trainer_state.json"), "w", encoding="utf-8") as f:
158
+ json.dump(trainer_state, f, indent=2)
159
+
160
+ size_mb = sum(f.stat().st_size for f in Path(save_dir).rglob("*") if f.is_file()) / 1e6
161
+ save_elapsed = time.time() - save_start
162
+ logger.info(f"[CKPT] {tag} -> {save_dir}/ ({size_mb:.0f} MB, {save_elapsed:.1f}s)")
163
+ return save_dir
164
+
165
+ def read_env_flag(name: str, default: bool = False) -> bool:
166
+ raw = os.environ.get(name)
167
+ if raw is None:
168
+ return default
169
+ return raw.strip().lower() in {"1", "true", "yes", "on"}
170
+
171
+ def hub_upload_strict() -> bool:
172
+ strict = getattr(getattr(cfg, "hub", None), "hub_upload_strict", None)
173
+ if strict is None:
174
+ return read_env_flag("QUINTUS_HUB_UPLOAD_STRICT", False)
175
+ return bool(strict)
176
+
177
+ def should_upload_checkpoint_tag(tag: str) -> bool:
178
+ upload_regular = getattr(getattr(cfg, "hub", None), "upload_kd_checkpoints", False) or read_env_flag("QUINTUS_UPLOAD_KD_CHECKPOINTS", False)
179
+ upload_steps = getattr(getattr(cfg, "hub", None), "upload_step_checkpoints", False) or read_env_flag("QUINTUS_UPLOAD_STEP_CHECKPOINTS", False)
180
+ upload_last = getattr(getattr(cfg, "hub", None), "upload_last_checkpoint", False) or read_env_flag("QUINTUS_UPLOAD_LAST_CHECKPOINT", False)
181
+ if tag.startswith("step_"):
182
+ return upload_steps
183
+ if tag.startswith("epoch_"):
184
+ return upload_regular
185
+ if tag == "best":
186
+ return upload_regular
187
+ if tag == "last":
188
+ return upload_last or upload_regular
189
+ return False
190
+
191
+ def maybe_upload_checkpoint(checkpoint_dir: str, tag: str, logger) -> None:
192
+ if not should_upload_checkpoint_tag(tag):
193
+ return
194
+
195
+ token = os.environ.get("HF_TOKEN") or getattr(cfg.hub, "token", None)
196
+ if not token:
197
+ msg = "HF checkpoint upload requested, but HF_TOKEN/cfg.hub.token is missing"
198
+ strict = hub_upload_strict()
199
+ if strict:
200
+ raise RuntimeError(msg)
201
+ logger.warning(f"[CKPT] {msg}; continuing without remote backup")
202
+ return
203
+
204
+ repo_id = getattr(getattr(cfg, "hub", None), "repo_id", None) or os.environ.get("QUINTUS_HUB_REPO_ID") or f"{cfg.hub.username}/{cfg.hub.repo_name}"
205
+ base_path = getattr(getattr(cfg, "hub", None), "ckpt_path_in_repo", None) or os.environ.get("KD_CKPT_PATH_IN_REPO", "models/online_kd_3b_05b_ep3_B200_20260601")
206
+ base_path = base_path.strip("/")
207
+ path_in_repo = f"{base_path}/{tag}"
208
+ commit_prefix = getattr(getattr(cfg, "hub", None), "commit_message_prefix", None) or os.environ.get(
209
+ "KD_COMMIT_MESSAGE_PREFIX",
210
+ "Online KD 8B->1.7B Run",
211
+ )
212
+ commit_message = os.environ.get("KD_COMMIT_MESSAGE") or f"{commit_prefix}: upload {tag}"
213
+ upload_start = time.time()
214
+ size_mb = sum(f.stat().st_size for f in Path(checkpoint_dir).rglob("*") if f.is_file()) / 1e6
215
+ strict = hub_upload_strict()
216
+ logger.info(
217
+ f"[CKPT] Uploading {tag} -> {repo_id}/{path_in_repo} "
218
+ f"({size_mb:.0f} MB, strict={strict})"
219
+ )
220
+ logger.info(f"[CKPT] Commit: {commit_message}")
221
+
222
+ try:
223
+ from huggingface_hub import HfApi
224
+
225
+ api = HfApi(token=token)
226
+ api.create_repo(repo_id=repo_id, repo_type="dataset", private=True, exist_ok=True)
227
+ api.upload_folder(
228
+ folder_path=checkpoint_dir,
229
+ repo_id=repo_id,
230
+ path_in_repo=path_in_repo,
231
+ repo_type="dataset",
232
+ commit_message=commit_message,
233
+ ignore_patterns=["*.tmp", "*.log", "__pycache__/*"],
234
+ )
235
+ upload_elapsed = time.time() - upload_start
236
+ logger.info(f"[CKPT] Uploaded {tag} to HF Hub in {upload_elapsed / 60:.1f}m")
237
+ except Exception as exc:
238
+ msg = f"HF checkpoint upload failed for {tag}: {exc}"
239
+ if hub_upload_strict():
240
+ raise RuntimeError(msg) from exc
241
+ logger.warning(f"[CKPT] {msg}; continuing because hub upload strict mode is disabled")
src/download.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import platform
7
+ import sys
8
+ import time
9
+ import warnings
10
+ from pathlib import Path
11
+
12
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
13
+ if str(_REPO_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(_REPO_ROOT))
15
+
16
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0")
17
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
18
+ warnings.filterwarnings("ignore", category=FutureWarning)
19
+ warnings.filterwarnings("ignore", category=UserWarning)
20
+
21
+ import torch
22
+ from datasets import load_dataset
23
+ from huggingface_hub import snapshot_download
24
+ from transformers import AutoTokenizer
25
+
26
+ from configs import cfg, emit_log_spacing, setup_logger
27
+ from src.transformers_compat import format_model_load_error
28
+
29
+ _IGNORE_PATTERNS = ["*.msgpack", "*.h5", "*.bin", "optimizer.pt", "optimizer.safetensors"]
30
+ _TOKENIZER_ALLOW_PATTERNS = [
31
+ "tokenizer.json",
32
+ "tokenizer.model",
33
+ "tokenizer_config.json",
34
+ "special_tokens_map.json",
35
+ "added_tokens.json",
36
+ "vocab.json",
37
+ "merges.txt",
38
+ "generation_config.json",
39
+ ]
40
+ _MIN_TOKEN_LENGTH = 10
41
+ _DATA_STATS_FILENAME = "_data_stats.json"
42
+ _ASSISTANT_MASK_KEYS = ("assistant_masks", "assistant_mask", "assistant_tokens_mask")
43
+
44
+
45
+ def _build_chat_template_error_types() -> tuple[type[BaseException], ...]:
46
+ error_types: list[type[BaseException]] = [
47
+ AttributeError,
48
+ IndexError,
49
+ KeyError,
50
+ RuntimeError,
51
+ TypeError,
52
+ ValueError,
53
+ ]
54
+ try:
55
+ from jinja2 import TemplateError
56
+ error_types.append(TemplateError)
57
+ except ImportError:
58
+ pass
59
+ return tuple(error_types)
60
+
61
+
62
+ _CHAT_TEMPLATE_ERRORS = _build_chat_template_error_types()
63
+
64
+
65
+ def _config_revision(value: str | None) -> str | None:
66
+ if value is None:
67
+ return None
68
+ stripped = value.strip()
69
+ return stripped or None
70
+
71
+
72
+ def _download_tokenizer_artifacts(tokenizer_model: str, tokenizer_revision: str | None, tokenizer_dir: str, log) -> None:
73
+ log.info(f"Downloading tokenizer -> ./{tokenizer_dir}/")
74
+ t0 = time.time()
75
+ try:
76
+ snapshot_download(
77
+ repo_id=tokenizer_model,
78
+ local_dir=tokenizer_dir,
79
+ revision=tokenizer_revision,
80
+ allow_patterns=_TOKENIZER_ALLOW_PATTERNS,
81
+ )
82
+ size_mb = sum(f.stat().st_size for f in Path(tokenizer_dir).rglob("*") if f.is_file()) / 1e6
83
+ log.info(f"Tokenizer downloaded: {size_mb:.1f} MB in {time.time() - t0:.0f}s")
84
+ except Exception as exc:
85
+ log.error(f"Failed to download tokenizer: {exc}")
86
+ sys.exit(1)
87
+
88
+
89
+ def write_system_info(output_path: str, logger) -> None:
90
+ output_dir = os.path.dirname(output_path)
91
+ if output_dir:
92
+ os.makedirs(output_dir, exist_ok=True)
93
+ info = {
94
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S %Z"),
95
+ "platform": platform.platform(),
96
+ "python": sys.version.split()[0],
97
+ "torch": torch.__version__,
98
+ "cuda": torch.version.cuda if torch.cuda.is_available() else None,
99
+ "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
100
+ "cpu_count": os.cpu_count(),
101
+ }
102
+ with open(output_path, "w", encoding="utf-8") as f:
103
+ json.dump(info, f, indent=2)
104
+ logger.info(f"System info -> {output_path}")
105
+
106
+
107
+ def write_data_stats(output_path: str, stats: dict, dataset_id: str, config_name: str, target_samples: int, max_seq_len: int, logger) -> None:
108
+ output_dir = os.path.dirname(output_path)
109
+ if output_dir:
110
+ os.makedirs(output_dir, exist_ok=True)
111
+ meta = {
112
+ "dataset": dataset_id,
113
+ "config": config_name,
114
+ "target_samples": target_samples,
115
+ "max_seq_len": max_seq_len,
116
+ "stats": stats,
117
+ }
118
+ with open(output_path, "w", encoding="utf-8") as f:
119
+ json.dump(meta, f, indent=2)
120
+ logger.info(f"Dataset stats -> {output_path}")
121
+
122
+
123
+ def _coerce_content_str(content) -> str | None:
124
+ if isinstance(content, str):
125
+ return content.strip() or None
126
+ if isinstance(content, dict):
127
+ for key in ("answer_content", "text", "value", "content", "think_content"):
128
+ value = content.get(key)
129
+ if isinstance(value, str) and value.strip():
130
+ return value.strip()
131
+ return None
132
+ if isinstance(content, list):
133
+ parts = []
134
+ for item in content:
135
+ if isinstance(item, str) and item.strip():
136
+ parts.append(item.strip())
137
+ elif isinstance(item, dict):
138
+ value = item.get("text", item.get("value", ""))
139
+ if isinstance(value, str) and value.strip():
140
+ parts.append(value.strip())
141
+ joined = " ".join(parts).strip()
142
+ return joined or None
143
+ return None
144
+
145
+
146
+ def _extract_clean_sample(row: dict) -> dict | None:
147
+ messages = row.get("messages", row.get("conversations", []))
148
+ if not messages and "instruction" in row and "output" in row:
149
+ messages = [
150
+ {"role": "user", "content": row["instruction"]},
151
+ {"role": "assistant", "content": row["output"]},
152
+ ]
153
+ if isinstance(messages, str):
154
+ try:
155
+ messages = json.loads(messages)
156
+ except (json.JSONDecodeError, TypeError):
157
+ return None
158
+ if not isinstance(messages, list) or len(messages) < 2:
159
+ return None
160
+
161
+ clean_messages: list[dict] = []
162
+ for message in messages:
163
+ if not isinstance(message, dict):
164
+ return None
165
+ role = message.get("role", message.get("from", ""))
166
+ content = _coerce_content_str(message.get("content", message.get("value", "")))
167
+ if not isinstance(role, str) or not role.strip() or content is None:
168
+ return None
169
+ role = role.strip().lower()
170
+ if role == "human":
171
+ role = "user"
172
+ elif role == "gpt":
173
+ role = "assistant"
174
+ clean_messages.append({"role": role, "content": content})
175
+
176
+ source = str(row.get("source", "unknown"))
177
+ return {"messages": clean_messages, "source": source}
178
+
179
+
180
+ def _coerce_token_ids(token_ids) -> list[int]:
181
+ if hasattr(token_ids, "input_ids"):
182
+ token_ids = token_ids.input_ids
183
+ if isinstance(token_ids, dict):
184
+ token_ids = token_ids.get("input_ids", token_ids)
185
+ if hasattr(token_ids, "tolist"):
186
+ token_ids = token_ids.tolist()
187
+ if isinstance(token_ids, tuple):
188
+ token_ids = list(token_ids)
189
+ if not isinstance(token_ids, list):
190
+ raise TypeError(f"Unexpected token id payload: {type(token_ids).__name__}")
191
+ if token_ids and isinstance(token_ids[0], list):
192
+ raise TypeError("Expected a single token sequence, not a batched payload")
193
+ return [int(token_id) for token_id in token_ids]
194
+
195
+
196
+ def _coerce_binary_mask(mask_values, expected_len: int) -> list[int]:
197
+ if hasattr(mask_values, "tolist"):
198
+ mask_values = mask_values.tolist()
199
+ if isinstance(mask_values, tuple):
200
+ mask_values = list(mask_values)
201
+ if not isinstance(mask_values, list):
202
+ raise TypeError(f"Unexpected assistant mask payload: {type(mask_values).__name__}")
203
+ if mask_values and isinstance(mask_values[0], list):
204
+ raise TypeError("Expected a single assistant mask, not a batched payload")
205
+ mask = [1 if int(value) != 0 else 0 for value in mask_values]
206
+ if len(mask) != expected_len:
207
+ raise ValueError(f"Assistant mask length mismatch: got {len(mask)}, expected {expected_len}")
208
+ return mask
209
+
210
+
211
+ def _try_builtin_assistant_mask(sample: dict, tokenizer: AutoTokenizer, input_ids: list[int]) -> list[int] | None:
212
+ try:
213
+ with warnings.catch_warnings():
214
+ warnings.filterwarnings("ignore", message="return_assistant_tokens_mask")
215
+ encoded = tokenizer.apply_chat_template(
216
+ sample["messages"],
217
+ tokenize=True,
218
+ add_generation_prompt=False,
219
+ return_dict=True,
220
+ return_assistant_tokens_mask=True,
221
+ )
222
+ except TypeError:
223
+ return None
224
+ except _CHAT_TEMPLATE_ERRORS:
225
+ return None
226
+ if not hasattr(encoded, "get"):
227
+ return None
228
+ try:
229
+ encoded_ids = _coerce_token_ids(encoded)
230
+ except _CHAT_TEMPLATE_ERRORS:
231
+ return None
232
+ if encoded_ids != input_ids:
233
+ return None
234
+ for key in _ASSISTANT_MASK_KEYS:
235
+ if key not in encoded:
236
+ continue
237
+ try:
238
+ mask = _coerce_binary_mask(encoded[key], expected_len=len(input_ids))
239
+ except (TypeError, ValueError):
240
+ return None
241
+ if any(mask):
242
+ return mask
243
+ return None
244
+
245
+
246
+ def _build_assistant_mask_from_prefixes(sample: dict, tokenizer: AutoTokenizer, input_ids: list[int]) -> list[int]:
247
+ loss_mask = [0] * len(input_ids)
248
+ prefix_ids: list[int] = []
249
+ for turn_index, message in enumerate(sample["messages"], start=1):
250
+ role = str(message.get("role", "")).strip().lower()
251
+ if role == "assistant":
252
+ # prompt_ids contains everything up to user message + assistant header (<|im_start|>assistant\n)
253
+ prompt_ids = _coerce_token_ids(
254
+ tokenizer.apply_chat_template(
255
+ sample["messages"][:turn_index-1],
256
+ tokenize=True,
257
+ add_generation_prompt=True,
258
+ )
259
+ )
260
+ # full_ids contains prompt + assistant response content + eos
261
+ full_ids = _coerce_token_ids(
262
+ tokenizer.apply_chat_template(
263
+ sample["messages"][:turn_index],
264
+ tokenize=True,
265
+ add_generation_prompt=False,
266
+ )
267
+ )
268
+ if len(full_ids) < len(prompt_ids) or full_ids[:len(prompt_ids)] != prompt_ids:
269
+ raise ValueError("Chat template is not prefix-stable enough to derive assistant-only targets")
270
+
271
+ # Loss mask is 1 only for assistant's content tokens (after prompt_ids)
272
+ for j in range(len(prompt_ids), len(full_ids)):
273
+ loss_mask[j] = 1
274
+ prefix_ids = _coerce_token_ids(
275
+ tokenizer.apply_chat_template(
276
+ sample["messages"][:turn_index],
277
+ tokenize=True,
278
+ add_generation_prompt=False,
279
+ )
280
+ )
281
+ if prefix_ids != input_ids:
282
+ raise ValueError("Prefix tokenization mismatch while deriving assistant-only targets")
283
+ if len(loss_mask) != len(input_ids):
284
+ raise ValueError(f"Assistant mask length mismatch: got {len(loss_mask)}, expected {len(input_ids)}")
285
+ return loss_mask
286
+
287
+
288
+ def _build_assistant_loss_mask(sample: dict, tokenizer: AutoTokenizer, input_ids: list[int]) -> list[int]:
289
+ builtin_mask = _try_builtin_assistant_mask(sample, tokenizer, input_ids)
290
+ if builtin_mask is not None:
291
+ return builtin_mask
292
+ return _build_assistant_mask_from_prefixes(sample, tokenizer, input_ids)
293
+
294
+
295
+ def write_tokenized_dataset(tokenizer, num_samples: int, out_file: str, log) -> dict:
296
+ if num_samples <= 0:
297
+ raise RuntimeError("num_samples must be positive when tokenization is enabled")
298
+ output_dir = os.path.dirname(out_file)
299
+ if output_dir:
300
+ os.makedirs(output_dir, exist_ok=True)
301
+
302
+ log.info(f"Streaming {cfg.data.dataset_path}...")
303
+ ds = load_dataset(cfg.data.dataset_path, split="train", streaming=True, token=cfg.hub.token)
304
+ buffer_size = int(getattr(cfg.data, "stream_shuffle_buffer_size", 0) or 0)
305
+ if buffer_size > 0 and hasattr(ds, "shuffle"):
306
+ seed = int(getattr(cfg.data, "stream_shuffle_seed", 42))
307
+ log.info(f"Shuffling stream with buffer_size={buffer_size:,} seed={seed}")
308
+ ds = ds.shuffle(seed=seed, buffer_size=buffer_size)
309
+
310
+ # Check if dataset is pre-tokenized
311
+ try:
312
+ first_sample = next(iter(ds))
313
+ is_pre_tokenized = "input_ids" in first_sample and "loss_mask" in first_sample
314
+ except StopIteration:
315
+ raise RuntimeError("Loaded dataset is empty")
316
+
317
+ stats = {
318
+ "scanned": 0,
319
+ "written": 0,
320
+ "too_long_tokens": 0,
321
+ "too_short_tokens": 0,
322
+ "template_errors": 0,
323
+ "no_target_tokens": 0,
324
+ "invalid_messages": 0,
325
+ "total_tokens_written": 0,
326
+ "total_target_tokens_written": 0,
327
+ "min_tokens_written": 0,
328
+ "max_tokens_written": 0,
329
+ "min_target_tokens_written": 0,
330
+ "max_target_tokens_written": 0,
331
+ }
332
+
333
+ if is_pre_tokenized:
334
+ log.info("Auto-detected pre-tokenized dataset on HF Hub. Writing directly to train.jsonl...")
335
+ # Re-initialize to avoid losing the first element consumed by next(iter())
336
+ ds = load_dataset(cfg.data.dataset_path, split="train", streaming=True, token=cfg.hub.token)
337
+ if buffer_size > 0 and hasattr(ds, "shuffle"):
338
+ ds = ds.shuffle(seed=seed, buffer_size=buffer_size)
339
+
340
+ with open(out_file, "w", encoding="utf-8") as f:
341
+ for row in ds:
342
+ stats["scanned"] += 1
343
+ input_ids = row["input_ids"]
344
+ loss_mask = row["loss_mask"]
345
+ token_len = len(input_ids)
346
+ target_tokens = sum(loss_mask)
347
+
348
+ out_row = {
349
+ "input_ids": input_ids,
350
+ "loss_mask": loss_mask,
351
+ "length": token_len,
352
+ "target_tokens": target_tokens,
353
+ "source": row.get("source", "unknown"),
354
+ }
355
+ f.write(json.dumps(out_row) + "\n")
356
+
357
+ stats["written"] += 1
358
+ stats["total_tokens_written"] += token_len
359
+ stats["total_target_tokens_written"] += target_tokens
360
+ if stats["written"] == 1:
361
+ stats["min_tokens_written"] = token_len
362
+ stats["max_tokens_written"] = token_len
363
+ stats["min_target_tokens_written"] = target_tokens
364
+ stats["max_target_tokens_written"] = target_tokens
365
+ else:
366
+ stats["min_tokens_written"] = min(stats["min_tokens_written"], token_len)
367
+ stats["max_tokens_written"] = max(stats["max_tokens_written"], token_len)
368
+ stats["min_target_tokens_written"] = min(stats["min_target_tokens_written"], target_tokens)
369
+ stats["max_target_tokens_written"] = max(stats["max_target_tokens_written"], target_tokens)
370
+
371
+ if stats["written"] >= num_samples:
372
+ break
373
+ return stats
374
+
375
+ log.info("Standard raw text dataset detected. Running tokenization locally...")
376
+ with open(out_file, "w", encoding="utf-8") as f:
377
+ for row in ds:
378
+ stats["scanned"] += 1
379
+ sample = _extract_clean_sample(row)
380
+ if sample is None:
381
+ stats["invalid_messages"] += 1
382
+ continue
383
+ try:
384
+ input_ids = _coerce_token_ids(
385
+ tokenizer.apply_chat_template(
386
+ sample["messages"],
387
+ tokenize=True,
388
+ add_generation_prompt=False,
389
+ )
390
+ )
391
+ token_len = len(input_ids)
392
+ if token_len < _MIN_TOKEN_LENGTH:
393
+ stats["too_short_tokens"] += 1
394
+ continue
395
+ if token_len > cfg.data.max_seq_len:
396
+ stats["too_long_tokens"] += 1
397
+ continue
398
+ loss_mask = _build_assistant_loss_mask(sample, tokenizer, input_ids)
399
+ except _CHAT_TEMPLATE_ERRORS:
400
+ stats["template_errors"] += 1
401
+ continue
402
+
403
+ target_tokens = sum(loss_mask)
404
+ if target_tokens == 0:
405
+ stats["no_target_tokens"] += 1
406
+ continue
407
+
408
+ out_row = {
409
+ "input_ids": input_ids,
410
+ "loss_mask": loss_mask,
411
+ "length": token_len,
412
+ "target_tokens": target_tokens,
413
+ "source": sample.get("source", "unknown"),
414
+ }
415
+ f.write(json.dumps(out_row) + "\n")
416
+
417
+ stats["written"] += 1
418
+ stats["total_tokens_written"] += token_len
419
+ stats["total_target_tokens_written"] += target_tokens
420
+ if stats["written"] == 1:
421
+ stats["min_tokens_written"] = token_len
422
+ stats["max_tokens_written"] = token_len
423
+ stats["min_target_tokens_written"] = target_tokens
424
+ stats["max_target_tokens_written"] = target_tokens
425
+ else:
426
+ stats["min_tokens_written"] = min(stats["min_tokens_written"], token_len)
427
+ stats["max_tokens_written"] = max(stats["max_tokens_written"], token_len)
428
+ stats["min_target_tokens_written"] = min(stats["min_target_tokens_written"], target_tokens)
429
+ stats["max_target_tokens_written"] = max(stats["max_target_tokens_written"], target_tokens)
430
+
431
+ if stats["written"] >= num_samples:
432
+ break
433
+
434
+ return stats
435
+
436
+ def main() -> None:
437
+ parser = argparse.ArgumentParser(description="Download models and tokenize data")
438
+ parser.add_argument("--num_samples", type=int, default=cfg.data.num_samples)
439
+ parser.add_argument("--skip_teacher", action="store_true", help="Skip teacher download.")
440
+ parser.add_argument("--skip_tokenization", action="store_true", help="Skip data tokenization.")
441
+ parser.add_argument("--tokenizer_only", action="store_true", help="Download tokenizer artifacts only")
442
+ args = parser.parse_args()
443
+
444
+ log = setup_logger("DOWNLOAD")
445
+ log.info("=" * 70)
446
+ log.info("Download and tokenize")
447
+ log.info("=" * 70)
448
+
449
+ write_system_info(cfg.paths.system_info, log)
450
+
451
+ log.info(f" Teacher: {cfg.model.teacher}")
452
+ log.info(f" Teacher rev: {_config_revision(cfg.model.teacher_revision) or 'unversioned'}")
453
+ tokenizer_model = getattr(cfg.model, "tokenizer", cfg.model.student)
454
+ tokenizer_revision = _config_revision(getattr(cfg.model, "tokenizer_revision", cfg.model.student_revision))
455
+ tokenizer_dir = getattr(cfg.paths, "tokenizer_dir", cfg.paths.student_dir)
456
+ log.info(f" Student: {cfg.model.student}")
457
+ log.info(f" Student rev: {_config_revision(cfg.model.student_revision) or 'unversioned'}")
458
+ log.info(f" Tokenizer: {tokenizer_model}")
459
+ log.info(f" Tokenizer rev:{tokenizer_revision or 'unversioned'}")
460
+ log.info(f" Student dir: {cfg.paths.student_dir}")
461
+ log.info(f" Tokenizer dir:{tokenizer_dir}")
462
+ log.info(f" Remote code: {cfg.model.allow_remote_code}")
463
+ log.info(f" Dataset: {cfg.data.dataset_path}")
464
+ log.info(f" Num samples: {args.num_samples:,}")
465
+ log.info(f" Max seq len: {cfg.data.max_seq_len}")
466
+ if torch.cuda.is_available():
467
+ log.info(f" GPU: {torch.cuda.get_device_name(0)}")
468
+
469
+ if not args.tokenizer_only:
470
+ emit_log_spacing(log)
471
+ log.info("-" * 70)
472
+ log.info(f"Downloading student -> ./{cfg.paths.student_dir}/")
473
+ t0 = time.time()
474
+ try:
475
+ snapshot_download(
476
+ repo_id=cfg.model.student,
477
+ local_dir=cfg.paths.student_dir,
478
+ revision=_config_revision(cfg.model.student_revision),
479
+ ignore_patterns=_IGNORE_PATTERNS,
480
+ )
481
+ size_gb = sum(f.stat().st_size for f in Path(cfg.paths.student_dir).rglob("*") if f.is_file()) / 1e9
482
+ log.info(f"Student downloaded: {size_gb:.1f} GB in {time.time() - t0:.0f}s")
483
+ except Exception as exc:
484
+ log.error(f"Failed to download student: {exc}")
485
+ sys.exit(1)
486
+
487
+ if args.tokenizer_only or Path(tokenizer_dir).resolve() != Path(cfg.paths.student_dir).resolve():
488
+ emit_log_spacing(log)
489
+ log.info("-" * 70)
490
+ _download_tokenizer_artifacts(tokenizer_model, tokenizer_revision, tokenizer_dir, log)
491
+
492
+ if not args.skip_tokenization:
493
+ emit_log_spacing(log)
494
+ log.info("-" * 70)
495
+ log.info(f"Preparing dataset: {cfg.data.dataset_path}")
496
+
497
+ try:
498
+ tokenizer = AutoTokenizer.from_pretrained(
499
+ tokenizer_dir,
500
+ trust_remote_code=cfg.model.allow_remote_code,
501
+ )
502
+ except Exception as exc:
503
+ log.error(format_model_load_error("Student tokenizer load", exc))
504
+ sys.exit(1)
505
+ if tokenizer.pad_token is None:
506
+ tokenizer.pad_token = tokenizer.eos_token
507
+
508
+ os.makedirs(cfg.paths.tokenized_dir, exist_ok=True)
509
+ out_file = os.path.join(cfg.paths.tokenized_dir, "train.jsonl")
510
+ stats_file = os.path.join(cfg.paths.tokenized_dir, _DATA_STATS_FILENAME)
511
+
512
+ log.info(
513
+ f"Streaming + tokenizing up to {args.num_samples:,} samples "
514
+ f"(max_seq_len={cfg.data.max_seq_len}, strict token limit, no truncation)"
515
+ )
516
+ t0 = time.time()
517
+ try:
518
+ token_stats = write_tokenized_dataset(
519
+ tokenizer=tokenizer,
520
+ num_samples=args.num_samples,
521
+ out_file=out_file,
522
+ log=log,
523
+ )
524
+ except RuntimeError as exc:
525
+ log.error(str(exc))
526
+ sys.exit(1)
527
+
528
+ write_data_stats(
529
+ output_path=stats_file,
530
+ stats=token_stats,
531
+ dataset_id=cfg.data.dataset_path,
532
+ config_name="default",
533
+ target_samples=args.num_samples,
534
+ max_seq_len=cfg.data.max_seq_len,
535
+ logger=log,
536
+ )
537
+ n_written = token_stats["written"]
538
+
539
+ if n_written == 0:
540
+ log.error("Tokenization produced 0 usable rows - aborting.")
541
+ sys.exit(1)
542
+
543
+ log.info(f"Pretokenization complete: {n_written:,} samples -> {out_file}")
544
+ else:
545
+ log.info("Skipping dataset tokenization (--skip_tokenization)")
546
+
547
+ if not args.skip_teacher:
548
+ emit_log_spacing(log)
549
+ log.info("-" * 70)
550
+ log.info(f"Downloading teacher -> ./{cfg.paths.teacher_dir}/")
551
+ t0 = time.time()
552
+ try:
553
+ snapshot_download(
554
+ repo_id=cfg.model.teacher,
555
+ local_dir=cfg.paths.teacher_dir,
556
+ revision=_config_revision(cfg.model.teacher_revision),
557
+ ignore_patterns=_IGNORE_PATTERNS,
558
+ )
559
+ size_gb = sum(f.stat().st_size for f in Path(cfg.paths.teacher_dir).rglob("*") if f.is_file()) / 1e9
560
+ log.info(f"Teacher downloaded: {size_gb:.1f} GB in {time.time() - t0:.0f}s")
561
+ except Exception as exc:
562
+ log.error(f"Failed to download teacher: {exc}")
563
+ sys.exit(1)
564
+ else:
565
+ log.info("Skipping teacher download (--skip_teacher)")
566
+
567
+ emit_log_spacing(log)
568
+ log.info("-" * 70)
569
+ log.info("Download complete")
570
+ log.info("-" * 70)
571
+
572
+
573
+ if __name__ == "__main__":
574
+ main()
src/kd_contracts.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ PROVENANCE_SCHEMA_VERSION = 4
9
+
10
+ _SHARD_SCHEMA = {
11
+ "support": "teacher_topk_plus_other_bucket",
12
+ "layout": "chunked_sample_lists",
13
+ "logprobs_dtype": "float16",
14
+ "ids_dtype": "int32",
15
+ "other_logprob_dtype": "float16",
16
+ }
17
+
18
+ _SPECIAL_TOKEN_ID_FIELDS = (
19
+ "bos_token_id",
20
+ "eos_token_id",
21
+ "pad_token_id",
22
+ "unk_token_id",
23
+ "cls_token_id",
24
+ "sep_token_id",
25
+ "mask_token_id",
26
+ "additional_special_tokens_ids",
27
+ )
28
+
29
+ def normalize_config_revision(value: str | None) -> str | None:
30
+ if value is None:
31
+ return None
32
+ stripped = value.strip()
33
+ return stripped or None
34
+
35
+ def canonical_revision(value: str | None) -> str:
36
+ return normalize_config_revision(value) or "unversioned"
37
+
38
+ def sha256_file(path: str | Path, chunk_size: int = 1 << 20) -> str:
39
+ digest = hashlib.sha256()
40
+ with open(path, "rb") as handle:
41
+ while True:
42
+ chunk = handle.read(chunk_size)
43
+ if not chunk:
44
+ break
45
+ digest.update(chunk)
46
+ return digest.hexdigest()
47
+
48
+ def _special_token_ids(tokenizer) -> dict[str, Any]:
49
+ snapshot: dict[str, Any] = {}
50
+ for field in _SPECIAL_TOKEN_ID_FIELDS:
51
+ value = getattr(tokenizer, field, None)
52
+ if isinstance(value, tuple):
53
+ value = list(value)
54
+ snapshot[field] = value
55
+ return snapshot
56
+
57
+ def build_tokenizer_contract(tokenizer) -> dict[str, Any]:
58
+ canonical = {
59
+ "tokenizer_class": tokenizer.__class__.__name__,
60
+ "full_vocab_size": len(tokenizer),
61
+ "special_token_ids": _special_token_ids(tokenizer),
62
+ "vocab": dict(sorted(tokenizer.get_vocab().items())),
63
+ }
64
+ encoded = json.dumps(
65
+ canonical,
66
+ sort_keys=True,
67
+ separators=(",", ":"),
68
+ ensure_ascii=True,
69
+ ).encode("utf-8")
70
+ return {
71
+ "tokenizer_class": canonical["tokenizer_class"],
72
+ "full_vocab_size": canonical["full_vocab_size"],
73
+ "special_token_ids": canonical["special_token_ids"],
74
+ "fingerprint": hashlib.sha256(encoded).hexdigest(),
75
+ }
76
+
77
+ def build_shard_schema() -> dict[str, str]:
78
+ return dict(_SHARD_SCHEMA)
79
+
80
+ def collect_model_vocab_sizes(model) -> dict[str, int]:
81
+ sizes: dict[str, int] = {}
82
+
83
+ config_size = getattr(getattr(model, "config", None), "vocab_size", None)
84
+ if isinstance(config_size, int):
85
+ sizes["config"] = config_size
86
+
87
+ input_embeddings = model.get_input_embeddings()
88
+ if input_embeddings is not None and getattr(input_embeddings, "weight", None) is not None:
89
+ sizes["input_embeddings"] = int(input_embeddings.weight.shape[0])
90
+
91
+ output_embeddings = model.get_output_embeddings()
92
+ if output_embeddings is not None and getattr(output_embeddings, "weight", None) is not None:
93
+ sizes["output_embeddings"] = int(output_embeddings.weight.shape[0])
94
+
95
+ return sizes
src/losses.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ PROB_EPS = 1.0e-12
8
+
9
+ def _normalize_support_logprobs(
10
+ topk_logprobs: torch.Tensor,
11
+ other_logprob: torch.Tensor,
12
+ ) -> tuple[torch.Tensor, torch.Tensor]:
13
+ topk_probs = topk_logprobs.float().exp()
14
+ other_prob = other_logprob.float().exp().unsqueeze(-1)
15
+ support_probs = torch.cat([topk_probs, other_prob], dim=-1).clamp_min(PROB_EPS)
16
+ support_probs = support_probs / support_probs.sum(dim=-1, keepdim=True).clamp_min(PROB_EPS)
17
+ return support_probs, support_probs.log()
18
+
19
+ def _masked_mean(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
20
+ mask = mask.float()
21
+ return (values * mask).sum() / mask.sum().clamp(min=1.0)
22
+
23
+ def compute_sft_ce(logits: torch.Tensor, labels: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor:
24
+ batch_size = logits.size(0)
25
+ shift_labels = labels[:, 1:].contiguous()
26
+ shift_loss_mask = ((loss_mask[:, 1:] > 0) & shift_labels.ne(-100)).contiguous().float()
27
+
28
+ total_loss = torch.tensor(0.0, device=logits.device, dtype=torch.bfloat16)
29
+ total_weight = torch.tensor(0.0, device=logits.device, dtype=torch.bfloat16)
30
+
31
+ for i in range(batch_size):
32
+ b_logits = logits[i, :-1, :]
33
+ b_labels = shift_labels[i]
34
+ b_mask = shift_loss_mask[i]
35
+
36
+ ce = F.cross_entropy(
37
+ b_logits,
38
+ b_labels,
39
+ ignore_index=-100,
40
+ reduction="none",
41
+ )
42
+ total_loss += (ce * b_mask).sum()
43
+ total_weight += b_mask.sum()
44
+
45
+ return total_loss / total_weight.clamp(min=1.0)
46
+
47
+ def _compute_masked_ce_with_logits(logits, labels, loss_mask):
48
+ loss_ce = compute_sft_ce(logits, labels, loss_mask)
49
+ shift_logits = logits[:, :-1, :]
50
+ shift_labels = labels[:, 1:].contiguous()
51
+ shift_loss_mask = ((loss_mask[:, 1:] > 0) & shift_labels.ne(-100)).contiguous().float()
52
+ return loss_ce, shift_logits, shift_loss_mask
53
+
54
+ def compute_distillation_loss(
55
+ student_logits: torch.Tensor,
56
+ labels: torch.Tensor,
57
+ teacher_logprobs: torch.Tensor,
58
+ teacher_ids: torch.Tensor,
59
+ teacher_other_logprob: torch.Tensor,
60
+ loss_mask: torch.Tensor,
61
+ alpha: float,
62
+ temperature: float,
63
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
64
+ vocab_size = student_logits.size(-1)
65
+ loss_ce, shift_logits, shift_loss_mask = _compute_masked_ce_with_logits(student_logits, labels, loss_mask)
66
+
67
+ shift_teacher_logprobs = teacher_logprobs[:, :-1, :].contiguous()
68
+ shift_teacher_ids = teacher_ids[:, :-1, :].contiguous()
69
+ shift_teacher_other_logprob = teacher_other_logprob[:, :-1].contiguous()
70
+ shift_student = shift_logits
71
+
72
+ topk_ids_clamped = shift_teacher_ids.clamp(0, vocab_size - 1)
73
+ student_log_z = torch.logsumexp(shift_student / temperature, dim=-1, keepdim=True).float()
74
+ student_topk_logprobs = shift_student.gather(-1, topk_ids_clamped).float() / temperature - student_log_z
75
+ student_topk_probs = student_topk_logprobs.float().exp()
76
+ student_other_prob = (1.0 - student_topk_probs.sum(dim=-1)).clamp_min(PROB_EPS)
77
+ student_other_logprob = student_other_prob.log()
78
+
79
+ teacher_support_probs, teacher_support_logprobs = _normalize_support_logprobs(
80
+ shift_teacher_logprobs,
81
+ shift_teacher_other_logprob,
82
+ )
83
+ _, student_support_logprobs = _normalize_support_logprobs(
84
+ student_topk_logprobs,
85
+ student_other_logprob,
86
+ )
87
+ positive_teacher = teacher_support_probs > 0
88
+ kl_terms = torch.where(
89
+ positive_teacher,
90
+ teacher_support_probs * (teacher_support_logprobs - student_support_logprobs),
91
+ torch.zeros_like(teacher_support_probs),
92
+ )
93
+ kl_per_token = kl_terms.sum(-1)
94
+ loss_kd = _masked_mean(kl_per_token, shift_loss_mask) * (temperature**2)
95
+
96
+ loss_total = alpha * loss_ce + (1.0 - alpha) * loss_kd
97
+ return loss_total, loss_ce.detach(), loss_kd.detach()
98
+
99
+ def compute_online_kd_loss(
100
+ student_logits: torch.Tensor,
101
+ teacher_logits: torch.Tensor,
102
+ labels: torch.Tensor,
103
+ loss_mask: torch.Tensor,
104
+ alpha: float,
105
+ temperature: float,
106
+ token_chunk_size: int = 2048,
107
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
108
+ loss_ce = compute_sft_ce(student_logits, labels, loss_mask)
109
+
110
+ shift_labels = labels[:, 1:].contiguous()
111
+ shift_loss_mask = (
112
+ (loss_mask[:, 1:] > 0) & shift_labels.ne(-100)
113
+ ).contiguous().float()
114
+
115
+ s_shifted = student_logits[:, :-1, :]
116
+ t_shifted = teacher_logits[:, :-1, :]
117
+ seq_len = s_shifted.size(1)
118
+
119
+ total_kl = torch.tensor(0.0, device=student_logits.device, dtype=torch.float32)
120
+ total_weight = torch.tensor(0.0, device=student_logits.device, dtype=torch.float32)
121
+
122
+ for tok_start in range(0, seq_len, token_chunk_size):
123
+ tok_end = min(tok_start + token_chunk_size, seq_len)
124
+
125
+ s_chunk = s_shifted[:, tok_start:tok_end, :].float()
126
+ t_chunk = t_shifted[:, tok_start:tok_end, :].float()
127
+ mask_chunk = shift_loss_mask[:, tok_start:tok_end]
128
+
129
+ chunk_weight = mask_chunk.sum()
130
+ t_probs = F.softmax(t_chunk / temperature, dim=-1)
131
+ s_log_probs = F.log_softmax(s_chunk / temperature, dim=-1)
132
+ kl_tokens = F.kl_div(
133
+ s_log_probs, t_probs, log_target=False, reduction="none"
134
+ ).sum(dim=-1)
135
+
136
+ total_kl += (kl_tokens * mask_chunk).sum()
137
+ total_weight += chunk_weight
138
+
139
+ del s_chunk, t_chunk, t_probs, s_log_probs, kl_tokens, mask_chunk
140
+
141
+ loss_kd = (total_kl / total_weight.clamp(min=1.0)) * (temperature ** 2)
142
+ loss_kd = loss_kd.to(dtype=student_logits.dtype)
143
+ loss_total = alpha * loss_ce + (1.0 - alpha) * loss_kd
144
+
145
+ return loss_total, loss_ce.detach(), loss_kd.detach()
146
+
147
+ def compute_loss_for_phase(
148
+ phase: str,
149
+ logits: torch.Tensor,
150
+ labels: torch.Tensor,
151
+ loss_mask: torch.Tensor,
152
+ batch: dict,
153
+ alpha: float,
154
+ temperature: float,
155
+ teacher_logits: torch.Tensor | None = None,
156
+ online_kd_token_chunk_size: int = 2048,
157
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
158
+ if phase == "sft":
159
+ loss_ce = compute_sft_ce(logits, labels, loss_mask)
160
+ return loss_ce, loss_ce.detach(), torch.tensor(0.0, device=logits.device)
161
+ if phase == "online_kd":
162
+ return compute_online_kd_loss(
163
+ logits,
164
+ teacher_logits,
165
+ labels,
166
+ loss_mask,
167
+ alpha,
168
+ temperature,
169
+ token_chunk_size=online_kd_token_chunk_size,
170
+ )
171
+ return compute_distillation_loss(
172
+ logits,
173
+ labels,
174
+ batch["teacher_logprobs"],
175
+ batch["teacher_ids"],
176
+ batch["teacher_other_logprob"],
177
+ loss_mask,
178
+ alpha,
179
+ temperature,
180
+ )
src/optim.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ from configs import cfg
6
+
7
+ def fused_adamw_preflight(logger) -> bool:
8
+ if not torch.cuda.is_available():
9
+ logger.info(" Optimizer: fused AdamW requested but CUDA is unavailable; using standard AdamW")
10
+ return False
11
+
12
+ try:
13
+ probe = torch.nn.Parameter(torch.ones(8, device="cuda", dtype=torch.bfloat16))
14
+ probe_optim = torch.optim.AdamW([probe], lr=1.0e-4, fused=True)
15
+ loss = probe.float().square().sum()
16
+ loss.backward()
17
+ probe_optim.step()
18
+ probe_optim.zero_grad(set_to_none=True)
19
+ del loss, probe_optim, probe
20
+ return True
21
+ except Exception as exc:
22
+ logger.warning(f" Optimizer: fused AdamW preflight failed ({exc}); using standard AdamW")
23
+ return False
24
+
25
+ def build_adamw_optimizer(params: list[torch.nn.Parameter], logger, allow_fused: bool) -> torch.optim.Optimizer:
26
+ kwargs = {
27
+ "lr": cfg.training.learning_rate,
28
+ "weight_decay": cfg.training.weight_decay,
29
+ "betas": (0.9, 0.999),
30
+ }
31
+ fused_requested = bool(getattr(cfg.training, "fused_adamw", False)) and allow_fused
32
+ if fused_requested and fused_adamw_preflight(logger):
33
+ try:
34
+ optimizer = torch.optim.AdamW(params, **kwargs, fused=True)
35
+ logger.info(" Optimizer: AdamW fused=True")
36
+ return optimizer
37
+ except Exception as exc:
38
+ logger.warning(f" Optimizer: fused AdamW construction failed ({exc}); using standard AdamW")
39
+ elif bool(getattr(cfg.training, "fused_adamw", False)) and not allow_fused:
40
+ logger.info(" Optimizer: fused AdamW disabled for DeepSpeed")
41
+
42
+ optimizer = torch.optim.AdamW(params, **kwargs)
43
+ logger.info(" Optimizer: AdamW standard")
44
+ return optimizer
src/provenance.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+
7
+ from configs import cfg
8
+ from src.kd_contracts import (
9
+ PROVENANCE_SCHEMA_VERSION,
10
+ build_shard_schema,
11
+ canonical_revision,
12
+ collect_model_vocab_sizes,
13
+ sha256_file,
14
+ )
15
+
16
+ def resolve_model_vocab_size(model, tokenizer, label: str, log) -> int:
17
+ model_sizes = collect_model_vocab_sizes(model)
18
+ if not model_sizes:
19
+ log.error(f"{label} model does not expose a usable vocab size view.")
20
+ raise SystemExit(1)
21
+
22
+ unique_sizes = sorted(set(model_sizes.values()))
23
+ if len(unique_sizes) != 1:
24
+ details = ", ".join(f"{name}={size:,}" for name, size in sorted(model_sizes.items()))
25
+ log.error(f"{label} vocab mismatch across checkpoint artifacts: {details}")
26
+ raise SystemExit(1)
27
+
28
+ model_vocab_size = unique_sizes[0]
29
+ tokenizer_vocab_size = len(tokenizer)
30
+ if model_vocab_size < tokenizer_vocab_size:
31
+ log.error(
32
+ f"{label} tokenizer length ({tokenizer_vocab_size:,}) exceeds "
33
+ f"the model vocab size ({model_vocab_size:,})."
34
+ )
35
+ log.error("Repair or regenerate the checkpoint before using it for distillation.")
36
+ raise SystemExit(1)
37
+ if model_vocab_size > tokenizer_vocab_size:
38
+ log.info(
39
+ f" {label} model vocab is padded beyond the tokenizer range: "
40
+ f"tokenizer={tokenizer_vocab_size:,}, model={model_vocab_size:,}"
41
+ )
42
+ return model_vocab_size
43
+
44
+ def validate_provenance(
45
+ prov_path: str,
46
+ data_path: str,
47
+ dataset,
48
+ teacher_tokenizer_contract: dict,
49
+ student_tokenizer_contract: dict,
50
+ log,
51
+ ) -> None:
52
+ if not os.path.exists(prov_path):
53
+ log.error("Missing _provenance.json in the logits directory.")
54
+ log.error("Regenerate the current teacher-logit shard metadata.")
55
+ raise SystemExit(1)
56
+
57
+ with open(prov_path, "r", encoding="utf-8") as f:
58
+ prov = json.load(f)
59
+
60
+ schema_version = prov.get("schema_version")
61
+ if schema_version != PROVENANCE_SCHEMA_VERSION:
62
+ log.error(
63
+ f"Unsupported shard provenance schema: {schema_version!r}. "
64
+ f"Expected {PROVENANCE_SCHEMA_VERSION}."
65
+ )
66
+ log.error("Regenerate the teacher-logit shards.")
67
+ raise SystemExit(1)
68
+
69
+ teacher_meta = prov.get("teacher", {})
70
+ student_meta = prov.get("student", {})
71
+ current_data_sha = sha256_file(data_path)
72
+ actual_shard_count = sum(1 for _ in Path(prov_path).parent.glob("shard_*.pt"))
73
+ provenance_num_samples = prov.get("num_samples")
74
+ try:
75
+ provenance_num_samples_int = int(provenance_num_samples)
76
+ except (TypeError, ValueError):
77
+ log.error(f"PROVENANCE MISMATCH: num_samples is {provenance_num_samples!r}, expected an integer.")
78
+ log.error("Regenerate compatible teacher-logit shards.")
79
+ raise SystemExit(1)
80
+ if provenance_num_samples_int < len(dataset):
81
+ log.error(
82
+ f"PROVENANCE MISMATCH: num_samples is {provenance_num_samples_int}, "
83
+ f"but the requested dataset has {len(dataset)} samples."
84
+ )
85
+ log.error("Regenerate compatible teacher-logit shards.")
86
+ raise SystemExit(1)
87
+ if provenance_num_samples_int > len(dataset):
88
+ log.warning(
89
+ f" Provenance contains {provenance_num_samples_int:,} samples; "
90
+ f"training is using the first {len(dataset):,}. This is expected for smoke tests."
91
+ )
92
+
93
+ expected_pairs = [
94
+ ("shard_count", prov.get("shard_count"), actual_shard_count),
95
+ ("samples_per_shard", prov.get("samples_per_shard"), dataset.samples_per_shard),
96
+ ("data_sha256", prov.get("data_sha256"), current_data_sha),
97
+ ("max_seq_len", prov.get("max_seq_len"), cfg.data.max_seq_len),
98
+ ("top_k", prov.get("top_k"), cfg.training.top_k),
99
+ ("temperature", prov.get("temperature"), float(cfg.training.temperature)),
100
+ ("teacher.model", teacher_meta.get("model"), cfg.model.teacher),
101
+ (
102
+ "teacher.revision",
103
+ teacher_meta.get("revision"),
104
+ canonical_revision(cfg.model.teacher_revision),
105
+ ),
106
+ (
107
+ "teacher.tokenizer_size",
108
+ teacher_meta.get("tokenizer_size"),
109
+ teacher_tokenizer_contract["full_vocab_size"],
110
+ ),
111
+ (
112
+ "teacher.tokenizer_fingerprint",
113
+ teacher_meta.get("tokenizer_fingerprint"),
114
+ teacher_tokenizer_contract["fingerprint"],
115
+ ),
116
+ ("student.model", student_meta.get("model"), getattr(cfg.model, "tokenizer", cfg.model.student)),
117
+ (
118
+ "student.revision",
119
+ student_meta.get("revision"),
120
+ canonical_revision(getattr(cfg.model, "tokenizer_revision", cfg.model.student_revision)),
121
+ ),
122
+ (
123
+ "student.tokenizer_size",
124
+ student_meta.get("tokenizer_size"),
125
+ student_tokenizer_contract["full_vocab_size"],
126
+ ),
127
+ (
128
+ "student.tokenizer_fingerprint",
129
+ student_meta.get("tokenizer_fingerprint"),
130
+ student_tokenizer_contract["fingerprint"],
131
+ ),
132
+ ]
133
+
134
+ warn_only_fields = {
135
+ "teacher.tokenizer_fingerprint",
136
+ "student.tokenizer_fingerprint",
137
+ }
138
+
139
+ for field_name, found, expected in expected_pairs:
140
+ if found != expected:
141
+ if field_name in warn_only_fields:
142
+ log.warning(
143
+ f" Provenance WARNING (non-fatal): {field_name} is {found!r}, "
144
+ f"expected {expected!r}. This is likely due to a transformers "
145
+ f"library version change. Continuing because vocab sizes match."
146
+ )
147
+ else:
148
+ log.error(
149
+ f"PROVENANCE MISMATCH: {field_name} is {found!r}, expected {expected!r}."
150
+ )
151
+ log.error("Regenerate compatible teacher-logit shards.")
152
+ raise SystemExit(1)
153
+
154
+ provenance_data_path = prov.get("data_path")
155
+ current_data_path = str(Path(data_path).resolve())
156
+ if provenance_data_path != current_data_path:
157
+ log.warning(
158
+ " Provenance data_path differs because logits were likely generated on another machine: "
159
+ f"{provenance_data_path!r} vs {current_data_path!r}. "
160
+ "Continuing because data_sha256 matches."
161
+ )
162
+
163
+ shard_schema = prov.get("shard_schema")
164
+ expected_shard_schema = build_shard_schema()
165
+ if shard_schema != expected_shard_schema:
166
+ log.error(
167
+ f"PROVENANCE MISMATCH: shard_schema is {shard_schema!r}, "
168
+ f"expected {expected_shard_schema!r}."
169
+ )
170
+ log.error("Regenerate compatible teacher-logit shards.")
171
+ raise SystemExit(1)
172
+
173
+ log.info(" Provenance: PASS (teacher shards match the current config and dataset)")
src/sequence_packing.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from bisect import bisect_left, insort
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset
8
+
9
+ from src.training_data import DistillationDataset
10
+
11
+ class SequencePackedDataset(Dataset):
12
+ def __init__(
13
+ self,
14
+ source: DistillationDataset,
15
+ source_indices: list[int],
16
+ pack_length: int,
17
+ eos_token_id: int,
18
+ pad_token_id: int,
19
+ mask_first_after_separator: bool = True,
20
+ ):
21
+ if pack_length <= 0:
22
+ raise ValueError(f"pack_length must be positive, got {pack_length}.")
23
+ if not hasattr(source, "sample_lengths"):
24
+ raise ValueError("Packed training requires a source dataset with sample_lengths metadata.")
25
+ if not source_indices:
26
+ raise ValueError("Packed training requires at least one source row.")
27
+
28
+ self.source = source
29
+ self.source_indices = [int(index) for index in source_indices]
30
+ self.source_index_set = set(self.source_indices)
31
+ if len(self.source_index_set) != len(self.source_indices):
32
+ raise ValueError("Packed training source indices contain duplicates.")
33
+
34
+ self.pack_length = int(pack_length)
35
+ self.eos_token_id = int(eos_token_id)
36
+ self.pad_token_id = int(pad_token_id)
37
+ self.mask_first_after_separator = bool(mask_first_after_separator)
38
+ self._length_by_index: dict[int, int] = {}
39
+ self.plan: list[list[int]] = []
40
+
41
+ for source_index in self.source_indices:
42
+ try:
43
+ length = int(source.sample_lengths[source_index])
44
+ except IndexError as exc:
45
+ raise IndexError(f"Source index {source_index} is outside the tokenized dataset.") from exc
46
+ if length > self.pack_length:
47
+ raise ValueError(
48
+ f"Tokenized sample #{source_index} has length {length}, "
49
+ f"which exceeds pack_length={self.pack_length}."
50
+ )
51
+ self._length_by_index[source_index] = length
52
+
53
+ self._build_plan()
54
+ self._validate_plan()
55
+ self.source_sample_count = len(self.source_indices)
56
+ self.bin_count = len(self.plan)
57
+ self.original_token_count = sum(self._length_by_index.values())
58
+ self.separator_token_count = sum(max(0, len(bin_indices) - 1) for bin_indices in self.plan)
59
+ self.packed_token_count = self.original_token_count + self.separator_token_count
60
+ self.total_capacity = self.bin_count * self.pack_length
61
+ self.pad_token_count = self.total_capacity - self.packed_token_count
62
+ self.average_samples_per_bin = self.source_sample_count / max(self.bin_count, 1)
63
+ self.utilization = self.packed_token_count / max(self.total_capacity, 1)
64
+
65
+ def _build_plan(self) -> None:
66
+ items = sorted(
67
+ ((self._length_by_index[source_index], source_index) for source_index in self.source_indices),
68
+ key=lambda item: (-item[0], item[1]),
69
+ )
70
+ available: list[tuple[int, int]] = []
71
+
72
+ for length, source_index in items:
73
+ required_existing = length + 1
74
+ insert_at = bisect_left(available, (required_existing, -1))
75
+ if insert_at == len(available):
76
+ bin_id = len(self.plan)
77
+ self.plan.append([source_index])
78
+ remaining = self.pack_length - length
79
+ insort(available, (remaining, bin_id))
80
+ continue
81
+
82
+ remaining, bin_id = available.pop(insert_at)
83
+ next_remaining = remaining - required_existing
84
+ if next_remaining < 0:
85
+ raise ValueError("Internal packing error: bin capacity became negative.")
86
+ self.plan[bin_id].append(source_index)
87
+ insort(available, (next_remaining, bin_id))
88
+
89
+ def _validate_plan(self) -> None:
90
+ seen: set[int] = set()
91
+ for bin_id, bin_indices in enumerate(self.plan):
92
+ if not bin_indices:
93
+ raise ValueError(f"Packed bin #{bin_id} is empty.")
94
+ real_length = sum(self._length_by_index[source_index] for source_index in bin_indices)
95
+ real_length += max(0, len(bin_indices) - 1)
96
+ if real_length > self.pack_length:
97
+ raise ValueError(
98
+ f"Packed bin #{bin_id} has real_length={real_length}, "
99
+ f"which exceeds pack_length={self.pack_length}."
100
+ )
101
+ for source_index in bin_indices:
102
+ if source_index in seen:
103
+ raise ValueError(f"Source sample #{source_index} appears in more than one packed bin.")
104
+ seen.add(source_index)
105
+
106
+ missing = self.source_index_set - seen
107
+ if missing:
108
+ first_missing = min(missing)
109
+ raise ValueError(f"Source sample #{first_missing} was not assigned to a packed bin.")
110
+
111
+ def __len__(self) -> int:
112
+ return len(self.plan)
113
+
114
+ def __getitem__(self, bin_idx: int) -> dict[str, torch.Tensor]:
115
+ bin_indices = self.plan[bin_idx]
116
+ input_parts: list[torch.Tensor] = []
117
+ mask_parts: list[torch.Tensor] = []
118
+ original_tokens = 0
119
+ separator_tokens = 0
120
+
121
+ for sample_offset, source_index in enumerate(bin_indices):
122
+ item = self.source[source_index]
123
+ input_ids = item["input_ids"].long()
124
+ loss_mask = item["loss_mask"].long()
125
+ original_tokens += int(input_ids.size(0))
126
+
127
+ if sample_offset > 0:
128
+ input_parts.append(torch.tensor([self.eos_token_id], dtype=torch.long))
129
+ mask_parts.append(torch.zeros(1, dtype=torch.long))
130
+ separator_tokens += 1
131
+ if self.mask_first_after_separator and loss_mask.numel() > 0:
132
+ loss_mask = loss_mask.clone()
133
+ loss_mask[0] = 0
134
+
135
+ input_parts.append(input_ids)
136
+ mask_parts.append(loss_mask)
137
+
138
+ input_ids = torch.cat(input_parts)
139
+ loss_mask = torch.cat(mask_parts)
140
+ real_length = int(input_ids.size(0))
141
+ if real_length > self.pack_length:
142
+ raise ValueError(
143
+ f"Packed bin #{bin_idx} has real_length={real_length}, "
144
+ f"which exceeds pack_length={self.pack_length}."
145
+ )
146
+
147
+ pad_len = self.pack_length - real_length
148
+ if pad_len:
149
+ input_ids = F.pad(input_ids, (0, pad_len), value=self.pad_token_id)
150
+ loss_mask = F.pad(loss_mask, (0, pad_len), value=0)
151
+
152
+ return {
153
+ "input_ids": input_ids,
154
+ "loss_mask": loss_mask,
155
+ "real_length": torch.tensor(real_length, dtype=torch.long),
156
+ "source_samples": torch.tensor(len(bin_indices), dtype=torch.long),
157
+ "original_tokens": torch.tensor(original_tokens, dtype=torch.long),
158
+ "separator_tokens": torch.tensor(separator_tokens, dtype=torch.long),
159
+ }
160
+
161
+ def collate_packed_fn(batch: list[dict], pad_token_id: int) -> dict:
162
+ del pad_token_id
163
+ input_ids = torch.stack([item["input_ids"] for item in batch])
164
+ loss_mask = torch.stack([item["loss_mask"] for item in batch]).long()
165
+ real_lengths = torch.stack([item["real_length"] for item in batch]).long()
166
+
167
+ seq_len = input_ids.size(1)
168
+ positions = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
169
+ attention_mask = (positions < real_lengths.unsqueeze(1)).long()
170
+
171
+ labels = input_ids.clone()
172
+ labels = labels.masked_fill(loss_mask == 0, -100)
173
+
174
+ return {
175
+ "input_ids": input_ids,
176
+ "attention_mask": attention_mask,
177
+ "loss_mask": loss_mask,
178
+ "labels": labels,
179
+ "real_length": real_lengths,
180
+ "source_samples": torch.stack([item["source_samples"] for item in batch]).long(),
181
+ "original_tokens": torch.stack([item["original_tokens"] for item in batch]).long(),
182
+ "separator_tokens": torch.stack([item["separator_tokens"] for item in batch]).long(),
183
+ }
src/train.py ADDED
@@ -0,0 +1,1219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import csv
5
+ import json
6
+ import math
7
+ import os
8
+ import sys
9
+ import time
10
+ from functools import partial
11
+ from pathlib import Path
12
+
13
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
14
+ if str(_REPO_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(_REPO_ROOT))
16
+
17
+ import torch
18
+ from torch.utils.data import DataLoader, Dataset, Subset
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
20
+
21
+ from configs import cfg, emit_log_spacing, setup_logger
22
+ from src.checkpoints import (
23
+ find_latest_training_checkpoint,
24
+ load_trainer_state,
25
+ maybe_upload_checkpoint,
26
+ packing_checkpoint_metadata,
27
+ read_env_flag,
28
+ save_checkpoint,
29
+ validate_resume_packing_state,
30
+ )
31
+ from src.kd_contracts import build_tokenizer_contract
32
+ from src.losses import compute_loss_for_phase
33
+ from src.optim import build_adamw_optimizer
34
+ from src.provenance import resolve_model_vocab_size, validate_provenance
35
+ from src.sequence_packing import SequencePackedDataset, collate_packed_fn
36
+ from src.training_data import (
37
+ DistillationDataset,
38
+ collate_fn,
39
+ extract_shard_id_range,
40
+ move_batch_to_device,
41
+ resolve_dataloader_runtime,
42
+ torch_load_cpu,
43
+ )
44
+ from src.training_schedule import (
45
+ build_train_validation_subsets,
46
+ compute_training_schedule,
47
+ load_deepspeed_runtime_config,
48
+ )
49
+ from src.transformers_compat import format_model_load_error, resolve_attention_backend
50
+ from src.validation import evaluate_validation_loss
51
+
52
+ def _log_gpu(logger) -> None:
53
+ if torch.cuda.is_available():
54
+ device = torch.cuda.current_device()
55
+ alloc = torch.cuda.max_memory_allocated(device) / (1024**3)
56
+ reserved = torch.cuda.max_memory_reserved(device) / (1024**3)
57
+ total = torch.cuda.get_device_properties(device).total_memory / (1024**3)
58
+ pct = alloc / total * 100
59
+ logger.info(f"[GPU] {alloc:.1f}/{total:.0f} GiB ({pct:.0f}%) peak alloc, {reserved:.1f} GiB peak reserved")
60
+
61
+ def main() -> None:
62
+ parser = argparse.ArgumentParser(description="Quintus training (SFT / KD)")
63
+ packing_cfg = getattr(cfg.training, "sequence_packing", None)
64
+ sequence_packing_default = bool(getattr(packing_cfg, "enabled", False))
65
+ pack_length_default = int(getattr(packing_cfg, "pack_length", cfg.data.max_seq_len))
66
+ mask_first_after_separator = bool(getattr(packing_cfg, "mask_first_token_after_separator", True))
67
+ parser.add_argument("--num_samples", type=int, default=cfg.data.num_samples)
68
+ parser.add_argument("--phase", type=str, choices=["sft", "kd", "online_kd"], default="online_kd", help="Training phase")
69
+ parser.add_argument("--resume_from_checkpoint", action="store_true", help="Resume from latest epoch in current output directory")
70
+ parser.add_argument("--init_from_checkpoint", type=str, default=None, help="Initialize weights from a specific path before training")
71
+ parser.add_argument(
72
+ "--compile_model",
73
+ action="store_true",
74
+ default=bool(getattr(cfg.training, "compile_model", False)),
75
+ help="Enable torch.compile after checkpoint loading. Off by default for KD memory safety.",
76
+ )
77
+ parser.add_argument("--local_rank", type=int, default=-1, help=argparse.SUPPRESS)
78
+ parser.add_argument("--deepspeed", type=str, default=None, help="Enable DeepSpeed with the given config path.")
79
+ parser.add_argument("--no_deepspeed", action="store_true", help="Run without DeepSpeed.")
80
+ parser.add_argument(
81
+ "--allow_partial_final_window",
82
+ action="store_true",
83
+ help="Allow DeepSpeed to drop a final incomplete accumulation window during smoke tests.",
84
+ )
85
+ parser.add_argument("--teacher_model", type=str, default=cfg.model.teacher)
86
+ parser.add_argument("--teacher_revision", type=str, default=cfg.model.teacher_revision)
87
+ parser.add_argument("--student_model", type=str, default=cfg.model.student)
88
+ parser.add_argument("--student_revision", type=str, default=cfg.model.student_revision)
89
+ parser.add_argument("--tokenizer_model", type=str, default=getattr(cfg.model, "tokenizer", cfg.model.student))
90
+ parser.add_argument("--tokenizer_revision", type=str, default=getattr(cfg.model, "tokenizer_revision", cfg.model.student_revision))
91
+ parser.add_argument("--student_dir", type=str, default=cfg.paths.student_dir)
92
+ parser.add_argument("--tokenizer_dir", type=str, default=getattr(cfg.paths, "tokenizer_dir", cfg.paths.student_dir))
93
+ parser.add_argument("--distilled_dir", type=str, default=cfg.paths.distilled_dir)
94
+ parser.add_argument("--num_epochs", type=int, default=cfg.training.num_epochs)
95
+ parser.add_argument("--max_steps", type=int, default=-1, help="Stop after this many optimizer steps. -1 = no limit.")
96
+ parser.add_argument("--learning_rate", type=float, default=float(cfg.training.learning_rate))
97
+ parser.add_argument("--alpha", type=float, default=cfg.training.alpha)
98
+ parser.add_argument("--temperature", type=float, default=cfg.training.temperature)
99
+ parser.add_argument(
100
+ "--online_kd_token_chunk_size",
101
+ type=int,
102
+ default=int(getattr(cfg.training, "online_kd_token_chunk_size", 2048)),
103
+ help="Token chunk size for full-vocabulary online KD loss.",
104
+ )
105
+ parser.add_argument("--micro_batch_size", type=int, default=cfg.training.micro_batch_size)
106
+ parser.add_argument("--grad_accum_steps", type=int, default=cfg.training.grad_accum_steps)
107
+ parser.add_argument("--sequence_packing", action="store_true", default=False, help="Enable sequence packing for online_kd.")
108
+ parser.add_argument("--no_sequence_packing", action="store_true", default=False, help="Disable sequence packing.")
109
+ parser.add_argument("--pack_length", type=int, default=None, help="Packed sequence length.")
110
+ parser.add_argument("--disable_checkpointing", action="store_true", default=False, help="Disable intermediate epoch/step/best checkpoint saves.")
111
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=bool(cfg.training.gradient_checkpointing), help="Enable gradient checkpointing (activation checkpointing).")
112
+ parser.add_argument("--upload_kd_checkpoints", action="store_true", default=False)
113
+ parser.add_argument("--upload_step_checkpoints", action="store_true", default=False)
114
+ parser.add_argument(
115
+ "--upload_last_checkpoint",
116
+ action="store_true",
117
+ default=False,
118
+ help="Upload the final 'last' checkpoint to the Hub. Off by default.",
119
+ )
120
+ parser.add_argument(
121
+ "--hub_upload_strict",
122
+ action="store_true",
123
+ default=read_env_flag("QUINTUS_HUB_UPLOAD_STRICT", False),
124
+ help="Fail training if a requested Hub checkpoint upload fails.",
125
+ )
126
+ parser.add_argument("--hub_repo_id", type=str, default=f"{cfg.hub.username}/{cfg.hub.repo_name}")
127
+ parser.add_argument("--ckpt_path_in_repo", type=str, default="models/online_kd_8b_17b_ep1_B200_20260608_alpha0.3")
128
+ parser.add_argument("--commit_message_prefix", type=str, default="Online KD 8B->1.7B B200 Run (alpha=0.3)")
129
+ args = parser.parse_args()
130
+
131
+ if args.sequence_packing and args.no_sequence_packing:
132
+ parser.error("Use either --sequence_packing or --no_sequence_packing, not both.")
133
+ sequence_packing_enabled = sequence_packing_default
134
+ if args.sequence_packing:
135
+ sequence_packing_enabled = True
136
+ elif args.no_sequence_packing:
137
+ sequence_packing_enabled = False
138
+ pack_length = int(args.pack_length if args.pack_length is not None else pack_length_default)
139
+ if pack_length <= 0:
140
+ parser.error(f"--pack_length must be positive, got {pack_length}.")
141
+ if pack_length > int(cfg.data.max_seq_len):
142
+ parser.error(f"--pack_length must be <= data.max_seq_len ({int(cfg.data.max_seq_len)}), got {pack_length}.")
143
+ if sequence_packing_enabled and args.phase != "online_kd":
144
+ parser.error("--sequence_packing is supported only with --phase online_kd.")
145
+ if args.online_kd_token_chunk_size <= 0:
146
+ parser.error(
147
+ f"--online_kd_token_chunk_size must be positive, got {args.online_kd_token_chunk_size}."
148
+ )
149
+
150
+ cfg.model.teacher = args.teacher_model
151
+ cfg.model.teacher_revision = args.teacher_revision
152
+ cfg.model.student = args.student_model
153
+ cfg.model.student_revision = args.student_revision
154
+ cfg.model.tokenizer = args.tokenizer_model
155
+ cfg.model.tokenizer_revision = args.tokenizer_revision
156
+ cfg.paths.student_dir = args.student_dir
157
+ cfg.paths.tokenizer_dir = args.tokenizer_dir
158
+ cfg.paths.distilled_dir = args.distilled_dir
159
+ cfg.training.num_epochs = args.num_epochs
160
+ cfg.training.learning_rate = args.learning_rate
161
+ cfg.training.alpha = args.alpha
162
+ cfg.training.temperature = args.temperature
163
+ cfg.training.online_kd_token_chunk_size = int(args.online_kd_token_chunk_size)
164
+ cfg.training.micro_batch_size = args.micro_batch_size
165
+ cfg.training.grad_accum_steps = args.grad_accum_steps
166
+ cfg.training.gradient_checkpointing = args.gradient_checkpointing
167
+ cfg.training.disable_checkpointing = args.disable_checkpointing
168
+ cfg.training.sequence_packing.enabled = sequence_packing_enabled
169
+ cfg.training.sequence_packing.pack_length = pack_length
170
+ cfg.training.sequence_packing.mask_first_token_after_separator = mask_first_after_separator
171
+ cfg.data.num_samples = args.num_samples
172
+
173
+ from omegaconf import OmegaConf
174
+ if not hasattr(cfg, "hub"):
175
+ cfg.hub = OmegaConf.create()
176
+ cfg.hub.upload_kd_checkpoints = args.upload_kd_checkpoints
177
+ cfg.hub.upload_step_checkpoints = args.upload_step_checkpoints
178
+ cfg.hub.upload_last_checkpoint = args.upload_last_checkpoint
179
+ cfg.hub.hub_upload_strict = args.hub_upload_strict
180
+ cfg.hub.repo_id = args.hub_repo_id
181
+ cfg.hub.ckpt_path_in_repo = args.ckpt_path_in_repo
182
+ cfg.hub.commit_message_prefix = args.commit_message_prefix
183
+
184
+ rank = int(os.environ.get("LOCAL_RANK", args.local_rank))
185
+ log = setup_logger("TRAIN", rank=rank)
186
+
187
+ log.info("=" * 70)
188
+ log.info("Quintus Training")
189
+ log.info("=" * 70)
190
+ tokenizer_dir = getattr(cfg.paths, "tokenizer_dir", cfg.paths.student_dir)
191
+ tokenizer_model = getattr(cfg.model, "tokenizer", cfg.model.student)
192
+
193
+ log.info(f" Student: {cfg.paths.student_dir}")
194
+ log.info(f" Student id: {cfg.model.student}")
195
+ log.info(f" Tokenizer: {tokenizer_dir}")
196
+ log.info(f" Tokenizer id:{tokenizer_model}")
197
+ log.info(f" Num samples: {args.num_samples:,}")
198
+ log.info(f" Epochs: {cfg.training.num_epochs}")
199
+ log.info(f" LR: {cfg.training.learning_rate}")
200
+ log.info(f" Phase: {args.phase}")
201
+ if args.phase in ("kd", "online_kd"):
202
+ log.info(f" CE weight: {cfg.training.alpha}")
203
+ log.info(f" Temperature: {cfg.training.temperature}")
204
+ if args.phase == "online_kd":
205
+ log.info(f" KD chunk: {cfg.training.online_kd_token_chunk_size} tokens")
206
+ log.info(f" Micro batch: {cfg.training.micro_batch_size}")
207
+ log.info(f" Grad accum: {cfg.training.grad_accum_steps}")
208
+ log.info(f" Eff. batch: {cfg.training.micro_batch_size * cfg.training.grad_accum_steps}")
209
+ log.info(f" Val ratio: {cfg.training.validation_ratio:.2%}")
210
+ log.info(f" Remote code: {cfg.model.allow_remote_code}")
211
+ log.info(f" Output dir: {cfg.paths.distilled_dir}")
212
+ log.info(f" Log file: {cfg.paths.log_file}")
213
+ log.info(f" Fused AdamW: {bool(getattr(cfg.training, 'fused_adamw', False))}")
214
+ log.info(
215
+ f" HF upload: regular={cfg.hub.upload_kd_checkpoints} "
216
+ f"steps={cfg.hub.upload_step_checkpoints} "
217
+ f"last={cfg.hub.upload_last_checkpoint} "
218
+ f"strict={cfg.hub.hub_upload_strict}"
219
+ )
220
+ log.info(
221
+ f" HF target: {cfg.hub.repo_id}/"
222
+ f"{cfg.hub.ckpt_path_in_repo}"
223
+ )
224
+ if torch.cuda.is_available():
225
+ log.info(f" GPU: {torch.cuda.get_device_name(0)}")
226
+
227
+ try:
228
+ t_dir = tokenizer_dir
229
+ if not os.path.exists(t_dir):
230
+ log.warning(f"Tokenizer directory '{t_dir}' not found. Falling back to downloading '{tokenizer_model}' from HF Hub.")
231
+ t_dir = tokenizer_model
232
+ tokenizer = AutoTokenizer.from_pretrained(
233
+ t_dir,
234
+ trust_remote_code=cfg.model.allow_remote_code,
235
+ )
236
+ except Exception as exc:
237
+ log.error(format_model_load_error("Student tokenizer load", exc))
238
+ sys.exit(1)
239
+ if tokenizer.pad_token is None:
240
+ tokenizer.pad_token = tokenizer.eos_token
241
+ if sequence_packing_enabled:
242
+ if tokenizer.eos_token_id is None:
243
+ log.error("Sequence packing requires tokenizer.eos_token_id.")
244
+ sys.exit(1)
245
+ if tokenizer.pad_token_id is None:
246
+ log.error("Sequence packing requires tokenizer.pad_token_id.")
247
+ sys.exit(1)
248
+ student_tokenizer_contract = build_tokenizer_contract(tokenizer)
249
+ student_tokenizer_vocab_size = student_tokenizer_contract["full_vocab_size"]
250
+
251
+ if args.phase == "kd":
252
+ _prov_path_for_teacher = os.path.join(cfg.paths.logits_dir, "_provenance.json")
253
+ if os.path.exists(_prov_path_for_teacher):
254
+ with open(_prov_path_for_teacher, "r", encoding="utf-8") as _pf:
255
+ _prov_data = json.load(_pf)
256
+ _teacher_prov = _prov_data.get("teacher", {})
257
+ teacher_tokenizer_contract = {
258
+ "full_vocab_size": _teacher_prov.get("tokenizer_size"),
259
+ "fingerprint": _teacher_prov.get("tokenizer_fingerprint"),
260
+ }
261
+ log.info(
262
+ f" Teacher contract read from provenance: "
263
+ f"vocab={teacher_tokenizer_contract['full_vocab_size']}, "
264
+ f"fingerprint={teacher_tokenizer_contract['fingerprint'][:12]}..."
265
+ )
266
+ else:
267
+ try:
268
+ teacher_tokenizer = AutoTokenizer.from_pretrained(
269
+ cfg.paths.teacher_dir if os.path.exists(cfg.paths.teacher_dir) else cfg.model.teacher,
270
+ trust_remote_code=cfg.model.allow_remote_code,
271
+ )
272
+ except Exception as exc:
273
+ log.error(format_model_load_error("Teacher tokenizer load", exc))
274
+ sys.exit(1)
275
+ teacher_tokenizer_contract = build_tokenizer_contract(teacher_tokenizer)
276
+ del teacher_tokenizer
277
+ else:
278
+ teacher_tokenizer_contract = None
279
+
280
+ attn_impl = resolve_attention_backend(log)
281
+ log.info(f" Attention: {attn_impl}")
282
+
283
+ try:
284
+ from liger_kernel.transformers import apply_liger_kernel_to_qwen3
285
+ apply_liger_kernel_to_qwen3(
286
+ rope=True,
287
+ swiglu=True,
288
+ rms_norm=True,
289
+ cross_entropy=False,
290
+ fused_linear_cross_entropy=False,
291
+ )
292
+ log.info(" Liger: enabled")
293
+ except ImportError:
294
+ if cfg.training.micro_batch_size >= 6:
295
+ log.error(" Liger: missing; install liger-kernel or lower micro_batch_size.")
296
+ raise RuntimeError("liger_kernel is required for micro_batch_size >= 6.")
297
+ else:
298
+ log.warning(" Liger: not installed")
299
+
300
+ try:
301
+ s_dir = cfg.paths.student_dir
302
+ if not os.path.exists(s_dir):
303
+ log.warning(f"Student model directory '{s_dir}' not found. Falling back to downloading '{cfg.model.student}' from HF Hub.")
304
+ s_dir = cfg.model.student
305
+ model = AutoModelForCausalLM.from_pretrained(
306
+ s_dir,
307
+ dtype=torch.bfloat16,
308
+ low_cpu_mem_usage=True,
309
+ trust_remote_code=cfg.model.allow_remote_code,
310
+ attn_implementation=attn_impl,
311
+ )
312
+ except Exception as exc:
313
+ log.error(format_model_load_error("Student model load", exc))
314
+ sys.exit(1)
315
+ model.config.use_cache = False
316
+ if getattr(cfg.training, "gradient_checkpointing", False):
317
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
318
+ log.info(" Grad ckpt: enabled")
319
+ else:
320
+ log.info(" Grad ckpt: disabled")
321
+
322
+
323
+
324
+ start_epoch = 0
325
+ resume_state: dict = {}
326
+ if args.resume_from_checkpoint and args.init_from_checkpoint:
327
+ log.error("Use either --init_from_checkpoint or --resume_from_checkpoint, not both.")
328
+ sys.exit(1)
329
+
330
+ checkpoint_to_load = args.init_from_checkpoint
331
+
332
+ if args.resume_from_checkpoint:
333
+ latest_ckpt = find_latest_training_checkpoint(cfg.paths.distilled_dir)
334
+ if latest_ckpt is None:
335
+ log.error(
336
+ f"--resume_from_checkpoint was set, but no epoch_* or step_* checkpoints were found in "
337
+ f"{cfg.paths.distilled_dir}. Use --init_from_checkpoint for the first KD run."
338
+ )
339
+ sys.exit(1)
340
+ checkpoint_to_load = latest_ckpt
341
+ resume_state = load_trainer_state(latest_ckpt, log)
342
+ checkpoint_type = resume_state.get("checkpoint_type", os.path.basename(latest_ckpt).split("_")[0])
343
+ start_epoch = int(resume_state.get("start_epoch", 0) or 0)
344
+ if checkpoint_type == "epoch":
345
+ log.info(f"Interrupted run detected. Resuming after completed epoch {start_epoch}")
346
+ else:
347
+ log.info(
348
+ f"Interrupted run detected. Resuming from {os.path.basename(latest_ckpt)} "
349
+ f"at epoch_index={start_epoch}, next_batch_in_epoch="
350
+ f"{int(resume_state.get('next_batch_in_epoch', 0) or 0)}"
351
+ )
352
+ validate_resume_packing_state(
353
+ resume_state,
354
+ enabled=sequence_packing_enabled,
355
+ pack_length=pack_length,
356
+ max_seq_len=int(cfg.data.max_seq_len),
357
+ log=log,
358
+ )
359
+
360
+ if checkpoint_to_load:
361
+ log.info(f"Loading weights from: {checkpoint_to_load}")
362
+ try:
363
+ from safetensors.torch import load_file
364
+ ckpt_file = os.path.join(checkpoint_to_load, "model.safetensors")
365
+ if not os.path.exists(ckpt_file):
366
+ ckpt_file = os.path.join(checkpoint_to_load, "pytorch_model.bin")
367
+
368
+ if ckpt_file.endswith(".safetensors"):
369
+ state_dict = load_file(ckpt_file)
370
+ else:
371
+ state_dict = torch.load(ckpt_file, map_location="cpu")
372
+
373
+ new_state_dict = {}
374
+ for k, v in state_dict.items():
375
+ if k.startswith("_orig_mod."):
376
+ new_state_dict[k[len("_orig_mod."):]] = v
377
+ else:
378
+ new_state_dict[k] = v
379
+
380
+ model.load_state_dict(new_state_dict)
381
+ log.info("Weights loaded.")
382
+ except Exception as e:
383
+ log.error(f"Failed to load weights: {e}")
384
+ sys.exit(1)
385
+
386
+ model.train()
387
+
388
+ if args.compile_model:
389
+ log.info(" Compile: enabled")
390
+ model = torch.compile(model, dynamic=True)
391
+ else:
392
+ log.info(" Compile: disabled")
393
+
394
+ torch.set_float32_matmul_precision("high")
395
+
396
+ student_model_vocab_size = resolve_model_vocab_size(model, tokenizer, "Student", log)
397
+ log.info(
398
+ f" Student V: tokenizer={student_tokenizer_vocab_size:,} "
399
+ f"model={student_model_vocab_size:,}"
400
+ )
401
+ _log_gpu(log)
402
+
403
+ if args.phase == "kd":
404
+ shard0 = os.path.join(cfg.paths.logits_dir, "shard_000000.pt")
405
+ if os.path.exists(shard0):
406
+ test_shard = torch_load_cpu(shard0)
407
+ try:
408
+ min_id, max_id = extract_shard_id_range(test_shard, shard0)
409
+ except (KeyError, ValueError) as exc:
410
+ log.error(str(exc))
411
+ sys.exit(1)
412
+ if min_id < 0:
413
+ log.error(f" Negative IDs (min={min_id}); int16 overflow.")
414
+ log.error(" Regenerate shards.")
415
+ sys.exit(1)
416
+ if max_id >= student_tokenizer_vocab_size:
417
+ log.error(
418
+ f"VOCAB MISMATCH: shard max_id={max_id} >= "
419
+ f"student tokenizer vocab={student_tokenizer_vocab_size}"
420
+ )
421
+ sys.exit(1)
422
+ log.info(
423
+ f" Vocab check: PASS (ids in [{min_id}, {max_id}], "
424
+ f"reachable tokenizer V={student_tokenizer_vocab_size})"
425
+ )
426
+ else:
427
+ log.warning(f" Shard {shard0} not found; skipping vocab check")
428
+
429
+ data_path = os.path.join(cfg.paths.tokenized_dir, "train.jsonl")
430
+ dataset = DistillationDataset(data_path, cfg.paths.logits_dir, cfg.data.max_seq_len, args.num_samples, args.phase)
431
+ log.info(f" Dataset: {len(dataset):,} samples")
432
+
433
+ if args.phase == "kd":
434
+ prov_path = os.path.join(cfg.paths.logits_dir, "_provenance.json")
435
+ validate_provenance(
436
+ prov_path=prov_path,
437
+ data_path=data_path,
438
+ dataset=dataset,
439
+ teacher_tokenizer_contract=teacher_tokenizer_contract,
440
+ student_tokenizer_contract=student_tokenizer_contract,
441
+ log=log,
442
+ )
443
+
444
+ pad_id = tokenizer.pad_token_id
445
+ if args.no_deepspeed:
446
+ args.deepspeed = None
447
+ use_ds = args.deepspeed is not None
448
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
449
+ if world_size != 1:
450
+ log.error("This training path is single-GPU only. Re-run with NUM_GPUS=1.")
451
+ sys.exit(1)
452
+ is_main = rank in (-1, 0)
453
+
454
+ ds_runtime_config = None
455
+ if use_ds:
456
+ try:
457
+ ds_runtime_config = load_deepspeed_runtime_config(
458
+ args.deepspeed,
459
+ micro_batch_size=cfg.training.micro_batch_size,
460
+ grad_accum=cfg.training.grad_accum_steps,
461
+ )
462
+ except (OSError, ValueError, json.JSONDecodeError) as exc:
463
+ log.error(str(exc))
464
+ sys.exit(1)
465
+
466
+ train_dataset, val_dataset, split_meta = build_train_validation_subsets(
467
+ dataset=dataset,
468
+ validation_ratio=float(cfg.training.validation_ratio),
469
+ split_seed=int(cfg.training.split_seed),
470
+ micro_batch_size=cfg.training.micro_batch_size,
471
+ grad_accum=cfg.training.grad_accum_steps,
472
+ num_epochs=cfg.training.num_epochs,
473
+ use_ds=use_ds,
474
+ )
475
+ log.info(
476
+ f" Train split: {len(train_dataset):,} samples | "
477
+ f"Val split: {int(split_meta['validation_size']):,} samples"
478
+ )
479
+ if bool(split_meta["accumulation_aligned"]):
480
+ log.info(
481
+ f" Accum align: train split is divisible by effective batch "
482
+ f"{int(split_meta['effective_batch_size']):,}"
483
+ )
484
+ else:
485
+ if use_ds:
486
+ log.warning(
487
+ f" Accum align: train split leaves "
488
+ f"{int(split_meta['train_remainder_batches'])} partial accumulation batches per epoch; "
489
+ "DeepSpeed will carry partial accumulation across epoch boundaries"
490
+ )
491
+ else:
492
+ log.warning(
493
+ f" Accum align: train split leaves "
494
+ f"{int(split_meta['train_remainder_batches'])} partial accumulation batches per epoch; "
495
+ "the fallback flush path will rescale gradients correctly"
496
+ )
497
+ if bool(split_meta["adjusted"]):
498
+ log.info(
499
+ f" Val align: requested {int(split_meta['requested_validation_size']):,} "
500
+ f"({float(split_meta['requested_validation_ratio']) * 100:.2f}%), "
501
+ f"using {int(split_meta['validation_size']):,} "
502
+ f"({float(split_meta['actual_validation_ratio']) * 100:.2f}%) "
503
+ "to preserve the training schedule"
504
+ )
505
+ elif val_dataset is not None:
506
+ log.info(
507
+ f" Val split: using {float(split_meta['actual_validation_ratio']) * 100:.2f}% "
508
+ f"held out with split_seed={cfg.training.split_seed}"
509
+ )
510
+ else:
511
+ log.warning(" Validation disabled; tracking training loss.")
512
+
513
+ effective_train_dataset: Dataset = train_dataset
514
+ train_collate = partial(collate_fn, pad_token_id=pad_id)
515
+ val_collate = partial(collate_fn, pad_token_id=pad_id)
516
+ if sequence_packing_enabled:
517
+ if isinstance(train_dataset, Subset):
518
+ source_dataset = train_dataset.dataset
519
+ train_source_indices = [int(index) for index in train_dataset.indices]
520
+ else:
521
+ source_dataset = train_dataset
522
+ train_source_indices = list(range(len(train_dataset)))
523
+
524
+ if not isinstance(source_dataset, DistillationDataset):
525
+ log.error("Sequence packing requires DistillationDataset as the split source.")
526
+ sys.exit(1)
527
+
528
+ val_source_indices: set[int] = set()
529
+ if isinstance(val_dataset, Subset) and val_dataset.dataset is source_dataset:
530
+ val_source_indices = {int(index) for index in val_dataset.indices}
531
+
532
+ try:
533
+ packed_train_dataset = SequencePackedDataset(
534
+ source=source_dataset,
535
+ source_indices=train_source_indices,
536
+ pack_length=pack_length,
537
+ eos_token_id=int(tokenizer.eos_token_id),
538
+ pad_token_id=int(tokenizer.pad_token_id),
539
+ mask_first_after_separator=mask_first_after_separator,
540
+ )
541
+ except (IndexError, ValueError) as exc:
542
+ log.error(str(exc))
543
+ sys.exit(1)
544
+
545
+ overlap = packed_train_dataset.source_index_set.intersection(val_source_indices)
546
+ if overlap:
547
+ first_overlap = min(overlap)
548
+ log.error(f"Sequence packing split error: validation sample #{first_overlap} appears in training bins.")
549
+ sys.exit(1)
550
+
551
+ effective_train_dataset = packed_train_dataset
552
+ train_collate = partial(collate_packed_fn, pad_token_id=pad_id)
553
+ log.info(" Packing: enabled")
554
+ log.info(f" Pack length: {packed_train_dataset.pack_length:,}")
555
+ log.info(f" Train bins: {packed_train_dataset.bin_count:,}")
556
+ log.info(f" Train rows: {packed_train_dataset.source_sample_count:,}")
557
+ log.info(f" Avg samples: {packed_train_dataset.average_samples_per_bin:.2f} per bin")
558
+ log.info(f" Original tokens: {packed_train_dataset.original_token_count:,}")
559
+ log.info(f" Separator tokens: {packed_train_dataset.separator_token_count:,}")
560
+ log.info(f" Pad tokens: {packed_train_dataset.pad_token_count:,}")
561
+ log.info(f" Utilization: {packed_train_dataset.utilization * 100:.1f}%")
562
+ else:
563
+ log.info(" Packing: disabled")
564
+
565
+ dataloader_runtime = resolve_dataloader_runtime()
566
+ log.info(
567
+ " DataLoader: "
568
+ f"workers={int(dataloader_runtime['num_workers'])} "
569
+ f"pin_memory={bool(dataloader_runtime['pin_memory'])} "
570
+ f"persistent={bool(dataloader_runtime.get('persistent_workers', False))}"
571
+ )
572
+
573
+ dataloader = DataLoader(
574
+ effective_train_dataset,
575
+ batch_size=cfg.training.micro_batch_size,
576
+ shuffle=(args.phase != "kd"),
577
+ collate_fn=train_collate,
578
+ drop_last=True,
579
+ **dataloader_runtime,
580
+ )
581
+ if args.phase == "kd":
582
+ log.info(" KD sampler: sequential shard-local order (split membership remains randomized)")
583
+ val_dataloader = None
584
+ if val_dataset is not None:
585
+ val_dataloader = DataLoader(
586
+ val_dataset,
587
+ batch_size=cfg.training.micro_batch_size,
588
+ shuffle=False,
589
+ collate_fn=val_collate,
590
+ drop_last=False,
591
+ **dataloader_runtime,
592
+ )
593
+
594
+ grad_accum = cfg.training.grad_accum_steps
595
+ schedule = compute_training_schedule(
596
+ dataset_size=len(effective_train_dataset),
597
+ micro_batch_size=cfg.training.micro_batch_size,
598
+ grad_accum=grad_accum,
599
+ num_epochs=cfg.training.num_epochs,
600
+ use_ds=use_ds,
601
+ drop_last=True,
602
+ )
603
+ batches_per_epoch = int(schedule["batches_per_epoch"])
604
+ remainder_batches = int(schedule["remainder_batches"])
605
+ has_remainder = bool(schedule["has_remainder"])
606
+ total_micro_batches = int(schedule["total_micro_batches"])
607
+ steps_per_epoch = int(schedule["steps_per_epoch"])
608
+ total_steps = int(schedule["total_steps"])
609
+ final_remainder = int(schedule["final_remainder"])
610
+
611
+ if batches_per_epoch == 0:
612
+ schedule_unit = "packed bins" if sequence_packing_enabled else "samples"
613
+ log.error(
614
+ f"Dataset too small for micro_batch_size={cfg.training.micro_batch_size}. "
615
+ f"Train split has {len(effective_train_dataset)} {schedule_unit} and drop_last=True would produce 0 batches."
616
+ )
617
+ sys.exit(1)
618
+
619
+ dropped_samples_per_epoch = int(schedule["dropped_samples_per_epoch"])
620
+ if dropped_samples_per_epoch:
621
+ schedule_unit = "packed bins" if sequence_packing_enabled else "samples"
622
+ log.warning(
623
+ f" drop_last=True will discard {dropped_samples_per_epoch} {schedule_unit} per epoch "
624
+ "before gradient accumulation begins"
625
+ )
626
+
627
+ if use_ds and final_remainder:
628
+ dropped_total = int(schedule["dropped_samples_total"])
629
+ schedule_unit = "packed bins" if sequence_packing_enabled else "samples"
630
+ message = (
631
+ f"DeepSpeed would drop the final {final_remainder} micro-batches "
632
+ f"({dropped_total} {schedule_unit} total) because {batches_per_epoch} batches per epoch "
633
+ f"across {cfg.training.num_epochs} epochs yields {total_micro_batches} micro-batches, "
634
+ f"which is not divisible by grad_accum={grad_accum}."
635
+ )
636
+ if not args.allow_partial_final_window:
637
+ log.error(message)
638
+ log.error(
639
+ "Adjust num_samples, micro_batch_size, grad_accum_steps, or num_epochs "
640
+ "so total micro-batches is divisible by grad_accum, or rerun with "
641
+ "--allow_partial_final_window for a smoke test."
642
+ )
643
+ sys.exit(1)
644
+ log.warning(message)
645
+ log.warning("Proceeding because --allow_partial_final_window was set.")
646
+
647
+ warmup_steps = int(total_steps * cfg.training.warmup_ratio)
648
+ if has_remainder:
649
+ if use_ds:
650
+ log.info(
651
+ f" NOTE: {batches_per_epoch} batches are not divisible by grad_accum={grad_accum}; "
652
+ f"DeepSpeed carries {remainder_batches} leftover micro-batches across epoch boundaries"
653
+ )
654
+ if final_remainder and args.allow_partial_final_window:
655
+ log.info(
656
+ f" NOTE: only the final {final_remainder} micro-batches of "
657
+ "the last epoch are dropped because they never reach a full accumulation window"
658
+ )
659
+ else:
660
+ log.info(
661
+ f" NOTE: {batches_per_epoch} batches are not divisible by grad_accum={grad_accum}; "
662
+ f"the training loop will flush {remainder_batches} leftover micro-batches each epoch"
663
+ )
664
+
665
+ if not use_ds:
666
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
667
+ model.to(device)
668
+
669
+ optimizer = build_adamw_optimizer(list(model.parameters()), log, allow_fused=not use_ds)
670
+ scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
671
+ resume_global_step = int(resume_state.get("global_step", 0) or 0) if args.resume_from_checkpoint else 0
672
+ saved_run_epochs = int(resume_state.get("num_epochs", cfg.training.num_epochs) or cfg.training.num_epochs)
673
+ extending_completed_run = (
674
+ args.resume_from_checkpoint
675
+ and saved_run_epochs < cfg.training.num_epochs
676
+ and start_epoch >= saved_run_epochs
677
+ )
678
+ scheduler_state_path = os.path.join(checkpoint_to_load, "scheduler.pt") if checkpoint_to_load else None
679
+ if (
680
+ extending_completed_run
681
+ and read_env_flag("QUINTUS_FRESH_SCHEDULER_ON_EXTEND", True)
682
+ ):
683
+ remaining_steps = max(1, total_steps - resume_global_step)
684
+ extension_warmup_steps = int(remaining_steps * cfg.training.warmup_ratio)
685
+ scheduler = get_cosine_schedule_with_warmup(optimizer, extension_warmup_steps, remaining_steps)
686
+ log.info(
687
+ f" Scheduler: fresh extension schedule "
688
+ f"({remaining_steps:,} remaining steps, {extension_warmup_steps:,} warmup); "
689
+ f"checkpoint was saved for {saved_run_epochs} epochs"
690
+ )
691
+ elif args.resume_from_checkpoint and scheduler_state_path and os.path.exists(scheduler_state_path):
692
+ try:
693
+ scheduler.load_state_dict(torch.load(scheduler_state_path, map_location="cpu"))
694
+ for param_group, lr in zip(optimizer.param_groups, scheduler.get_last_lr()):
695
+ param_group["lr"] = lr
696
+ log.info(f" Scheduler: restored from {scheduler_state_path}")
697
+ except Exception as exc:
698
+ log.warning(f" Scheduler restore failed ({exc}); continuing with a fresh schedule")
699
+ log.info(f" Batches/ep: {batches_per_epoch:,}")
700
+ step_label = "Steps/ep"
701
+ step_note = ""
702
+ if has_remainder:
703
+ if use_ds:
704
+ step_label = "Steps/ep*"
705
+ step_note = " (floor; cross-epoch carry shifts exact epoch boundaries)"
706
+ else:
707
+ step_note = " (includes remainder flush)"
708
+ log.info(f" {step_label}: {steps_per_epoch:,}{step_note}")
709
+ log.info(f" Steps total: {total_steps:,} ({warmup_steps:,} warmup)")
710
+ log.info(
711
+ " Best ckpt: held-out validation loss"
712
+ if val_dataloader is not None
713
+ else " Best ckpt: training loss (validation disabled)"
714
+ )
715
+
716
+ if use_ds:
717
+ import deepspeed
718
+
719
+ model, optimizer, _, scheduler = deepspeed.initialize(
720
+ model=model,
721
+ optimizer=optimizer,
722
+ lr_scheduler=scheduler,
723
+ config=ds_runtime_config,
724
+ )
725
+ device = model.device
726
+ log.info("[DS] DeepSpeed ZeRO-2 initialized")
727
+ log.info(f"[DS] DeepSpeed will accumulate over {grad_accum} micro-batches internally")
728
+ else:
729
+ log.info(f" Device: {device}")
730
+
731
+ _log_gpu(log)
732
+
733
+ teacher_model = None
734
+ if args.phase == "online_kd":
735
+ teacher_source = cfg.paths.teacher_dir if os.path.exists(cfg.paths.teacher_dir) else cfg.model.teacher
736
+ if teacher_source != cfg.model.teacher:
737
+ log.info(f"Loading frozen teacher model from local directory '{teacher_source}' on device {device}...")
738
+ else:
739
+ log.info(f"Loading frozen teacher model '{teacher_source}' on device {device}...")
740
+ try:
741
+ teacher_model = AutoModelForCausalLM.from_pretrained(
742
+ teacher_source,
743
+ dtype=torch.bfloat16,
744
+ low_cpu_mem_usage=True,
745
+ trust_remote_code=cfg.model.allow_remote_code,
746
+ attn_implementation=attn_impl,
747
+ ).to(device)
748
+ for p in teacher_model.parameters():
749
+ p.requires_grad = False
750
+ teacher_model.eval()
751
+ log.info(f"Teacher model '{teacher_source}' loaded and frozen.")
752
+ except Exception as exc:
753
+ log.error(f"Failed to load teacher model: {exc}")
754
+ sys.exit(1)
755
+
756
+ checkpoint_packing_metadata = packing_checkpoint_metadata(
757
+ enabled=sequence_packing_enabled,
758
+ pack_length=pack_length,
759
+ max_seq_len=int(cfg.data.max_seq_len),
760
+ )
761
+
762
+ os.makedirs(cfg.paths.distilled_dir, exist_ok=True)
763
+ loss_log: list[dict] = []
764
+ global_step = resume_global_step
765
+ micro_step_global = int(resume_state.get("micro_step_global", 0) or 0) if args.resume_from_checkpoint else 0
766
+ best_metric_name = "validation loss" if val_dataloader is not None else "training loss"
767
+ best_selection_loss = float("inf")
768
+ if args.resume_from_checkpoint and "best_selection_loss" in resume_state:
769
+ try:
770
+ best_selection_loss = float(resume_state["best_selection_loss"])
771
+ log.info(f" Best resume: restored prior best {best_metric_name}={best_selection_loss:.4f}")
772
+ except (TypeError, ValueError):
773
+ log.warning(" Best resume: prior best_selection_loss was unreadable; recomputing from this run")
774
+ best_checkpoint_tag = resume_state.get("best_checkpoint_tag")
775
+ best_ckpt_path = os.path.join(cfg.paths.distilled_dir, "best")
776
+ if not os.path.isdir(best_ckpt_path):
777
+ best_ckpt_path = None
778
+ if best_checkpoint_tag:
779
+ candidate_best_path = os.path.join(cfg.paths.distilled_dir, str(best_checkpoint_tag))
780
+ if os.path.isdir(candidate_best_path):
781
+ best_ckpt_path = candidate_best_path
782
+ log.info(f" Best resume: using {best_checkpoint_tag} as the current best checkpoint")
783
+ t_start = time.time()
784
+
785
+ alpha = cfg.training.alpha
786
+ temperature = cfg.training.temperature
787
+ log_every = max(1, min(50, total_steps // 20))
788
+ checkpoint_every_steps = max(0, int(os.environ.get("TRAIN_CHECKPOINT_EVERY_STEPS", "2000")))
789
+ if getattr(cfg.training, "disable_checkpointing", False):
790
+ checkpoint_every_steps = 0
791
+
792
+ running_loss = 0.0
793
+ running_ce = 0.0
794
+ running_kd = 0.0
795
+ running_count = 0
796
+
797
+ emit_log_spacing(log)
798
+ log.info("-" * 70)
799
+ log.info("Training Start")
800
+ if checkpoint_every_steps:
801
+ log.info(f" Mid-epoch checkpoint interval: every {checkpoint_every_steps:,} optimizer steps")
802
+ else:
803
+ log.info(" Mid-epoch checkpoints disabled")
804
+ log.info("-" * 70)
805
+
806
+ window_tokens = 0
807
+ window_t_start = time.time()
808
+ _gpu_loss_accum = torch.zeros(1, device=device)
809
+ _gpu_ce_accum = torch.zeros(1, device=device)
810
+ _gpu_kd_accum = torch.zeros(1, device=device)
811
+ _gpu_tokens_accum = torch.zeros(1, dtype=torch.long, device=device)
812
+ training_complete = False
813
+
814
+ for epoch in range(start_epoch, cfg.training.num_epochs):
815
+ if training_complete:
816
+ break
817
+ t_epoch = time.time()
818
+ epoch_loss = 0.0
819
+ epoch_ce = 0.0
820
+ epoch_kd = 0.0
821
+ epoch_steps = 0
822
+ epoch_tokens = 0
823
+ micro_in_epoch = 0
824
+ resume_batch_offset = 0
825
+ if args.resume_from_checkpoint and epoch == start_epoch:
826
+ resume_batch_offset = int(resume_state.get("next_batch_in_epoch", 0) or 0)
827
+ if resume_batch_offset:
828
+ log.info(f" Resume: skipping {resume_batch_offset:,} already-processed batches in epoch {epoch + 1}")
829
+
830
+ for batch_idx, batch in enumerate(dataloader):
831
+ if resume_batch_offset and batch_idx < resume_batch_offset:
832
+ continue
833
+ batch = move_batch_to_device(batch, device)
834
+ input_ids = batch["input_ids"]
835
+ attention_mask = batch["attention_mask"]
836
+ labels = batch["labels"]
837
+ loss_mask = batch["loss_mask"]
838
+
839
+ logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
840
+
841
+ if args.phase == "online_kd" and teacher_model is not None:
842
+ with torch.no_grad():
843
+ teacher_logits = teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits
844
+ else:
845
+ teacher_logits = None
846
+
847
+ loss, ce, kd = compute_loss_for_phase(
848
+ args.phase,
849
+ logits,
850
+ labels,
851
+ loss_mask,
852
+ batch,
853
+ alpha,
854
+ temperature,
855
+ teacher_logits=teacher_logits,
856
+ online_kd_token_chunk_size=int(cfg.training.online_kd_token_chunk_size),
857
+ )
858
+ if not torch.isfinite(loss):
859
+ log.error(
860
+ f"Non-finite loss in phase={args.phase}: "
861
+ f"loss={loss.item()} ce={ce.item()} kd={kd.item()}"
862
+ )
863
+ if args.phase == "kd":
864
+ log.error("Action: regenerate teacher logits.")
865
+ else:
866
+ log.error("Action: check dataset / reduce LR.")
867
+ sys.exit(1)
868
+
869
+ micro_in_epoch += 1
870
+ micro_step_global += 1
871
+
872
+ _gpu_loss_accum += loss.detach()
873
+ _gpu_ce_accum += ce.detach()
874
+ _gpu_kd_accum += kd.detach()
875
+ _gpu_tokens_accum += attention_mask.sum()
876
+
877
+ if use_ds:
878
+ model.backward(loss)
879
+ model.step()
880
+ else:
881
+ scaled = loss / grad_accum
882
+ scaled.backward()
883
+
884
+ if micro_in_epoch % grad_accum == 0:
885
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
886
+ optimizer.step()
887
+ scheduler.step()
888
+ optimizer.zero_grad(set_to_none=True)
889
+
890
+ is_optim_step = (
891
+ (micro_step_global % grad_accum == 0) if use_ds else (micro_in_epoch % grad_accum == 0)
892
+ )
893
+
894
+ if is_optim_step:
895
+ global_step += 1
896
+ epoch_steps += 1
897
+ running_count += 1
898
+
899
+ step_loss = _gpu_loss_accum.item() / grad_accum
900
+ step_ce = _gpu_ce_accum.item() / grad_accum
901
+ step_kd = _gpu_kd_accum.item() / grad_accum
902
+ step_tokens = _gpu_tokens_accum.item()
903
+ _gpu_loss_accum.zero_()
904
+ _gpu_ce_accum.zero_()
905
+ _gpu_kd_accum.zero_()
906
+ _gpu_tokens_accum.zero_()
907
+
908
+ epoch_tokens += step_tokens
909
+ window_tokens += step_tokens
910
+ epoch_loss += step_loss
911
+ epoch_ce += step_ce
912
+ epoch_kd += step_kd
913
+ running_loss += step_loss
914
+ running_ce += step_ce
915
+ running_kd += step_kd
916
+
917
+ if global_step % log_every == 0 or global_step == total_steps:
918
+ avg_loss = running_loss / max(running_count, 1)
919
+ avg_ce = running_ce / max(running_count, 1)
920
+ avg_kd = running_kd / max(running_count, 1)
921
+ try:
922
+ lr = scheduler.get_last_lr()[0]
923
+ except Exception:
924
+ lr = cfg.training.learning_rate
925
+ window_elapsed = max(time.time() - window_t_start, 0.1)
926
+ rolling_tok_s = window_tokens / window_elapsed
927
+ rolling_eta_s = (window_elapsed / max(running_count, 1)) * (total_steps - global_step) / log_every * running_count
928
+ cum_tok_s = epoch_tokens / max(time.time() - t_epoch, 1)
929
+ log.info(
930
+ f" E{epoch + 1}/{cfg.training.num_epochs} "
931
+ f"S{global_step:>4}/{total_steps} | "
932
+ f"loss={avg_loss:.4f} ce={avg_ce:.4f} kd={avg_kd:.4f} | "
933
+ f"lr={lr:.2e} | {rolling_tok_s:,.0f} tok/s (avg {cum_tok_s:,.0f}) | ETA {rolling_eta_s / 60:.1f}m"
934
+ )
935
+ loss_log.append(
936
+ {
937
+ "step": global_step,
938
+ "epoch": epoch + 1,
939
+ "loss_total": round(avg_loss, 5),
940
+ "loss_ce": round(avg_ce, 5),
941
+ "loss_kd": round(avg_kd, 5),
942
+ "lr": lr,
943
+ "tok_per_sec": round(rolling_tok_s, 0),
944
+ "tok_per_sec_cumulative": round(cum_tok_s, 0),
945
+ }
946
+ )
947
+ window_tokens = 0
948
+ window_t_start = time.time()
949
+
950
+ running_loss = 0.0
951
+ running_ce = 0.0
952
+ running_kd = 0.0
953
+ running_count = 0
954
+
955
+ if checkpoint_every_steps and global_step % checkpoint_every_steps == 0 and is_main:
956
+ log.info(f" Saving mid-epoch checkpoint at step {global_step}...")
957
+ step_tag = f"step_{global_step}"
958
+ step_ckpt_path = save_checkpoint(
959
+ model,
960
+ tokenizer,
961
+ cfg.paths.distilled_dir,
962
+ step_tag,
963
+ log,
964
+ scheduler=scheduler,
965
+ trainer_state={
966
+ **checkpoint_packing_metadata,
967
+ "checkpoint_type": "step",
968
+ "phase": args.phase,
969
+ "epoch_index": epoch,
970
+ "start_epoch": epoch,
971
+ "global_step": global_step,
972
+ "micro_step_global": micro_step_global,
973
+ "next_batch_in_epoch": micro_in_epoch,
974
+ "num_epochs": cfg.training.num_epochs,
975
+ "micro_batch_size": cfg.training.micro_batch_size,
976
+ "grad_accum_steps": grad_accum,
977
+ },
978
+ )
979
+ maybe_upload_checkpoint(step_ckpt_path, step_tag, log)
980
+
981
+ if args.max_steps > 0 and global_step >= args.max_steps:
982
+ log.info(f"Reached max_steps={args.max_steps}. Stopping training.")
983
+ training_complete = True
984
+ break
985
+
986
+ if training_complete:
987
+ break
988
+
989
+ if not use_ds:
990
+ remainder = micro_in_epoch % grad_accum
991
+ if remainder != 0:
992
+ flush_scale = grad_accum / remainder
993
+ for parameter in model.parameters():
994
+ if parameter.grad is not None:
995
+ parameter.grad.mul_(flush_scale)
996
+
997
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
998
+ optimizer.step()
999
+ scheduler.step()
1000
+ optimizer.zero_grad(set_to_none=True)
1001
+ global_step += 1
1002
+ epoch_steps += 1
1003
+
1004
+ step_loss = _gpu_loss_accum.item() / remainder
1005
+ step_ce = _gpu_ce_accum.item() / remainder
1006
+ step_kd = _gpu_kd_accum.item() / remainder
1007
+ step_tokens = _gpu_tokens_accum.item()
1008
+ _gpu_loss_accum.zero_()
1009
+ _gpu_ce_accum.zero_()
1010
+ _gpu_kd_accum.zero_()
1011
+ _gpu_tokens_accum.zero_()
1012
+
1013
+ epoch_tokens += step_tokens
1014
+ window_tokens += step_tokens
1015
+ running_loss += step_loss
1016
+ running_ce += step_ce
1017
+ running_kd += step_kd
1018
+ running_count += 1
1019
+
1020
+ avg_loss = running_loss / max(running_count, 1)
1021
+ avg_ce = running_ce / max(running_count, 1)
1022
+ avg_kd = running_kd / max(running_count, 1)
1023
+ epoch_loss += step_loss
1024
+ epoch_ce += step_ce
1025
+ epoch_kd += step_kd
1026
+ running_loss = 0.0
1027
+ running_ce = 0.0
1028
+ running_kd = 0.0
1029
+ running_count = 0
1030
+
1031
+ elapsed = time.time() - t_start
1032
+ try:
1033
+ lr = scheduler.get_last_lr()[0]
1034
+ except Exception:
1035
+ lr = cfg.training.learning_rate
1036
+ tok_s = epoch_tokens / max(time.time() - t_epoch, 1)
1037
+ eta_s = (elapsed / max(global_step, 1)) * (total_steps - global_step)
1038
+ log.info(
1039
+ f" E{epoch + 1}/{cfg.training.num_epochs} "
1040
+ f"S{global_step:>4}/{total_steps} | "
1041
+ f"loss={avg_loss:.4f} ce={avg_ce:.4f} kd={avg_kd:.4f} | "
1042
+ f"lr={lr:.2e} | {tok_s:,.0f} tok/s | ETA {eta_s / 60:.1f}m [flush]"
1043
+ )
1044
+ loss_log.append(
1045
+ {
1046
+ "step": global_step,
1047
+ "epoch": epoch + 1,
1048
+ "loss_total": round(avg_loss, 5),
1049
+ "loss_ce": round(avg_ce, 5),
1050
+ "loss_kd": round(avg_kd, 5),
1051
+ "lr": lr,
1052
+ "tok_per_sec": round(tok_s, 0),
1053
+ }
1054
+ )
1055
+ window_tokens = 0
1056
+ window_t_start = time.time()
1057
+ log.info(f" Epoch {epoch + 1}: flushed {remainder} leftover micro-batches")
1058
+ else:
1059
+ optimizer.zero_grad(set_to_none=True)
1060
+ elif (micro_step_global % grad_accum) != 0 and epoch < cfg.training.num_epochs - 1:
1061
+ carry = micro_step_global % grad_accum
1062
+ log.info(f" Epoch {epoch + 1}: carrying {carry} micro-batches into the next epoch")
1063
+
1064
+ avg_epoch_loss = epoch_loss / max(epoch_steps, 1)
1065
+ avg_epoch_ce = epoch_ce / max(epoch_steps, 1)
1066
+ avg_epoch_kd = epoch_kd / max(epoch_steps, 1)
1067
+ epoch_elapsed = time.time() - t_epoch
1068
+ log.info(
1069
+ f" Epoch {epoch + 1} done | "
1070
+ f"avg_loss={avg_epoch_loss:.4f} ce={avg_epoch_ce:.4f} kd={avg_epoch_kd:.4f} | "
1071
+ f"{epoch_tokens:,} tok | {epoch_elapsed / 60:.1f}m"
1072
+ )
1073
+ _log_gpu(log)
1074
+
1075
+ val_metrics = None
1076
+ if val_dataloader is not None:
1077
+ val_start = time.time()
1078
+ val_limit = min(20, len(val_dataloader)) if args.max_steps > 0 else -1
1079
+ if val_limit > 0:
1080
+ log.info(f" Validation start | capping at {val_limit} batches for dry run (total {len(val_dataloader)} batches)")
1081
+ else:
1082
+ log.info(f" Validation start | {len(val_dataloader):,} batches")
1083
+ val_metrics = evaluate_validation_loss(
1084
+ phase=args.phase,
1085
+ model=model,
1086
+ dataloader=val_dataloader,
1087
+ device=device,
1088
+ alpha=alpha,
1089
+ temperature=temperature,
1090
+ online_kd_token_chunk_size=int(cfg.training.online_kd_token_chunk_size),
1091
+ teacher_model=teacher_model,
1092
+ max_batches=val_limit,
1093
+ )
1094
+ log.info(
1095
+ f" Validation | loss={val_metrics['loss']:.4f} ce={val_metrics['ce']:.4f} "
1096
+ f"kd={val_metrics['kd']:.4f} | {int(val_metrics['batches'])} batches | "
1097
+ f"{(time.time() - val_start) / 60:.1f}m"
1098
+ )
1099
+
1100
+ if is_main:
1101
+ selection_loss = val_metrics["loss"] if val_metrics is not None else avg_epoch_loss
1102
+ is_new_best = selection_loss < best_selection_loss
1103
+ epoch_tag = f"epoch_{epoch + 1}"
1104
+ if is_new_best:
1105
+ best_selection_loss = selection_loss
1106
+ best_checkpoint_tag = epoch_tag
1107
+ log.info(f" Best update: {best_metric_name}={best_selection_loss:.4f} from {epoch_tag}")
1108
+ else:
1109
+ log.info(
1110
+ f" Best unchanged: current {best_metric_name}={selection_loss:.4f}; "
1111
+ f"best={best_selection_loss:.4f} from {best_checkpoint_tag}"
1112
+ )
1113
+ epoch_state = {
1114
+ **checkpoint_packing_metadata,
1115
+ "checkpoint_type": "epoch",
1116
+ "phase": args.phase,
1117
+ "epoch_index": epoch,
1118
+ "start_epoch": epoch + 1,
1119
+ "global_step": global_step,
1120
+ "micro_step_global": micro_step_global,
1121
+ "next_batch_in_epoch": 0,
1122
+ "num_epochs": cfg.training.num_epochs,
1123
+ "micro_batch_size": cfg.training.micro_batch_size,
1124
+ "grad_accum_steps": grad_accum,
1125
+ "selection_loss": float(selection_loss),
1126
+ "best_selection_loss": float(best_selection_loss),
1127
+ "best_metric_name": best_metric_name,
1128
+ "best_checkpoint_tag": best_checkpoint_tag,
1129
+ }
1130
+ if read_env_flag("QUINTUS_SAVE_EPOCH_CHECKPOINTS", True) and not getattr(cfg.training, "disable_checkpointing", False):
1131
+ epoch_ckpt_path = save_checkpoint(
1132
+ model,
1133
+ tokenizer,
1134
+ cfg.paths.distilled_dir,
1135
+ epoch_tag,
1136
+ log,
1137
+ scheduler=scheduler,
1138
+ trainer_state=epoch_state,
1139
+ )
1140
+ maybe_upload_checkpoint(epoch_ckpt_path, epoch_tag, log)
1141
+ else:
1142
+ log.info(f" Skipping intermediate {epoch_tag} save")
1143
+ if is_new_best and not getattr(cfg.training, "disable_checkpointing", False):
1144
+ best_ckpt_path = save_checkpoint(
1145
+ model,
1146
+ tokenizer,
1147
+ cfg.paths.distilled_dir,
1148
+ "best",
1149
+ log,
1150
+ scheduler=scheduler,
1151
+ trainer_state=dict(epoch_state, checkpoint_type="best"),
1152
+ )
1153
+
1154
+ if use_ds and final_remainder:
1155
+ model.zero_grad()
1156
+ running_loss = 0.0
1157
+ running_ce = 0.0
1158
+ running_kd = 0.0
1159
+ running_count = 0
1160
+ log.warning(f" Training end: dropped final {final_remainder} leftover micro-batches")
1161
+
1162
+ if is_main:
1163
+ if best_ckpt_path and os.path.isdir(best_ckpt_path) and not getattr(cfg.training, "disable_checkpointing", False):
1164
+ maybe_upload_checkpoint(best_ckpt_path, "best", log)
1165
+ last_ckpt_path = save_checkpoint(
1166
+ model,
1167
+ tokenizer,
1168
+ cfg.paths.distilled_dir,
1169
+ "last",
1170
+ log,
1171
+ scheduler=scheduler,
1172
+ trainer_state={
1173
+ **checkpoint_packing_metadata,
1174
+ "checkpoint_type": "last",
1175
+ "phase": args.phase,
1176
+ "start_epoch": cfg.training.num_epochs,
1177
+ "global_step": global_step,
1178
+ "micro_step_global": micro_step_global,
1179
+ "next_batch_in_epoch": 0,
1180
+ "num_epochs": cfg.training.num_epochs,
1181
+ "micro_batch_size": cfg.training.micro_batch_size,
1182
+ "grad_accum_steps": grad_accum,
1183
+ "best_selection_loss": float(best_selection_loss) if math.isfinite(best_selection_loss) else None,
1184
+ "best_metric_name": best_metric_name,
1185
+ "best_checkpoint_tag": best_checkpoint_tag,
1186
+ },
1187
+ )
1188
+ maybe_upload_checkpoint(last_ckpt_path, "last", log)
1189
+
1190
+ csv_path = os.path.join(cfg.paths.distilled_dir, cfg.paths.loss_csv)
1191
+ if loss_log and is_main:
1192
+ with open(csv_path, "w", newline="", encoding="utf-8") as f:
1193
+ writer = csv.DictWriter(f, fieldnames=loss_log[0].keys())
1194
+ writer.writeheader()
1195
+ writer.writerows(loss_log)
1196
+ log.info(f"Loss CSV -> {csv_path}")
1197
+
1198
+ total_elapsed = time.time() - t_start
1199
+ emit_log_spacing(log)
1200
+ log.info("=" * 70)
1201
+ log.info("Training complete")
1202
+ log.info(f" Wall time: {total_elapsed / 3600:.2f}h ({total_elapsed / 60:.1f}m)")
1203
+ log.info(f" Optim steps: {global_step}")
1204
+ log.info(f" Micro steps: {micro_step_global}")
1205
+ log.info(f" Best {best_metric_name}: {best_selection_loss:.4f}")
1206
+ log.info(f" Best ckpt: {best_ckpt_path}")
1207
+ log.info(f" Output dir: {cfg.paths.distilled_dir}/")
1208
+ log.info("=" * 70)
1209
+
1210
+
1211
+ if __name__ == "__main__":
1212
+ try:
1213
+ main()
1214
+ except Exception:
1215
+ try:
1216
+ setup_logger("TRAIN").exception("Uncaught training failure")
1217
+ except Exception:
1218
+ pass
1219
+ raise
src/training_data.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import Dataset
9
+
10
+ from configs import cfg
11
+
12
+ PAD_MULTIPLE = 128
13
+
14
+ def torch_load_cpu(path: str) -> dict:
15
+ try:
16
+ return torch.load(path, map_location="cpu", weights_only=True)
17
+ except TypeError:
18
+ return torch.load(path, map_location="cpu")
19
+
20
+ def extract_shard_id_range(shard_payload: dict, shard_path: str) -> tuple[int, int]:
21
+ try:
22
+ ids_payload = shard_payload["ids"]
23
+ except KeyError as exc:
24
+ raise KeyError(
25
+ f"Teacher shard {shard_path} is missing 'ids'. Regenerate the teacher-logit shards."
26
+ ) from exc
27
+
28
+ if torch.is_tensor(ids_payload):
29
+ if ids_payload.numel() == 0:
30
+ raise ValueError(
31
+ f"Teacher shard {shard_path} has an empty ids tensor. Regenerate the teacher-logit shards."
32
+ )
33
+ return int(ids_payload.min().item()), int(ids_payload.max().item())
34
+
35
+ if not isinstance(ids_payload, list) or not ids_payload:
36
+ raise ValueError(
37
+ f"Teacher shard {shard_path} has an incompatible ids payload. "
38
+ "Regenerate the teacher-logit shards."
39
+ )
40
+
41
+ min_id: int | None = None
42
+ max_id: int | None = None
43
+ for sample_idx, ids_tensor in enumerate(ids_payload):
44
+ if not torch.is_tensor(ids_tensor):
45
+ raise ValueError(
46
+ f"Teacher shard {shard_path} sample #{sample_idx} has a non-tensor ids payload. "
47
+ "Regenerate the teacher-logit shards."
48
+ )
49
+ if ids_tensor.numel() == 0:
50
+ continue
51
+ sample_min = int(ids_tensor.min().item())
52
+ sample_max = int(ids_tensor.max().item())
53
+ min_id = sample_min if min_id is None else min(min_id, sample_min)
54
+ max_id = sample_max if max_id is None else max(max_id, sample_max)
55
+
56
+ if min_id is None or max_id is None:
57
+ raise ValueError(
58
+ f"Teacher shard {shard_path} only contains empty ids tensors. "
59
+ "Regenerate the teacher-logit shards."
60
+ )
61
+ return min_id, max_id
62
+
63
+ class DistillationDataset(Dataset):
64
+ def __init__(self, data_path: str, logits_dir: str, max_seq_len: int, num_samples: int = -1, phase: str = "kd"):
65
+ self.phase = phase
66
+ self.data_path = data_path
67
+ self.logits_dir = logits_dir
68
+ self.max_seq_len = max_seq_len
69
+ self.samples_per_shard = self._resolve_samples_per_shard()
70
+ self.sample_offsets: list[int] = []
71
+ self.sample_lengths: list[int] = []
72
+ self.sample_target_counts: list[int] = []
73
+ self._data_handle = None
74
+ self._cached_shard_idx: int | None = None
75
+ self._cached_shard_path: str | None = None
76
+ self._cached_shard_payload: dict | None = None
77
+
78
+ with open(data_path, "r", encoding="utf-8") as f:
79
+ while True:
80
+ if 0 < num_samples <= len(self.sample_offsets):
81
+ break
82
+ offset = f.tell()
83
+ line = f.readline()
84
+ if not line:
85
+ break
86
+ i = len(self.sample_offsets)
87
+ raw_sample = json.loads(line)
88
+ input_ids_list, loss_mask_list = self._coerce_tokenized_row(raw_sample, i)
89
+ self.sample_offsets.append(offset)
90
+ self.sample_lengths.append(len(input_ids_list))
91
+ self.sample_target_counts.append(sum(loss_mask_list))
92
+
93
+ def __len__(self) -> int:
94
+ return len(self.sample_offsets)
95
+
96
+ def __getstate__(self) -> dict:
97
+ state = self.__dict__.copy()
98
+ state["_data_handle"] = None
99
+ state["_cached_shard_idx"] = None
100
+ state["_cached_shard_path"] = None
101
+ state["_cached_shard_payload"] = None
102
+ return state
103
+
104
+ def __del__(self) -> None:
105
+ data_handle = getattr(self, "_data_handle", None)
106
+ if data_handle is not None:
107
+ try:
108
+ data_handle.close()
109
+ except Exception:
110
+ pass
111
+
112
+ def _resolve_samples_per_shard(self) -> int:
113
+ prov_path = os.path.join(self.logits_dir, "_provenance.json")
114
+ if not os.path.exists(prov_path):
115
+ return 1
116
+
117
+ try:
118
+ with open(prov_path, "r", encoding="utf-8") as f:
119
+ prov = json.load(f)
120
+ except (OSError, json.JSONDecodeError):
121
+ return 1
122
+
123
+ shard_schema = prov.get("shard_schema", {})
124
+ if shard_schema.get("layout") != "chunked_sample_lists":
125
+ return 1
126
+
127
+ raw_value = prov.get("samples_per_shard", 1)
128
+ try:
129
+ value = int(raw_value)
130
+ except (TypeError, ValueError):
131
+ return 1
132
+ return max(value, 1)
133
+
134
+ def _coerce_tokenized_row(self, raw_sample: dict, idx: int) -> tuple[list[int], list[int]]:
135
+ try:
136
+ input_ids = raw_sample["input_ids"][: self.max_seq_len]
137
+ except KeyError as exc:
138
+ raise KeyError(
139
+ f"Tokenized sample #{idx} is missing 'input_ids'. "
140
+ "Re-run download.py to regenerate the tokenized dataset."
141
+ ) from exc
142
+ try:
143
+ loss_mask = raw_sample["loss_mask"][: len(input_ids)]
144
+ except KeyError as exc:
145
+ raise KeyError(
146
+ "Tokenized sample is missing 'loss_mask'. Re-run download.py to regenerate "
147
+ "assistant-only training targets before distilling."
148
+ ) from exc
149
+
150
+ if not isinstance(input_ids, list) or len(input_ids) == 0:
151
+ raise ValueError(
152
+ f"Tokenized sample #{idx} has incompatible input_ids payload. "
153
+ "Re-run download.py to regenerate."
154
+ )
155
+ if not isinstance(loss_mask, list) or len(loss_mask) != len(input_ids):
156
+ raise ValueError(
157
+ f"Tokenized sample #{idx} has incompatible loss_mask length {len(loss_mask)}. "
158
+ "Re-run download.py to regenerate assistant-only targets."
159
+ )
160
+
161
+ normalized_mask = [int(value) for value in loss_mask]
162
+ if any(value not in (0, 1) for value in normalized_mask):
163
+ raise ValueError(
164
+ f"Tokenized sample #{idx} has non-binary loss_mask values. "
165
+ "Re-run download.py to regenerate assistant-only targets."
166
+ )
167
+ if sum(normalized_mask) == 0:
168
+ raise ValueError(
169
+ f"Tokenized sample #{idx} has no assistant target tokens. "
170
+ "Re-run download.py to filter invalid conversations."
171
+ )
172
+
173
+ return [int(token_id) for token_id in input_ids], normalized_mask
174
+
175
+ def _data_file(self):
176
+ if self._data_handle is None:
177
+ self._data_handle = open(self.data_path, "r", encoding="utf-8")
178
+ return self._data_handle
179
+
180
+ def _load_raw_sample(self, idx: int) -> dict:
181
+ data_file = self._data_file()
182
+ data_file.seek(self.sample_offsets[idx])
183
+ line = data_file.readline()
184
+ if not line:
185
+ raise IndexError(f"Tokenized sample #{idx} could not be read from {self.data_path}.")
186
+ return json.loads(line)
187
+
188
+ def _load_shard_payload(self, shard_idx: int) -> tuple[str, dict]:
189
+ if self._cached_shard_idx == shard_idx and self._cached_shard_payload is not None and self._cached_shard_path is not None:
190
+ return self._cached_shard_path, self._cached_shard_payload
191
+
192
+ shard_path = os.path.join(self.logits_dir, f"shard_{shard_idx:06d}.pt")
193
+ if not os.path.exists(shard_path):
194
+ raise FileNotFoundError(
195
+ f"Missing teacher logit shard: {shard_path}. "
196
+ "Regenerate the teacher-logit shards."
197
+ )
198
+
199
+ payload = torch_load_cpu(shard_path)
200
+ self._cached_shard_idx = shard_idx
201
+ self._cached_shard_path = shard_path
202
+ self._cached_shard_payload = payload
203
+ return shard_path, payload
204
+
205
+ def _load_teacher_tensors(self, idx: int, seq_len: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
206
+ if self.samples_per_shard <= 1:
207
+ shard_path, shard = self._load_shard_payload(idx)
208
+ try:
209
+ teacher_logprobs = shard["logprobs"][:seq_len]
210
+ teacher_ids = shard["ids"][:seq_len]
211
+ teacher_other_logprob = shard["other_logprob"][:seq_len]
212
+ except KeyError as exc:
213
+ missing = exc.args[0]
214
+ raise KeyError(
215
+ f"Shard {shard_path} is missing {missing!r}. "
216
+ "Regenerate the current teacher-logit shards."
217
+ ) from exc
218
+ return teacher_logprobs, teacher_ids, teacher_other_logprob
219
+
220
+ shard_idx = idx // self.samples_per_shard
221
+ sample_offset = idx % self.samples_per_shard
222
+ shard_path, shard = self._load_shard_payload(shard_idx)
223
+ try:
224
+ count = int(shard["count"])
225
+ start_idx = int(shard["start_idx"])
226
+ logprobs_list = shard["logprobs"]
227
+ ids_list = shard["ids"]
228
+ other_list = shard["other_logprob"]
229
+ except KeyError as exc:
230
+ missing = exc.args[0]
231
+ raise KeyError(
232
+ f"Grouped shard {shard_path} is missing {missing!r}. "
233
+ "Regenerate the current teacher-logit shards."
234
+ ) from exc
235
+
236
+ expected_start_idx = shard_idx * self.samples_per_shard
237
+ if start_idx != expected_start_idx:
238
+ raise ValueError(
239
+ f"Grouped shard {shard_path} starts at sample {start_idx}, "
240
+ f"expected {expected_start_idx}. Regenerate the teacher-logit shards."
241
+ )
242
+ if not (len(logprobs_list) == len(ids_list) == len(other_list) == count):
243
+ raise ValueError(
244
+ f"Grouped shard {shard_path} has inconsistent sample counts. "
245
+ "Regenerate the current teacher-logit shards."
246
+ )
247
+ if sample_offset >= count:
248
+ raise FileNotFoundError(
249
+ f"Grouped shard {shard_path} does not contain sample #{idx} "
250
+ f"(start_idx={start_idx}, count={count}). Regenerate the teacher-logit shards."
251
+ )
252
+ try:
253
+ teacher_logprobs = logprobs_list[sample_offset][:seq_len]
254
+ teacher_ids = ids_list[sample_offset][:seq_len]
255
+ teacher_other_logprob = other_list[sample_offset][:seq_len]
256
+ except (IndexError, TypeError) as exc:
257
+ raise ValueError(
258
+ f"Grouped shard {shard_path} has an incompatible payload layout. "
259
+ "Regenerate the current teacher-logit shards."
260
+ ) from exc
261
+ return teacher_logprobs, teacher_ids, teacher_other_logprob
262
+
263
+ def __getitem__(self, idx: int) -> dict:
264
+ raw_sample = self._load_raw_sample(idx)
265
+ input_ids_list, loss_mask_list = self._coerce_tokenized_row(raw_sample, idx)
266
+ input_ids = torch.tensor(input_ids_list, dtype=torch.long)
267
+ loss_mask = torch.tensor(loss_mask_list, dtype=torch.long)
268
+ seq_len = int(input_ids.size(0))
269
+
270
+ if self.phase in ("sft", "online_kd"):
271
+ return {"input_ids": input_ids, "loss_mask": loss_mask}
272
+
273
+ teacher_logprobs, teacher_ids, teacher_other_logprob = self._load_teacher_tensors(idx, seq_len)
274
+ if teacher_logprobs.shape[0] != seq_len:
275
+ raise ValueError(
276
+ f"Teacher shard for sample #{idx} has length {teacher_logprobs.shape[0]}, "
277
+ f"but the tokenized row has length {seq_len}. Regenerate the teacher-logit shards; "
278
+ "teacher shards must be in original JSONL row order."
279
+ )
280
+ if teacher_logprobs.ndim != 2 or teacher_ids.shape != teacher_logprobs.shape:
281
+ raise ValueError(
282
+ f"Teacher shard for sample #{idx} has incompatible top-k tensor shapes: "
283
+ f"logprobs={tuple(teacher_logprobs.shape)}, ids={tuple(teacher_ids.shape)}. "
284
+ "Regenerate the current teacher-logit shards."
285
+ )
286
+ if teacher_other_logprob.ndim != 1 or teacher_other_logprob.shape[0] != teacher_logprobs.shape[0]:
287
+ raise ValueError(
288
+ f"Teacher shard for sample #{idx} has incompatible other-bucket shape: "
289
+ f"other_logprob={tuple(teacher_other_logprob.shape)}, "
290
+ f"expected ({teacher_logprobs.shape[0]},). "
291
+ "Regenerate the current teacher-logit shards."
292
+ )
293
+ if teacher_logprobs.shape[1] != cfg.training.top_k:
294
+ raise ValueError(
295
+ f"Teacher shard for sample #{idx} stores top_k={teacher_logprobs.shape[1]}, "
296
+ f"expected {cfg.training.top_k}. "
297
+ "Regenerate compatible teacher-logit shards."
298
+ )
299
+
300
+ return {
301
+ "input_ids": input_ids,
302
+ "loss_mask": loss_mask,
303
+ "teacher_logprobs": teacher_logprobs,
304
+ "teacher_ids": teacher_ids.long(),
305
+ "teacher_other_logprob": teacher_other_logprob,
306
+ }
307
+
308
+ def collate_fn(batch: list[dict], pad_token_id: int) -> dict:
309
+ raw_max = max(item["input_ids"].size(0) for item in batch)
310
+ max_len = ((raw_max + PAD_MULTIPLE - 1) // PAD_MULTIPLE) * PAD_MULTIPLE
311
+
312
+ input_ids_list, mask_list, loss_mask_list, labels_list = [], [], [], []
313
+ teacher_logprobs_list, teacher_ids_list, teacher_other_logprob_list = [], [], []
314
+
315
+ for item in batch:
316
+ seq_len = item["input_ids"].size(0)
317
+ pad_len = max_len - seq_len
318
+ padded_loss_mask = F.pad(item["loss_mask"], (0, pad_len), value=0)
319
+ padded_labels = F.pad(item["input_ids"].clone(), (0, pad_len), value=pad_token_id)
320
+ padded_labels = padded_labels.masked_fill(padded_loss_mask == 0, -100)
321
+
322
+ input_ids_list.append(F.pad(item["input_ids"], (0, pad_len), value=pad_token_id))
323
+ mask_list.append(
324
+ torch.cat(
325
+ [
326
+ torch.ones(seq_len, dtype=torch.long),
327
+ torch.zeros(pad_len, dtype=torch.long),
328
+ ]
329
+ )
330
+ )
331
+ loss_mask_list.append(padded_loss_mask)
332
+ labels_list.append(padded_labels)
333
+
334
+ if "teacher_logprobs" in item:
335
+ teacher_seq_len = item["teacher_logprobs"].size(0)
336
+ teacher_pad_len = max_len - teacher_seq_len
337
+ teacher_logprobs_list.append(
338
+ F.pad(item["teacher_logprobs"], (0, 0, 0, teacher_pad_len), value=float("-inf"))
339
+ )
340
+ teacher_ids_list.append(F.pad(item["teacher_ids"], (0, 0, 0, teacher_pad_len), value=0))
341
+ teacher_other_logprob_list.append(
342
+ F.pad(item["teacher_other_logprob"], (0, teacher_pad_len), value=float("-inf"))
343
+ )
344
+
345
+ result = {
346
+ "input_ids": torch.stack(input_ids_list),
347
+ "attention_mask": torch.stack(mask_list),
348
+ "loss_mask": torch.stack(loss_mask_list).long(),
349
+ "labels": torch.stack(labels_list),
350
+ }
351
+ if teacher_logprobs_list:
352
+ result["teacher_logprobs"] = torch.stack(teacher_logprobs_list)
353
+ result["teacher_ids"] = torch.stack(teacher_ids_list)
354
+ result["teacher_other_logprob"] = torch.stack(teacher_other_logprob_list)
355
+ return result
356
+
357
+ def resolve_dataloader_runtime() -> dict[str, int | bool]:
358
+ cpu_count = max(1, os.cpu_count() or 1)
359
+ configured_workers = int(getattr(cfg.training, "dataloader_workers", 4))
360
+ num_workers = max(0, min(configured_workers, cpu_count))
361
+ runtime: dict[str, int | bool] = {
362
+ "num_workers": num_workers,
363
+ "pin_memory": torch.cuda.is_available(),
364
+ }
365
+ if num_workers > 0:
366
+ runtime["persistent_workers"] = True
367
+ runtime["prefetch_factor"] = max(1, int(getattr(cfg.training, "prefetch_factor", 2)))
368
+ return runtime
369
+
370
+ def move_batch_to_device(batch: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
371
+ non_blocking = device.type == "cuda"
372
+ return {
373
+ name: tensor.to(device, non_blocking=non_blocking)
374
+ for name, tensor in batch.items()
375
+ }
src/training_schedule.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset, Subset
8
+
9
+ def compute_training_schedule(
10
+ dataset_size: int,
11
+ micro_batch_size: int,
12
+ grad_accum: int,
13
+ num_epochs: int,
14
+ use_ds: bool,
15
+ drop_last: bool = True,
16
+ ) -> dict[str, int | bool]:
17
+ if dataset_size < 0:
18
+ raise ValueError("dataset_size must be >= 0")
19
+ if micro_batch_size <= 0 or grad_accum <= 0 or num_epochs <= 0:
20
+ raise ValueError("micro_batch_size, grad_accum, and num_epochs must all be positive")
21
+
22
+ if drop_last:
23
+ batches_per_epoch = dataset_size // micro_batch_size
24
+ used_samples_per_epoch = batches_per_epoch * micro_batch_size
25
+ dropped_samples_per_epoch = dataset_size - used_samples_per_epoch
26
+ else:
27
+ batches_per_epoch = math.ceil(dataset_size / micro_batch_size) if dataset_size else 0
28
+ used_samples_per_epoch = dataset_size
29
+ dropped_samples_per_epoch = 0
30
+
31
+ total_micro_batches = batches_per_epoch * num_epochs
32
+ remainder_batches = batches_per_epoch % grad_accum if batches_per_epoch else 0
33
+ has_remainder = remainder_batches != 0
34
+
35
+ if use_ds:
36
+ steps_per_epoch = batches_per_epoch // grad_accum
37
+ total_steps = total_micro_batches // grad_accum
38
+ final_remainder = total_micro_batches % grad_accum
39
+ else:
40
+ steps_per_epoch = batches_per_epoch // grad_accum + (1 if has_remainder and batches_per_epoch else 0)
41
+ total_steps = steps_per_epoch * num_epochs
42
+ final_remainder = 0
43
+
44
+ return {
45
+ "batches_per_epoch": batches_per_epoch,
46
+ "used_samples_per_epoch": used_samples_per_epoch,
47
+ "dropped_samples_per_epoch": dropped_samples_per_epoch,
48
+ "remainder_batches": remainder_batches,
49
+ "has_remainder": has_remainder,
50
+ "total_micro_batches": total_micro_batches,
51
+ "steps_per_epoch": steps_per_epoch,
52
+ "total_steps": total_steps,
53
+ "final_remainder": final_remainder,
54
+ "dropped_samples_total": final_remainder * micro_batch_size if use_ds else 0,
55
+ }
56
+
57
+ def choose_validation_size(
58
+ dataset_size: int,
59
+ validation_ratio: float,
60
+ micro_batch_size: int,
61
+ grad_accum: int,
62
+ num_epochs: int,
63
+ use_ds: bool,
64
+ ) -> int:
65
+ if not 0.0 <= validation_ratio < 1.0:
66
+ raise ValueError(f"validation_ratio must be in [0, 1), got {validation_ratio}")
67
+ if dataset_size < 2 or validation_ratio <= 0:
68
+ return 0
69
+
70
+ desired_val_size = max(1, int(round(dataset_size * validation_ratio)))
71
+ aligned_candidates: list[tuple[int, int]] = []
72
+ fallback_candidates: list[tuple[int, int]] = []
73
+ for val_size in range(1, dataset_size):
74
+ train_size = dataset_size - val_size
75
+ schedule = compute_training_schedule(
76
+ dataset_size=train_size,
77
+ micro_batch_size=micro_batch_size,
78
+ grad_accum=grad_accum,
79
+ num_epochs=num_epochs,
80
+ use_ds=use_ds,
81
+ drop_last=True,
82
+ )
83
+ if int(schedule["batches_per_epoch"]) == 0:
84
+ continue
85
+ if int(schedule["dropped_samples_per_epoch"]) != 0:
86
+ continue
87
+ candidate = (abs(val_size - desired_val_size), val_size)
88
+ if int(schedule["remainder_batches"]) == 0 and int(schedule["final_remainder"]) == 0:
89
+ aligned_candidates.append(candidate)
90
+ else:
91
+ fallback_candidates.append(candidate)
92
+
93
+ if aligned_candidates:
94
+ return min(aligned_candidates)[1]
95
+ if fallback_candidates:
96
+ return min(fallback_candidates)[1]
97
+ return min(desired_val_size, dataset_size - 1)
98
+
99
+ def build_train_validation_subsets(
100
+ dataset: Dataset,
101
+ validation_ratio: float,
102
+ split_seed: int,
103
+ micro_batch_size: int,
104
+ grad_accum: int,
105
+ num_epochs: int,
106
+ use_ds: bool,
107
+ ) -> tuple[Dataset, Dataset | None, dict[str, float | int | bool]]:
108
+ dataset_size = len(dataset)
109
+ validation_size = choose_validation_size(
110
+ dataset_size=dataset_size,
111
+ validation_ratio=validation_ratio,
112
+ micro_batch_size=micro_batch_size,
113
+ grad_accum=grad_accum,
114
+ num_epochs=num_epochs,
115
+ use_ds=use_ds,
116
+ )
117
+ requested_validation_size = max(1, int(round(dataset_size * validation_ratio))) if validation_ratio > 0 else 0
118
+ metadata: dict[str, float | int | bool] = {
119
+ "dataset_size": dataset_size,
120
+ "requested_validation_size": requested_validation_size,
121
+ "validation_size": validation_size,
122
+ "train_size": dataset_size - validation_size,
123
+ "requested_validation_ratio": validation_ratio,
124
+ "actual_validation_ratio": (validation_size / dataset_size) if dataset_size else 0.0,
125
+ "adjusted": validation_size != requested_validation_size,
126
+ }
127
+ train_schedule = compute_training_schedule(
128
+ dataset_size=dataset_size - validation_size,
129
+ micro_batch_size=micro_batch_size,
130
+ grad_accum=grad_accum,
131
+ num_epochs=num_epochs,
132
+ use_ds=use_ds,
133
+ drop_last=True,
134
+ )
135
+ metadata.update(
136
+ {
137
+ "effective_batch_size": micro_batch_size * grad_accum,
138
+ "train_batches_per_epoch": int(train_schedule["batches_per_epoch"]),
139
+ "train_remainder_batches": int(train_schedule["remainder_batches"]),
140
+ "train_dropped_samples_per_epoch": int(train_schedule["dropped_samples_per_epoch"]),
141
+ "accumulation_aligned": int(train_schedule["remainder_batches"]) == 0
142
+ and int(train_schedule["final_remainder"]) == 0,
143
+ }
144
+ )
145
+
146
+ if validation_size == 0:
147
+ return dataset, None, metadata
148
+
149
+ generator = torch.Generator().manual_seed(split_seed)
150
+ permutation = torch.randperm(dataset_size, generator=generator).tolist()
151
+ val_indices = sorted(permutation[:validation_size])
152
+ train_indices = sorted(permutation[validation_size:])
153
+ return Subset(dataset, train_indices), Subset(dataset, val_indices), metadata
154
+
155
+ def load_deepspeed_runtime_config(config_path: str, micro_batch_size: int, grad_accum: int) -> dict:
156
+ with open(config_path, "r", encoding="utf-8") as f:
157
+ ds_cfg = json.load(f)
158
+
159
+ if not isinstance(ds_cfg, dict):
160
+ raise ValueError(f"DeepSpeed config in {config_path} must be a JSON object.")
161
+
162
+ runtime_cfg = dict(ds_cfg)
163
+ runtime_cfg["train_micro_batch_size_per_gpu"] = micro_batch_size
164
+ runtime_cfg["gradient_accumulation_steps"] = grad_accum
165
+ return runtime_cfg
src/transformers_compat.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import importlib.util
5
+ import os
6
+ from configs import cfg
7
+
8
+ def _false() -> bool:
9
+ return False
10
+
11
+ def describe_exception_chain(exc: Exception) -> str:
12
+ messages: list[str] = []
13
+ seen: set[int] = set()
14
+ current: BaseException | None = exc
15
+
16
+ while current is not None and id(current) not in seen:
17
+ seen.add(id(current))
18
+ message = f"{type(current).__name__}: {current}"
19
+ if message not in messages:
20
+ messages.append(message)
21
+ current = current.__cause__ or current.__context__
22
+
23
+ return " -> ".join(messages)
24
+
25
+ def disable_flash_attn_for_transformers() -> None:
26
+ try:
27
+ import transformers.utils as tf_utils
28
+
29
+ tf_utils.is_flash_attn_2_available = _false
30
+ if hasattr(tf_utils, "is_flash_attn_3_available"):
31
+ tf_utils.is_flash_attn_3_available = _false
32
+ except Exception:
33
+ pass
34
+
35
+ try:
36
+ from transformers.utils import import_utils as tf_import_utils
37
+
38
+ tf_import_utils.is_flash_attn_2_available = _false
39
+ if hasattr(tf_import_utils, "is_flash_attn_3_available"):
40
+ tf_import_utils.is_flash_attn_3_available = _false
41
+ except Exception:
42
+ pass
43
+
44
+ try:
45
+ import transformers.modeling_utils as modeling_utils
46
+
47
+ if hasattr(modeling_utils, "is_flash_attn_2_available"):
48
+ modeling_utils.is_flash_attn_2_available = _false
49
+ if hasattr(modeling_utils, "is_flash_attn_3_available"):
50
+ modeling_utils.is_flash_attn_3_available = _false
51
+ except Exception:
52
+ pass
53
+
54
+ try:
55
+ import transformers.modeling_flash_attention_utils as flash_utils
56
+
57
+ flash_utils.is_flash_attn_2_available = _false
58
+ if hasattr(flash_utils, "is_flash_attn_3_available"):
59
+ flash_utils.is_flash_attn_3_available = _false
60
+ except Exception:
61
+ pass
62
+
63
+ def resolve_attention_backend(logger) -> str:
64
+ forced_backend = os.environ.get("QUINTUS_ATTENTION_BACKEND")
65
+ if forced_backend:
66
+ logger.info(f" [ATTENTION] Forced backend via QUINTUS_ATTENTION_BACKEND={forced_backend!r}.")
67
+ return forced_backend
68
+
69
+ try:
70
+ from transformers.utils import is_flash_attn_3_available
71
+ if is_flash_attn_3_available():
72
+ logger.info(" [ATTENTION] Using flash_attention_3.")
73
+ return "flash_attention_3"
74
+ except Exception:
75
+ pass
76
+
77
+ try:
78
+ importlib.import_module("flash_attn")
79
+ logger.info(" [ATTENTION] Using flash_attention_2.")
80
+ return "flash_attention_2"
81
+ except Exception as exc:
82
+ if importlib.util.find_spec("flash_attn") is not None:
83
+ disable_flash_attn_for_transformers()
84
+ logger.warning(
85
+ "flash-attn appears installed but failed to import (%s: %s); "
86
+ "masking flash-attn from Transformers and falling back to sdpa.",
87
+ type(exc).__name__,
88
+ exc,
89
+ )
90
+ else:
91
+ logger.info(" [ATTENTION] Using PyTorch SDPA.")
92
+ return "sdpa"
93
+
94
+ def _requires_remote_code_opt_in(exc: Exception) -> bool:
95
+ message = str(exc).lower()
96
+ return (
97
+ "trust_remote_code" in message
98
+ or "requires you to execute the configuration file" in message
99
+ or "requires remote code" in message
100
+ )
101
+
102
+ def format_model_load_error(subject: str, exc: Exception) -> str:
103
+ if not cfg.model.allow_remote_code and _requires_remote_code_opt_in(exc):
104
+ return (
105
+ f"{subject} failed because the selected model/tokenizer requires remote code, "
106
+ "but Quintus is configured with allow_remote_code=false. Review the upstream "
107
+ "repository and rerun with QUINTUS_ALLOW_REMOTE_CODE=1 only if you explicitly "
108
+ "trust that code."
109
+ )
110
+ return f"{subject} failed: {describe_exception_chain(exc)}"
src/validation.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+ from src.losses import compute_loss_for_phase
7
+ from src.training_data import move_batch_to_device
8
+
9
+ def evaluate_validation_loss(
10
+ phase: str,
11
+ model,
12
+ dataloader: DataLoader,
13
+ device: torch.device,
14
+ alpha: float,
15
+ temperature: float,
16
+ online_kd_token_chunk_size: int = 2048,
17
+ teacher_model=None,
18
+ max_batches: int = -1,
19
+ ) -> dict[str, float | int]:
20
+ was_training = model.training
21
+ model.eval()
22
+
23
+ total_loss = 0.0
24
+ total_ce = 0.0
25
+ total_kd = 0.0
26
+ batches = 0
27
+
28
+ with torch.inference_mode():
29
+ for batch in dataloader:
30
+ if max_batches > 0 and batches >= max_batches:
31
+ break
32
+ batch = move_batch_to_device(batch, device)
33
+ input_ids = batch["input_ids"]
34
+ attention_mask = batch["attention_mask"]
35
+ labels = batch["labels"]
36
+ loss_mask = batch["loss_mask"]
37
+
38
+ logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
39
+
40
+ if phase == "online_kd" and teacher_model is not None:
41
+ teacher_logits = teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits
42
+ else:
43
+ teacher_logits = None
44
+
45
+ loss, ce, kd = compute_loss_for_phase(
46
+ phase,
47
+ logits,
48
+ labels,
49
+ loss_mask,
50
+ batch,
51
+ alpha,
52
+ temperature,
53
+ teacher_logits=teacher_logits,
54
+ online_kd_token_chunk_size=online_kd_token_chunk_size,
55
+ )
56
+ total_loss += float(loss.detach().item())
57
+ total_ce += float(ce.detach().item())
58
+ total_kd += float(kd.detach().item())
59
+ batches += 1
60
+
61
+ if was_training:
62
+ model.train()
63
+
64
+ denom = max(batches, 1)
65
+ return {
66
+ "loss": total_loss / denom,
67
+ "ce": total_ce / denom,
68
+ "kd": total_kd / denom,
69
+ "batches": batches,
70
+ }
weight_audit/quintus_weight_audit.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage : python audit.py \
3
+ --base_model Qwen/Qwen3-1.7B-Base \
4
+ --distilled_model iamrahulreddy/Quintus \
5
+ --output_file weight_audit_report.txt \
6
+ --alpha 0.3
7
+ """
8
+
9
+ import argparse
10
+ import collections
11
+ import math
12
+ import sys
13
+ import time
14
+ from datetime import datetime, timezone
15
+ from pathlib import Path
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from huggingface_hub import snapshot_download
20
+ from transformers import AutoConfig, AutoModelForCausalLM
21
+
22
+ # Formatting utilities
23
+ def fmt_num(n: int) -> str:
24
+ if n >= 1_000_000_000:
25
+ return f"{n:,} ({n / 1e9:.6f} B)"
26
+ if n >= 1_000_000:
27
+ return f"{n:,} ({n / 1e6:.6f} M)"
28
+ return f"{n:,}"
29
+
30
+ def fmt_size(b: int) -> str:
31
+ if b >= 1 << 30:
32
+ return f"{b / (1 << 30):.3f} GiB"
33
+ if b >= 1 << 20:
34
+ return f"{b / (1 << 20):.3f} MiB"
35
+ if b >= 1 << 10:
36
+ return f"{b / (1 << 10):.3f} KiB"
37
+ return f"{b} B"
38
+
39
+ def divider(char: str = "-", width: int = 88) -> str:
40
+ return char * width
41
+
42
+ def section_header(index: int, title: str) -> str:
43
+ return f"\n[{index:02d}] {title}"
44
+
45
+ def sub_header(title: str) -> str:
46
+ return f"\n -- {title}"
47
+
48
+ # Layer classification
49
+ LAYER_TYPE_MAP = {
50
+ "embed_tokens": "embedding",
51
+ "lm_head": "lm_head",
52
+ "self_attn.q_proj": "attn_q",
53
+ "self_attn.k_proj": "attn_k",
54
+ "self_attn.v_proj": "attn_v",
55
+ "self_attn.o_proj": "attn_o",
56
+ "self_attn.q_norm": "attn_qnorm",
57
+ "self_attn.k_norm": "attn_knorm",
58
+ "mlp.gate_proj": "mlp_gate",
59
+ "mlp.up_proj": "mlp_up",
60
+ "mlp.down_proj": "mlp_down",
61
+ "input_layernorm": "layernorm",
62
+ "post_attention_layernorm": "layernorm",
63
+ "model.norm": "final_norm",
64
+ }
65
+
66
+ def classify_layer(name: str) -> str:
67
+ for pattern, label in LAYER_TYPE_MAP.items():
68
+ if pattern in name:
69
+ return label
70
+ return "other"
71
+
72
+ # Tensor statistics
73
+ def tensor_stats(t: torch.Tensor) -> dict:
74
+ tf = t.float()
75
+ flat = tf.view(-1)
76
+ mean = flat.mean().item()
77
+ std = flat.std().item()
78
+
79
+ sparsity = (flat.abs() < 1e-6).float().mean().item()
80
+ sat_thresh = flat.abs().max().item() * 0.99
81
+ saturation = (flat.abs() >= sat_thresh).float().mean().item()
82
+ kurtosis = (((flat - mean) / std) ** 4).mean().item() - 3.0 if std > 1e-10 else 0.0
83
+ outlier_r = (flat.abs() > (flat.abs().mean() + 3.0 * std)).float().mean().item()
84
+
85
+ row_l2_stats = {}
86
+ if tf.ndim == 2:
87
+ row_norms = tf.norm(2, dim=1)
88
+ row_l2_stats = {
89
+ "row_l2_mean": row_norms.mean().item(),
90
+ "row_l2_std": row_norms.std().item(),
91
+ "row_l2_min": row_norms.min().item(),
92
+ "row_l2_max": row_norms.max().item(),
93
+ "dead_rows": int((row_norms < 1e-6).sum().item()),
94
+ }
95
+
96
+ return {
97
+ "shape": list(tf.shape),
98
+ "numel": flat.numel(),
99
+ "dtype": str(t.dtype),
100
+ "mean": mean,
101
+ "std": std,
102
+ "min": flat.min().item(),
103
+ "max": flat.max().item(),
104
+ "abs_mean": flat.abs().mean().item(),
105
+ "l2_norm": flat.norm(2).item(),
106
+ "l1_norm": flat.norm(1).item(),
107
+ "sparsity": sparsity,
108
+ "saturation": saturation,
109
+ "kurtosis": kurtosis,
110
+ "outlier_ratio": outlier_r,
111
+ **row_l2_stats,
112
+ }
113
+
114
+ # Divergence between two tensors
115
+ def tensor_divergence(t_base: torch.Tensor, t_dist: torch.Tensor, chunk_size: int = 10_000_000) -> dict:
116
+ a_flat = t_base.detach().view(-1)
117
+ b_flat = t_dist.detach().view(-1)
118
+ n_elements = a_flat.numel()
119
+
120
+ # Running accumulators in float64 (on CPU/Python) to prevent memory spikes
121
+ dot_prod = 0.0
122
+ a_sq_sum = 0.0
123
+ b_sq_sum = 0.0
124
+
125
+ # Delta statistics
126
+ max_delta = 0.0
127
+ sum_delta = 0.0
128
+ l2_delta_sq = 0.0
129
+ sum_abs_a = 0.0
130
+
131
+ # Process in chunks to keep memory footprint extremely small (~80MB peak per chunk)
132
+ for i in range(0, n_elements, chunk_size):
133
+ a_chunk = a_flat[i : i + chunk_size].to(torch.float64)
134
+ b_chunk = b_flat[i : i + chunk_size].to(torch.float64)
135
+
136
+ # Accumulate dot product and norms
137
+ dot_prod += torch.dot(a_chunk, b_chunk).item()
138
+ a_sq_sum += torch.dot(a_chunk, a_chunk).item()
139
+ b_sq_sum += torch.dot(b_chunk, b_chunk).item()
140
+
141
+ # Accumulate delta stats
142
+ delta_chunk = (b_chunk - a_chunk).abs()
143
+ max_delta = max(max_delta, delta_chunk.max().item())
144
+ sum_delta += delta_chunk.sum().item()
145
+ l2_delta_sq += torch.dot(delta_chunk, delta_chunk).item()
146
+ sum_abs_a += a_chunk.abs().sum().item()
147
+
148
+ # Final metrics
149
+ a_norm = math.sqrt(a_sq_sum)
150
+ b_norm = math.sqrt(b_sq_sum)
151
+ if a_norm > 0 and b_norm > 0:
152
+ cos_sim_raw = dot_prod / (a_norm * b_norm)
153
+ else:
154
+ cos_sim_raw = 0.0
155
+
156
+ cos_sim = max(-1.0, min(1.0, cos_sim_raw))
157
+
158
+ rel_err = sum_delta / (sum_abs_a + 1e-12)
159
+ base_l2 = a_norm
160
+ delta_l2 = math.sqrt(l2_delta_sq)
161
+ snr_db = 20.0 * math.log10(base_l2 / (delta_l2 + 1e-12)) if base_l2 > 0 else 0.0
162
+
163
+ # Standard deviation of delta
164
+ mean_delta = sum_delta / n_elements
165
+ mean_delta_sq = l2_delta_sq / n_elements
166
+ var_delta = max(0.0, mean_delta_sq - mean_delta**2)
167
+ std_delta = math.sqrt(var_delta)
168
+
169
+ return {
170
+ "max_delta": max_delta,
171
+ "mean_delta": mean_delta,
172
+ "std_delta": std_delta,
173
+ "l2_delta": delta_l2,
174
+ "cos_sim": cos_sim,
175
+ "cos_sim_raw": cos_sim_raw,
176
+ "rel_err": rel_err,
177
+ "snr_db": snr_db,
178
+ "changed": max_delta > 1e-7,
179
+ }
180
+
181
+ # Isotropy
182
+ def isotropy_score(t: torch.Tensor, n_samples: int = 2048) -> float:
183
+ """
184
+ Average pairwise cosine similarity of randomly sampled row vectors.
185
+ Near 0 = isotropic (healthy). Near 1 = collapsed representations.
186
+ Only valid for 2D tensors with >= 2 rows.
187
+ """
188
+ if t.ndim != 2 or t.shape[0] < 2:
189
+ return float("nan")
190
+ tf = t.float()
191
+ n = min(t.shape[0], n_samples)
192
+ # Add deterministic seed for isotropy sampling
193
+ gen = torch.Generator().manual_seed(42)
194
+ idx = torch.randperm(t.shape[0], generator=gen)[:n].to(t.device)
195
+ rows = tf[idx]
196
+ norms = rows.norm(2, dim=1, keepdim=True).clamp(min=1e-12)
197
+ normed = rows / norms
198
+ sim = normed @ normed.T
199
+ mask = ~torch.eye(n, dtype=torch.bool)
200
+ return sim[mask].mean().item()
201
+
202
+ # Config helpers
203
+ def config_architecture_lines(config, label: str, model_id: str) -> list[str]:
204
+ cfg = config.to_dict()
205
+ n_q = cfg.get("num_attention_heads", 1)
206
+ n_kv = cfg.get("num_key_value_heads", n_q)
207
+ h = cfg.get("hidden_size", 0)
208
+ head_dim = h // n_q if n_q else 0
209
+ gqa = n_q // n_kv if n_kv else 1
210
+
211
+ return [
212
+ f" label : {label} ({model_id})",
213
+ f" model_type : {cfg.get('model_type', 'unknown')}",
214
+ f" architecture : {cfg.get('architectures', ['unknown'])[0]}",
215
+ "",
216
+ " Vocabulary",
217
+ f" vocab_size : {cfg.get('vocab_size', 'N/A'):,}",
218
+ f" bos / eos / pad : {cfg.get('bos_token_id')} / {cfg.get('eos_token_id')} / {cfg.get('pad_token_id')}",
219
+ "",
220
+ " Positional encoding",
221
+ f" max_position_embeddings: {cfg.get('max_position_embeddings', 'N/A'):,}",
222
+ f" rope_theta : {cfg.get('rope_theta', 'N/A')}",
223
+ f" rope_scaling : {cfg.get('rope_scaling', 'None')}",
224
+ "",
225
+ " Transformer dimensions",
226
+ f" hidden_size : {h}",
227
+ f" num_hidden_layers : {cfg.get('num_hidden_layers', 'N/A')}",
228
+ f" intermediate_size : {cfg.get('intermediate_size', 'N/A')}",
229
+ "",
230
+ " Attention",
231
+ f" num_attention_heads : {n_q}",
232
+ f" num_key_value_heads : {n_kv}",
233
+ f" head_dim : {head_dim}",
234
+ f" GQA ratio : {gqa}:1",
235
+ f" attention_bias : {cfg.get('attention_bias', False)}",
236
+ f" use_qk_norm : {cfg.get('use_qk_norm', False) or 'qwen3' in model_id.lower() or 'qwen3' in cfg.get('model_type', '').lower()}",
237
+ f" sliding_window : {cfg.get('sliding_window', 'None')}",
238
+ "",
239
+ " Feed-forward",
240
+ f" hidden_act : {cfg.get('hidden_act', 'silu')}",
241
+ f" mlp_bias : {cfg.get('mlp_bias', False)}",
242
+ "",
243
+ " Misc",
244
+ f" rms_norm_eps : {cfg.get('rms_norm_eps', 1e-6)}",
245
+ f" tie_word_embeddings : {cfg.get('tie_word_embeddings', True)}",
246
+ f" use_cache : {cfg.get('use_cache', True)}",
247
+ f" torch_dtype : {cfg.get('torch_dtype', 'float32')}",
248
+ f" initializer_range : {cfg.get('initializer_range', 'N/A')}",
249
+ ]
250
+
251
+ def get_params_info(config, model_id: str = "") -> dict:
252
+ h = config.hidden_size
253
+ l = config.num_hidden_layers
254
+ v = config.vocab_size
255
+ embed = v * h
256
+ tie = getattr(config, "tie_word_embeddings", True)
257
+ n_q = config.num_attention_heads
258
+ n_kv = getattr(config, "num_key_value_heads", n_q)
259
+ head_dim = h // n_q
260
+ qkv_proj = (n_q + 2 * n_kv) * head_dim * h
261
+ o_proj = h * h
262
+ use_qk_norm = (
263
+ getattr(config, "use_qk_norm", False) or
264
+ "qwen3" in model_id.lower() or
265
+ "qwen3" in getattr(config, "model_type", "").lower()
266
+ )
267
+ qk_norm = 2 * head_dim if use_qk_norm else 0
268
+ mlp = 3 * h * config.intermediate_size
269
+ norms = 2 * h
270
+ per_layer = qkv_proj + o_proj + qk_norm + mlp + norms
271
+ total_layers = l * per_layer
272
+ lm_head = 0 if tie else embed
273
+ unique = embed + lm_head + total_layers + h # +h for final norm
274
+ return {
275
+ "raw": unique + (embed if tie else 0),
276
+ "embed": embed,
277
+ "lm_head": embed,
278
+ "tied": tie,
279
+ "unique": unique,
280
+ "non_embed": total_layers + h,
281
+ "per_layer": per_layer,
282
+ }
283
+
284
+ def param_lines(config, p: dict, label: str) -> list[str]:
285
+ return [
286
+ f" {label}",
287
+ f" raw (all named) : {fmt_num(p['raw'])}",
288
+ f" embedding : {fmt_num(p['embed'])}",
289
+ f" lm_head : {fmt_num(p['lm_head'])}",
290
+ f" tied : {p['tied']}",
291
+ f" unique (deduped) : {fmt_num(p['unique'])}",
292
+ f" non-embedding : {fmt_num(p['non_embed'])}",
293
+ f" per layer (approx) : {p['per_layer']:,}",
294
+ ]
295
+
296
+ # Main
297
+ def main():
298
+ parser = argparse.ArgumentParser(description="Quintus Deep Weight Audit")
299
+ parser.add_argument("--base_model", type=str, default="Qwen/Qwen3-1.7B-Base")
300
+ parser.add_argument("--distilled_model", type=str, default="iamrahulreddy/Quintus")
301
+ parser.add_argument("--output_file", type=str, default="weight_audit_report.txt")
302
+ parser.add_argument("--alpha", type=float, default=0.3)
303
+ parser.add_argument("--isotropy_samples", type=int, default=2048)
304
+ parser.add_argument("--trust_remote_code", action="store_true", help="Allow custom code from model repositories.")
305
+ args = parser.parse_args()
306
+
307
+ # Determine compute device
308
+ device = "cuda" if torch.cuda.is_available() else "cpu"
309
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
310
+
311
+ utc_ts = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
312
+ loc_ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S local")
313
+
314
+ R: list[str] = []
315
+
316
+ def log(line: str = ""):
317
+ print(line)
318
+ R.append(line)
319
+
320
+ def loglines(lines: list[str]):
321
+ for ln in lines:
322
+ log(ln)
323
+
324
+ # Header
325
+ loglines([
326
+ divider("="),
327
+ " QUINTUS WEIGHT AUDIT",
328
+ divider("="),
329
+ f" {utc_ts} ({loc_ts})",
330
+ f" base model : {args.base_model}",
331
+ f" distilled model : {args.distilled_model}",
332
+ f" alpha : {args.alpha}",
333
+ f" device : {device} | dtype: {dtype}",
334
+ f" python : {sys.version.split()[0]} | torch: {torch.__version__}",
335
+ divider("="),
336
+ ])
337
+
338
+ # [01] Resolve checkpoints
339
+ log(section_header(1, "Resolve checkpoints"))
340
+
341
+ # Resolve base model commit hash (pin and report base commit)
342
+ base_commit = "local"
343
+ if not Path(args.base_model).exists():
344
+ try:
345
+ base_local_dir = Path(snapshot_download(repo_id=args.base_model))
346
+ base_commit = base_local_dir.name
347
+ except Exception:
348
+ base_commit = "unknown"
349
+
350
+ dist_commit = "local"
351
+ if not Path(args.distilled_model).exists():
352
+ log(f" Downloading '{args.distilled_model}' from HuggingFace Hub...")
353
+ t0 = time.time()
354
+ try:
355
+ local_dir = snapshot_download(repo_id=args.distilled_model)
356
+ distilled_path = Path(local_dir)
357
+ dist_commit = distilled_path.name
358
+ except Exception as e:
359
+ log(f" ERROR: {e}")
360
+ sys.exit(1)
361
+ log(f" Done in {time.time() - t0:.1f}s")
362
+ else:
363
+ distilled_path = Path(args.distilled_model)
364
+ if "snapshots" in distilled_path.parts:
365
+ dist_commit = distilled_path.name
366
+
367
+ # Redact absolute local HF cache paths for sharing
368
+ redacted_root = "<HF_CACHE_DIR>/snapshots"
369
+ log(f" base model commit : {base_commit}")
370
+ log(f" distilled commit : {dist_commit}")
371
+ log(f" snapshot root : {redacted_root}")
372
+
373
+ if not (distilled_path / "config.json").exists():
374
+ log(" ERROR: config.json missing from checkpoint directory.")
375
+ sys.exit(1)
376
+
377
+ files = sorted(f for f in distilled_path.iterdir() if f.is_file())
378
+ total_ckpt_bytes = sum(f.stat().st_size for f in files)
379
+ log("")
380
+ log(f" {'Filename':<52} {'Size':>12} Modified")
381
+ for f in files:
382
+ mtime = datetime.fromtimestamp(f.stat().st_mtime).strftime("%Y-%m-%d %H:%M")
383
+ log(f" {f.name:<52} {fmt_size(f.stat().st_size):>12} {mtime}")
384
+ log(f" {'total':<52} {fmt_size(total_ckpt_bytes):>12}")
385
+
386
+ # [02] Architecture configuration
387
+ log(section_header(2, "Architecture configuration"))
388
+
389
+ log(" Loading base config...")
390
+ try:
391
+ base_config = AutoConfig.from_pretrained(args.base_model, trust_remote_code=args.trust_remote_code)
392
+ except Exception as e:
393
+ log(f" ERROR: {e}"); sys.exit(1)
394
+
395
+ log(" Loading distilled config...")
396
+ try:
397
+ distilled_config = AutoConfig.from_pretrained(str(distilled_path), trust_remote_code=args.trust_remote_code)
398
+ except Exception as e:
399
+ log(f" ERROR: {e}"); sys.exit(1)
400
+
401
+ log(sub_header("Base"))
402
+ loglines(config_architecture_lines(base_config, "base", args.base_model))
403
+
404
+ log(sub_header("Distilled"))
405
+ loglines(config_architecture_lines(distilled_config, "distilled", args.distilled_model))
406
+
407
+ log(sub_header("Config diff (ignoring: _name_or_path, transformers_version)"))
408
+ ignore_keys = {"_name_or_path", "transformers_version"}
409
+ base_dict = base_config.to_dict()
410
+ dist_dict = distilled_config.to_dict()
411
+ config_diffs = [
412
+ (k, base_dict.get(k), dist_dict.get(k))
413
+ for k in sorted(set(base_dict) | set(dist_dict))
414
+ if k not in ignore_keys and base_dict.get(k) != dist_dict.get(k)
415
+ ]
416
+ if not config_diffs:
417
+ log(" No differences — configs identical (expected for same-architecture KD).")
418
+ else:
419
+ log(f" {'Key':<40} {'Base':>28} Distilled")
420
+ for k, vb, vd in config_diffs:
421
+ log(f" {k:<40} {str(vb):>28} {vd}")
422
+
423
+ # [03] Parameter accounting
424
+ log(section_header(3, "Parameter accounting"))
425
+
426
+ base_params = get_params_info(base_config, args.base_model)
427
+ dist_params = get_params_info(distilled_config, args.distilled_model)
428
+
429
+ log(sub_header("Base"))
430
+ loglines(param_lines(base_config, base_params, "base"))
431
+
432
+ log(sub_header("Distilled"))
433
+ loglines(param_lines(distilled_config, dist_params, "distilled"))
434
+
435
+ log(sub_header("Delta"))
436
+ du = dist_params["unique"] - base_params["unique"]
437
+ log(f" unique param delta : {du:+,} ({du / base_params['unique'] * 100:+.4f} %)")
438
+ log(f" non-embed param delta : {dist_params['non_embed'] - base_params['non_embed']:+,}")
439
+
440
+ # [04] Load weights onto GPU
441
+ log(section_header(4, "Load weights"))
442
+ log(f" device: {device} | dtype: {dtype}")
443
+
444
+ load_kwargs = dict(dtype=dtype, device_map=device, trust_remote_code=args.trust_remote_code)
445
+
446
+ log(f" Loading base model : {args.base_model}")
447
+ t0 = time.time()
448
+ base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **load_kwargs)
449
+ log(f" Done in {time.time() - t0:.1f}s")
450
+
451
+ log(f" Loading distilled : {args.distilled_model}")
452
+ t0 = time.time()
453
+ distilled_model = AutoModelForCausalLM.from_pretrained(str(distilled_path), **load_kwargs)
454
+ log(f" Done in {time.time() - t0:.1f}s")
455
+
456
+ base_sd = base_model.state_dict()
457
+ dist_sd = distilled_model.state_dict()
458
+
459
+ log(f" base tensors : {len(base_sd)}")
460
+ log(f" distilled tensors : {len(dist_sd)}")
461
+
462
+ only_base = set(base_sd) - set(dist_sd)
463
+ only_dist = set(dist_sd) - set(base_sd)
464
+ if only_base:
465
+ log(f" keys only in base : {sorted(only_base)[:5]} ...")
466
+ if only_dist:
467
+ log(f" keys only in distilled: {sorted(only_dist)[:5]} ...")
468
+
469
+ tied = torch.equal(
470
+ base_sd["model.embed_tokens.weight"],
471
+ base_sd.get("lm_head.weight", base_sd["model.embed_tokens.weight"]),
472
+ )
473
+ log(f" weight tying confirmed (embed == lm_head): {tied}")
474
+
475
+ def sd_bytes(sd):
476
+ return sum(t.numel() * t.element_size() for t in sd.values())
477
+
478
+ log(f" base weight memory : {fmt_size(sd_bytes(base_sd))}")
479
+ log(f" distilled memory : {fmt_size(sd_bytes(dist_sd))}")
480
+
481
+ # All subsequent tensor ops: move to CPU float32 only during computation,
482
+ # keep storage on GPU in bfloat16.
483
+ all_names = list(dist_sd.keys())
484
+
485
+ # [05] Full per-tensor statistics (distilled)
486
+ log(section_header(5, "Per-tensor weight statistics (distilled)"))
487
+
488
+ col = (
489
+ f" {'Layer':<68} {'Shape':<22} {'Mean':>8} {'Std':>8} "
490
+ f"{'Min':>8} {'Max':>8} {'Sparse':>7} {'KurtD':>7} "
491
+ f"{'OutlR':>7} {'RowL2':>8} {'DeadR':>6}"
492
+ )
493
+ log(col)
494
+ log(f" {divider('-', 170)}")
495
+
496
+ # Helper to calculate kurtosis statistics for base comparison
497
+ all_stats: dict[str, dict] = {}
498
+ type_buckets: dict[str, list[str]] = collections.defaultdict(list)
499
+
500
+ for name in all_names:
501
+ # Move to CPU float32 for stats only
502
+ t = dist_sd[name].cpu()
503
+ st = tensor_stats(t)
504
+
505
+ # Calculate base model kurtosis if present
506
+ if name in base_sd:
507
+ t_base = base_sd[name].cpu()
508
+ st_base = tensor_stats(t_base)
509
+ kurt_base = st_base["kurtosis"]
510
+ else:
511
+ kurt_base = 0.0
512
+
513
+ st["kurtosis_base"] = kurt_base
514
+ st["kurtosis_delta"] = st["kurtosis"] - kurt_base
515
+
516
+ all_stats[name] = st
517
+ type_buckets[classify_layer(name)].append(name)
518
+
519
+ rl2 = st.get("row_l2_mean", float("nan"))
520
+ dead = st.get("dead_rows", float("nan"))
521
+ log(
522
+ f" {name:<68} {str(st['shape']):<22} "
523
+ f"{st['mean']:8.4f} {st['std']:8.4f} "
524
+ f"{st['min']:8.4f} {st['max']:8.4f} "
525
+ f"{st['sparsity']:7.4f} {st['kurtosis_delta']:7.2f} "
526
+ f"{st['outlier_ratio']:7.4f} "
527
+ f"{rl2:8.4f} "
528
+ f"{str(int(dead)) if not math.isnan(dead) else 'N/A':>6}"
529
+ )
530
+
531
+ # [06] Layer-type aggregation (distilled)
532
+ log(section_header(6, "Layer-type aggregated statistics (distilled)"))
533
+
534
+ log(f" {'Type':<18} {'Count':>5} {'Params':>16} {'AvgMean':>9} {'AvgStd':>9} {'AvgSparse':>10} {'AvgKurtD':>9}")
535
+ log(f" {divider('-', 82)}")
536
+
537
+ for ltype in sorted(type_buckets):
538
+ names = type_buckets[ltype]
539
+ n = len(names)
540
+ params = sum(all_stats[x]["numel"] for x in names)
541
+ log(
542
+ f" {ltype:<18} {n:>5} {params:>16,} "
543
+ f"{sum(all_stats[x]['mean'] for x in names)/n:>9.5f} "
544
+ f"{sum(all_stats[x]['std'] for x in names)/n:>9.5f} "
545
+ f"{sum(all_stats[x]['sparsity'] for x in names)/n:>10.5f} "
546
+ f"{sum(all_stats[x]['kurtosis_delta'] for x in names)/n:>9.3f}"
547
+ )
548
+
549
+ # [07] Per-transformer-block breakdown (distilled)
550
+ log(section_header(7, "Per-transformer-block breakdown (distilled)"))
551
+
552
+ n_layers = distilled_config.num_hidden_layers
553
+ sublayer_order = [
554
+ "input_layernorm", "self_attn.q_proj", "self_attn.k_proj",
555
+ "self_attn.v_proj", "self_attn.o_proj", "self_attn.q_norm",
556
+ "self_attn.k_norm", "post_attention_layernorm",
557
+ "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj",
558
+ ]
559
+
560
+ log(f" {'Blk':>4} {'Sublayer':<35} {'Shape':<22} {'L2':>9} {'AbsMn':>9} {'Std':>9} {'Sparse':>8} {'RowL2':>9}")
561
+ log(f" {divider('-', 115)}")
562
+
563
+ for blk in range(n_layers):
564
+ prefix = f"model.layers.{blk}."
565
+ for sub in sublayer_order:
566
+ nm = prefix + sub + ".weight"
567
+ if nm not in dist_sd:
568
+ continue
569
+ st = all_stats[nm]
570
+ rl2 = st.get("row_l2_mean", float("nan"))
571
+ log(
572
+ f" {blk:>4} {sub:<35} {str(st['shape']):<22} "
573
+ f"{st['l2_norm']:>9.3f} {st['abs_mean']:>9.5f} "
574
+ f"{st['std']:>9.5f} {st['sparsity']:>8.5f} {rl2:>9.5f}"
575
+ )
576
+ log("")
577
+
578
+ # [08] Isotropy analysis (distilled)
579
+ log(section_header(8, "Isotropy analysis (distilled, 2D tensors only)"))
580
+ log(f" Sampling up to {args.isotropy_samples} rows per layer.")
581
+ log(f" Score near 0 = isotropic (healthy). Score near 1 = representation collapse.")
582
+ log("")
583
+ log(f" {'Layer':<68} {'Shape':<20} {'Score':>10}")
584
+ log(f" {divider('-', 102)}")
585
+
586
+ iso_scores: dict[str, float] = {}
587
+ for name in all_names:
588
+ t = dist_sd[name].cpu()
589
+ iso = isotropy_score(t, n_samples=args.isotropy_samples)
590
+ iso_scores[name] = iso
591
+ if not math.isnan(iso):
592
+ log(f" {name:<68} {str(all_stats[name]['shape']):<20} {iso:>10.6f}")
593
+
594
+ valid_iso = [v for v in iso_scores.values() if not math.isnan(v)]
595
+ if valid_iso:
596
+ log("")
597
+ log(f" Global (across {len(valid_iso)} 2D layers)")
598
+ log(f" mean : {sum(valid_iso)/len(valid_iso):.6f}")
599
+ log(f" min : {min(valid_iso):.6f}")
600
+ log(f" max : {max(valid_iso):.6f}")
601
+
602
+ # [09] Base vs distilled divergence — all shared layers
603
+ log(section_header(9, "Base vs distilled divergence (all shared layers)"))
604
+
605
+ shared = sorted(set(base_sd) & set(dist_sd))
606
+ all_div: dict[str, dict] = {}
607
+ changed = []
608
+ unchanged = []
609
+
610
+ log(f" Shared tensors: {len(shared)}")
611
+ log("")
612
+ log(
613
+ f" {'Layer':<68} {'MaxDelta':>9} {'MeanDelta':>10} "
614
+ f"{'L2Delta':>9} {'CosSim':>8} {'RelErr':>8} {'SNR_dB':>7} {'Chg':>4}"
615
+ )
616
+ log(f" {divider('-', 135)}")
617
+
618
+ for name in shared:
619
+ b = base_sd[name]
620
+ d = dist_sd[name]
621
+ dv = tensor_divergence(b, d)
622
+ all_div[name] = dv
623
+ (changed if dv["changed"] else unchanged).append(name)
624
+ log(
625
+ f" {name:<68} "
626
+ f"{dv['max_delta']:>9.5f} {dv['mean_delta']:>10.6f} "
627
+ f"{dv['l2_delta']:>9.4f} {dv['cos_sim']:>8.5f} "
628
+ f"{dv['rel_err']:>8.5f} {dv['snr_db']:>7.2f} "
629
+ f"{'Y' if dv['changed'] else 'N':>4}"
630
+ )
631
+
632
+ log("")
633
+ log(f" Changed : {len(changed)} / {len(shared)}")
634
+ log(f" Unchanged: {len(unchanged)} / {len(shared)}")
635
+ if unchanged:
636
+ log(f" Unchanged (first 10): {unchanged[:10]}")
637
+ log("\n Note: Unchanged tensors are primarily normalization layers (input_layernorm, q_norm, k_norm, model.norm).")
638
+ log(" This demonstrates that the SFT/KD process modified the primary semantic projection weights")
639
+ log(" (attention and MLP projections) while preserving basic layer scaling characteristics.")
640
+
641
+ # [10] Cosine similarity distribution histogram
642
+ log(section_header(10, "Cosine similarity distribution histogram"))
643
+
644
+ cos_vals = [all_div[n]["cos_sim_raw"] for n in shared]
645
+ bins = [
646
+ (float('-inf'), 0.900),
647
+ (0.900, 0.990),
648
+ (0.990, 0.999),
649
+ (0.999, 0.9999),
650
+ (0.9999, 0.99999),
651
+ (0.99999, 1.00001),
652
+ (1.00001, 1.001),
653
+ (1.001, float('inf'))
654
+ ]
655
+
656
+ def fmt_bnd(v: float) -> str:
657
+ if v == float('-inf'):
658
+ return "-inf"
659
+ if v == float('inf'):
660
+ return "inf"
661
+ return f"{v:7.5f}"
662
+
663
+ counts = []
664
+ for lo, hi in bins:
665
+ cnt = sum(1 for v in cos_vals if lo <= v < hi)
666
+ counts.append(cnt)
667
+ max_cnt = max(counts) if counts else 0
668
+ max_bar_width = 40
669
+
670
+ log(f" {'Range':<22} {'Count':>6} Histogram")
671
+ for (lo, hi), cnt in zip(bins, counts):
672
+ bar_len = int(round((cnt / max_cnt) * max_bar_width)) if max_cnt > 0 and cnt > 0 else 0
673
+ label = f"[{fmt_bnd(lo):>8}, {fmt_bnd(hi):>8})"
674
+ log(f" {label:<22} {cnt:>6} {'#' * bar_len}")
675
+
676
+ # [11] Attention geometry per block
677
+ log(section_header(11, "Attention geometry per transformer block"))
678
+
679
+ n_q = distilled_config.num_attention_heads
680
+ n_kv = getattr(distilled_config, "num_key_value_heads", n_q)
681
+ head_dim = distilled_config.hidden_size // n_q
682
+
683
+ log(f" Query heads: {n_q} | KV heads: {n_kv} | head_dim: {head_dim} | GQA: {n_q//n_kv}:1")
684
+ log("")
685
+ log(
686
+ f" {'Blk':>4} {'Q shape':<20} {'K shape':<20} {'V shape':<20} {'O shape':<20} "
687
+ f"{'Q L2':>8} {'K L2':>8} {'V L2':>8} {'O L2':>8}"
688
+ )
689
+ log(f" {divider('-', 130)}")
690
+
691
+ for blk in range(n_layers):
692
+ p = f"model.layers.{blk}.self_attn."
693
+ def attn(key):
694
+ nm = p + key + ".weight"
695
+ if nm in dist_sd:
696
+ st = all_stats[nm]
697
+ return str(st["shape"]), st["l2_norm"]
698
+ return "N/A", float("nan")
699
+
700
+ qs, ql = attn("q_proj")
701
+ ks, kl = attn("k_proj")
702
+ vs, vl = attn("v_proj")
703
+ os_, ol = attn("o_proj")
704
+ log(
705
+ f" {blk:>4} {qs:<20} {ks:<20} {vs:<20} {os_:<20} "
706
+ f"{ql:>8.3f} {kl:>8.3f} {vl:>8.3f} {ol:>8.3f}"
707
+ )
708
+
709
+ # [12] MLP geometry per block
710
+ log(section_header(12, "MLP feed-forward geometry per transformer block"))
711
+
712
+ log(f" intermediate_size: {distilled_config.intermediate_size} | activation: {getattr(distilled_config, 'hidden_act', 'silu')}")
713
+ log("")
714
+ log(
715
+ f" {'Blk':>4} {'Gate shape':<22} {'Up shape':<22} {'Down shape':<22} "
716
+ f"{'Gate L2':>8} {'Up L2':>8} {'Down L2':>9} "
717
+ f"{'GateSp':>8} {'UpSp':>8} {'DnSp':>8}"
718
+ )
719
+ log(f" {divider('-', 135)}")
720
+
721
+ for blk in range(n_layers):
722
+ p = f"model.layers.{blk}.mlp."
723
+ def mlp(key):
724
+ nm = p + key + ".weight"
725
+ if nm in dist_sd:
726
+ st = all_stats[nm]
727
+ return str(st["shape"]), st["l2_norm"], st["sparsity"]
728
+ return "N/A", float("nan"), float("nan")
729
+
730
+ gs, gl, gsp = mlp("gate_proj")
731
+ us, ul, usp = mlp("up_proj")
732
+ ds, dl, dsp = mlp("down_proj")
733
+ log(
734
+ f" {blk:>4} {gs:<22} {us:<22} {ds:<22} "
735
+ f"{gl:>8.3f} {ul:>8.3f} {dl:>9.3f} "
736
+ f"{gsp:>8.5f} {usp:>8.5f} {dsp:>8.5f}"
737
+ )
738
+
739
+ # [13] Health diagnostics
740
+ log(section_header(13, "Weight health diagnostics"))
741
+
742
+ high_sparsity = [(n, all_stats[n]["sparsity"]) for n in all_names if all_stats[n]["sparsity"] > 0.10]
743
+ high_kurtosis = [(n, all_stats[n]["kurtosis_delta"]) for n in all_names if abs(all_stats[n]["kurtosis_delta"]) > 5.0]
744
+ high_outlier = [(n, all_stats[n]["outlier_ratio"]) for n in all_names if all_stats[n]["outlier_ratio"] > 0.01]
745
+ dead_rows = [(n, int(all_stats[n].get("dead_rows", 0))) for n in all_names
746
+ if not math.isnan(all_stats[n].get("dead_rows", float("nan")))
747
+ and all_stats[n].get("dead_rows", 0) > 0]
748
+ low_cos = [(n, all_div[n]["cos_sim"]) for n in shared if all_div[n]["cos_sim"] < 0.95]
749
+ low_snr = [(n, all_div[n]["snr_db"]) for n in shared if all_div[n]["snr_db"] < 20.0]
750
+
751
+ def diag_block(title: str, rows: list, fmt):
752
+ log(f"\n {title}")
753
+ if not rows:
754
+ log(" none")
755
+ else:
756
+ for n, v in rows:
757
+ log(f" {n:<70} {fmt(v)}")
758
+
759
+ def get_percentiles(vals: list[float]) -> dict:
760
+ if not vals:
761
+ return {"mean": 0.0, "median": 0.0, "p10": 0.0, "p90": 0.0}
762
+ t = torch.tensor(vals, dtype=torch.float64)
763
+ return {
764
+ "mean": t.mean().item(),
765
+ "median": t.median().item(),
766
+ "p10": torch.quantile(t, 0.10).item(),
767
+ "p90": torch.quantile(t, 0.90).item(),
768
+ }
769
+
770
+ diag_block("Sparsity > 10%", high_sparsity, lambda v: f"sparsity={v:.5f}")
771
+ diag_block("|Kurtosis Delta| > 5.0", high_kurtosis, lambda v: f"kurt_delta={v:+.3f}")
772
+ diag_block("Outlier ratio > 1%", high_outlier, lambda v: f"outlier_ratio={v:.5f}")
773
+ diag_block("Dead rows (L2 < 1e-6)", dead_rows, lambda v: f"dead_rows={v}")
774
+ diag_block("Low cosine sim vs base (<0.95)", low_cos, lambda v: f"cos_sim={v:.6f}")
775
+ diag_block("Low SNR vs base (< 20 dB)", low_snr, lambda v: f"snr_db={v:.2f}")
776
+
777
+ log("\n Note on kurtosis delta: Kurtosis values are reported as the difference (delta) compared to the base model.")
778
+ log(" A high kurtosis delta on tiny vectors (like norm/q-k-norm vectors of size 128) is statistically expected")
779
+ log(" due to small sample sizes and does not indicate a model health or representation collapse issue.")
780
+
781
+ # [14] Executive summary
782
+ log(section_header(14, "Executive summary"))
783
+
784
+ all_cos = [all_div[n]["cos_sim"] for n in shared]
785
+ all_snr = [all_div[n]["snr_db"] for n in shared]
786
+ all_rel = [all_div[n]["rel_err"] for n in shared]
787
+
788
+ cos_stats = get_percentiles(all_cos)
789
+ snr_stats = get_percentiles(all_snr)
790
+ rel_stats = get_percentiles(all_rel)
791
+
792
+ log(f" shared tensors : {len(shared)}")
793
+ log(f" tensors changed vs base : {len(changed)} / {len(shared)}")
794
+ log(f" cosine similarity : mean = {cos_stats['mean']:.6f} | median = {cos_stats['median']:.6f} | p10 = {cos_stats['p10']:.6f} | p90 = {cos_stats['p90']:.6f}")
795
+ log(f" relative error : mean = {rel_stats['mean']:.6f} | median = {rel_stats['median']:.6f} | p10 = {rel_stats['p10']:.6f} | p90 = {rel_stats['p90']:.6f}")
796
+ log(f" SNR dB : mean = {snr_stats['mean']:.2f} | median = {snr_stats['median']:.2f} | p10 = {snr_stats['p10']:.2f} | p90 = {snr_stats['p90']:.2f}")
797
+ log(f" high-sparsity layers (>10%) : {len(high_sparsity)}")
798
+ log(f" heavy-tail layers (|kurt_d|>5.0) : {len(high_kurtosis)}")
799
+ log(f" dead-row layers : {len(dead_rows)}")
800
+ log(f" low-cos layers (<0.95) : {len(low_cos)}")
801
+ log(f" low-SNR layers (<20 dB) : {len(low_snr)}")
802
+ log(f" distillation alpha : {args.alpha}")
803
+ log("")
804
+ log(f" checkpoint size on disk : {fmt_size(total_ckpt_bytes)}")
805
+ log(f" base weights in memory : {fmt_size(sd_bytes(base_sd))}")
806
+ log(f" distilled weights in memory : {fmt_size(sd_bytes(dist_sd))}")
807
+ log("")
808
+ log(divider("="))
809
+ log(" END OF REPORT")
810
+ log(divider("="))
811
+
812
+ # Write to file
813
+ out = Path(args.output_file)
814
+ out.write_text("\n".join(R) + "\n", encoding="utf-8")
815
+ print(f"\nReport written to: {out.resolve()}")
816
+
817
+ if __name__ == "__main__":
818
+ main()
weight_audit/weight_audit_report.txt ADDED
The diff for this file is too large to render. See raw diff