Text Generation
Transformers
Safetensors
PyTorch
English
qwen3
qwen
qwen3-1.7b
qwen3-8b
quintus
quintus-1.7b
causal-lm
language-model
chat
assistant
compact-llm
small-language-model
knowledge-distillation
online-kd
full-vocabulary-kd
supervised-fine-tuning
sft
reasoning
code-generation
english
vllm
conversational
text-generation-inference
Instructions to use iamrahulreddy/Quintus with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use iamrahulreddy/Quintus with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="iamrahulreddy/Quintus") messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoTokenizer, AutoModelForMultimodalLM tokenizer = AutoTokenizer.from_pretrained("iamrahulreddy/Quintus") model = AutoModelForMultimodalLM.from_pretrained("iamrahulreddy/Quintus") messages = [ {"role": "user", "content": "Who are you?"}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use iamrahulreddy/Quintus with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "iamrahulreddy/Quintus" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "iamrahulreddy/Quintus", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/iamrahulreddy/Quintus
- SGLang
How to use iamrahulreddy/Quintus with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "iamrahulreddy/Quintus" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "iamrahulreddy/Quintus", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "iamrahulreddy/Quintus" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "iamrahulreddy/Quintus", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use iamrahulreddy/Quintus with Docker Model Runner:
docker model run hf.co/iamrahulreddy/Quintus
release: publish Quintus project files
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +367 -63
- assets/benchmark_scoreboard.png +3 -0
- assets/offline_vs_online_kd.svg +3 -0
- assets/pipeline_hardening_flow.svg +3 -0
- assets/quintus_architecture.svg +3 -0
- configs/__init__.py +186 -0
- configs/config.yaml +77 -0
- configs/ds_zero2.json +20 -0
- docs/architecture.md +88 -0
- docs/benchmarks.md +58 -0
- docs/engineering_insights.md +152 -0
- docs/evaluation_methodology.md +234 -0
- docs/experiment_timeline.md +181 -0
- docs/huggingface_model_card.md +178 -0
- docs/index.md +42 -0
- docs/pipeline_hardening.md +208 -0
- docs/training_playbook.md +199 -0
- docs/weight_audit.md +66 -0
- requirements-eval.txt +6 -0
- requirements-train.txt +12 -0
- requirements.txt +13 -0
- sft/chat.py +89 -0
- sft/evaluate.py +267 -0
- sft/train_sft.py +690 -0
- src/__init__.py +1 -0
- src/checkpoints.py +241 -0
- src/download.py +574 -0
- src/kd_contracts.py +95 -0
- src/losses.py +180 -0
- src/optim.py +44 -0
- src/provenance.py +173 -0
- src/sequence_packing.py +183 -0
- src/train.py +1219 -0
- src/training_data.py +375 -0
- src/training_schedule.py +165 -0
- src/transformers_compat.py +110 -0
- src/validation.py +70 -0
- weight_audit/quintus_weight_audit.py +818 -0
- weight_audit/weight_audit_report.txt +0 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/benchmark_scoreboard.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Muskula Rahul
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,94 +1,326 @@
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
language:
|
| 4 |
-
- en
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
tags:
|
| 6 |
-
-
|
| 7 |
-
-
|
| 8 |
-
- qwen3
|
| 9 |
-
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
# Quintus
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
- **Teacher Model**: Qwen3-8B-Instruct
|
| 22 |
-
- **Training Paradigm**: Online Full-Vocab Knowledge Distillation + Targeted Persona SFT
|
| 23 |
-
- **Language**: English
|
| 24 |
-
- **License**: MIT
|
| 25 |
|
| 26 |
-
## Core
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
2. **Targeted Persona SFT**: A final fine-tuning phase on LIMA and identity data grounds the model's persona and prevents infinite repetition loops.
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
| 36 |
-
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
## Benchmark Scoreboard
|
| 42 |
|
| 43 |
-
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
| :--- | :---: | :---: | :---: |
|
| 47 |
-
| **HumanEval** pass@1 | 67.1% | **70.7%** | 67.7% |
|
| 48 |
-
| **MBPP** pass@1 | 67.2% | 58.2% | **64.8%** |
|
| 49 |
-
| **GSM8K** (10-shot, flexible) | 69.98% | 69.75% | **74.30%** |
|
| 50 |
-
| **ARC-Challenge** acc_norm | 55.72% | 52.99% | **58.36%** |
|
| 51 |
-
| **WinoGrande** (5-shot) | 65.67% | 61.01% | **66.38%** |
|
| 52 |
-
| **PIQA** acc_norm | 75.63% | 72.09% | **75.57%** |
|
| 53 |
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
import torch
|
| 63 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
| 64 |
|
| 65 |
-
# Repo name
|
| 66 |
PUBLIC_REPO_ID = "iamrahulreddy/Quintus"
|
| 67 |
|
| 68 |
print(f"Loading Quintus from {PUBLIC_REPO_ID}...")
|
| 69 |
tokenizer = AutoTokenizer.from_pretrained(PUBLIC_REPO_ID, trust_remote_code=True)
|
| 70 |
model = AutoModelForCausalLM.from_pretrained(
|
| 71 |
-
PUBLIC_REPO_ID,
|
| 72 |
-
device_map="auto",
|
| 73 |
-
dtype=torch.float16,
|
| 74 |
-
trust_remote_code=True
|
| 75 |
)
|
| 76 |
|
| 77 |
-
# Stopping criteria
|
| 78 |
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
|
| 79 |
eos_token_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
|
| 80 |
for token in stop_tokens:
|
| 81 |
-
|
| 82 |
-
if
|
| 83 |
-
eos_token_ids.append(
|
| 84 |
|
| 85 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 86 |
|
| 87 |
conversation_history = [
|
| 88 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
]
|
| 90 |
|
| 91 |
-
print(
|
|
|
|
|
|
|
| 92 |
|
| 93 |
while True:
|
| 94 |
try:
|
|
@@ -100,17 +332,17 @@ while True:
|
|
| 100 |
continue
|
| 101 |
|
| 102 |
conversation_history.append({"role": "user", "content": user_input})
|
| 103 |
-
|
| 104 |
prompt = tokenizer.apply_chat_template(
|
| 105 |
-
conversation_history,
|
| 106 |
-
tokenize=False,
|
| 107 |
-
add_generation_prompt=True
|
| 108 |
)
|
| 109 |
-
|
| 110 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 111 |
-
|
| 112 |
print("Quintus: ", end="", flush=True)
|
| 113 |
-
|
| 114 |
with torch.no_grad():
|
| 115 |
outputs = model.generate(
|
| 116 |
**inputs,
|
|
@@ -120,15 +352,87 @@ while True:
|
|
| 120 |
do_sample=True,
|
| 121 |
streamer=streamer,
|
| 122 |
pad_token_id=tokenizer.eos_token_id,
|
| 123 |
-
eos_token_id=eos_token_ids
|
| 124 |
)
|
| 125 |
-
|
| 126 |
-
# Extract response for history
|
| 127 |
generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
|
| 128 |
-
assistant_response = tokenizer.decode(
|
|
|
|
|
|
|
|
|
|
| 129 |
conversation_history.append({"role": "assistant", "content": assistant_response})
|
| 130 |
print()
|
| 131 |
-
|
| 132 |
except KeyboardInterrupt:
|
| 133 |
print("\n\nGoodbye!")
|
| 134 |
-
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
+
pipeline_tag: text-generation
|
| 7 |
+
base_model: Qwen/Qwen3-1.7B-Base
|
| 8 |
+
base_model_relation: finetune
|
| 9 |
+
datasets:
|
| 10 |
+
- alibaba-pai/DistilQwen_100k
|
| 11 |
+
metrics:
|
| 12 |
+
- accuracy
|
| 13 |
+
- exact_match
|
| 14 |
+
- code_eval
|
| 15 |
tags:
|
| 16 |
+
- qwen3
|
| 17 |
+
- qwen
|
| 18 |
+
- qwen3-1.7b
|
| 19 |
+
- qwen3-8b
|
| 20 |
+
- quintus
|
| 21 |
+
- quintus-1.7b
|
| 22 |
+
- causal-lm
|
| 23 |
+
- text-generation
|
| 24 |
+
- language-model
|
| 25 |
+
- chat
|
| 26 |
+
- assistant
|
| 27 |
+
- compact-llm
|
| 28 |
+
- small-language-model
|
| 29 |
+
- knowledge-distillation
|
| 30 |
+
- online-kd
|
| 31 |
+
- full-vocabulary-kd
|
| 32 |
+
- supervised-fine-tuning
|
| 33 |
+
- sft
|
| 34 |
+
- reasoning
|
| 35 |
+
- code-generation
|
| 36 |
+
- english
|
| 37 |
+
- pytorch
|
| 38 |
+
- transformers
|
| 39 |
+
- vllm
|
| 40 |
+
widget:
|
| 41 |
+
- text: "Explain knowledge distillation in simple terms."
|
| 42 |
+
- text: "Solve this step by step: If a train travels 180 km in 3 hours, what is its average speed?"
|
| 43 |
---
|
| 44 |
|
| 45 |
+
# Quintus
|
| 46 |
|
| 47 |
+
[](https://colab.research.google.com/drive/1TdMSN5HzD1mToCFVf_qQoj10NGZLy2V0?usp=sharing)
|
| 48 |
+
[](https://huggingface.co/iamrahulreddy/Quintus)
|
| 49 |
+
[](docs/index.md)
|
| 50 |
+
[](docs/benchmarks.md)
|
| 51 |
+
[](LICENSE)
|
| 52 |
+
[](https://huggingface.co/Qwen/Qwen3-1.7B-Base)
|
| 53 |
+
[](https://huggingface.co/Qwen/Qwen3-8B)
|
| 54 |
|
| 55 |
+
**Quintus-1.7B** is a compact English-focused assistant built from
|
| 56 |
+
`Qwen/Qwen3-1.7B-Base`. The project uses **online full-vocabulary knowledge
|
| 57 |
+
distillation** from a `Qwen/Qwen3-8B` teacher, followed by a targeted SFT stage
|
| 58 |
+
for assistant behavior, identity grounding, and generation stability.
|
| 59 |
|
| 60 |
+
Final model weights:
|
| 61 |
+
[iamrahulreddy/Quintus](https://huggingface.co/iamrahulreddy/Quintus)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
## Core Technical Points
|
| 64 |
|
| 65 |
+
- **Dense KD signal:** the final training path streams the teacher's full
|
| 66 |
+
vocabulary distribution live instead of relying on sparse cached top-k logits.
|
| 67 |
+
- **Base-student strategy:** the student starts from `Qwen/Qwen3-1.7B-Base`,
|
| 68 |
+
leaving more room for distillation before assistant-format tuning.
|
| 69 |
+
- **Assistant-only supervision:** prompt text, chat headers, separators, and
|
| 70 |
+
padding are masked out of the supervised target region.
|
| 71 |
+
- **Sequence packing:** deterministic first-fit decreasing packing improves
|
| 72 |
+
useful-token throughput at 4096-token context length.
|
| 73 |
+
- **Public benchmark controls:** raw/chat prompt format, metric extraction,
|
| 74 |
+
generation budget, and artifact hygiene are documented explicitly.
|
| 75 |
|
| 76 |
+
## Training Summary
|
|
|
|
| 77 |
|
| 78 |
+
The release training path is a two-stage pipeline:
|
| 79 |
|
| 80 |
+
1. **Online KD:** train the 1.7B base student against live teacher logits from a
|
| 81 |
+
Qwen3-8B teacher.
|
| 82 |
+
2. **Targeted SFT:** tune the distilled checkpoint for assistant-style
|
| 83 |
+
interaction, persona consistency, and repetition control.
|
| 84 |
+
|
| 85 |
+
## Reuse As A KD Framework
|
| 86 |
+
|
| 87 |
+
Quintus is released as a trained 1.7B assistant, but the repository is also a
|
| 88 |
+
reusable reference pipeline for compact-model distillation. The same structure
|
| 89 |
+
can be adapted to other teacher/student pairs with changes to the model IDs,
|
| 90 |
+
tokenizer, dataset source, local paths, sequence length, batch schedule, and
|
| 91 |
+
hardware-specific memory settings in [configs/config.yaml](configs/config.yaml).
|
| 92 |
+
|
| 93 |
+
The reusable pieces are split across the codebase: assistant-only masking,
|
| 94 |
+
sequence packing, online full-vocabulary KD loss, checkpoint/resume metadata,
|
| 95 |
+
validation, provenance checks, SFT, and evaluation. The final pattern is:
|
| 96 |
+
|
| 97 |
+
1. Distill a smaller base student from a stronger teacher with online KD.
|
| 98 |
+
2. Apply targeted SFT to recover assistant behavior, formatting, identity, and
|
| 99 |
+
generation stability.
|
| 100 |
+
|
| 101 |
+

|
| 102 |
+
|
| 103 |
+
Core KD objective:
|
| 104 |
+
|
| 105 |
+
$$
|
| 106 |
+
\mathcal{L}_{\text{total}}
|
| 107 |
+
= \alpha \mathcal{L}_{\text{CE}}
|
| 108 |
+
+ (1 - \alpha)\mathcal{L}_{\text{KD}}
|
| 109 |
+
$$
|
| 110 |
+
|
| 111 |
+
For the final run,
|
| 112 |
+
|
| 113 |
+
$$
|
| 114 |
+
\alpha = 0.3,\quad T = 2.0
|
| 115 |
+
$$
|
| 116 |
+
|
| 117 |
+
Configuration snapshot:
|
| 118 |
+
|
| 119 |
+
| Setting | Value |
|
| 120 |
+
| :--- | :--- |
|
| 121 |
+
| Teacher | `Qwen/Qwen3-8B` |
|
| 122 |
+
| Student | `Qwen/Qwen3-1.7B-Base` |
|
| 123 |
+
| Tokenizer | `Qwen/Qwen3-1.7B` |
|
| 124 |
+
| Data | ~90K English-only samples from DistilQwen_100k |
|
| 125 |
+
| Max sequence length | 4096 |
|
| 126 |
+
| Epochs | 1 |
|
| 127 |
+
| Learning rate | `5.0e-6` |
|
| 128 |
+
| Weight decay | `0.1` |
|
| 129 |
+
| Warmup ratio | `0.05` |
|
| 130 |
+
| Online KD token chunk | 2048 |
|
| 131 |
+
| Micro batch | 4 |
|
| 132 |
+
| Gradient accumulation | 2 |
|
| 133 |
+
| Sequence packing | enabled, `pack_length = 4096` |
|
| 134 |
+
| Attention | FlashAttention-2 when available |
|
| 135 |
+
| Liger kernels | enabled for compatible Qwen-family ops |
|
| 136 |
+
| Optimizer | fused AdamW |
|
| 137 |
+
| `torch.compile` | disabled |
|
| 138 |
+
| Gradient checkpointing | disabled |
|
| 139 |
+
| Seed | 25 |
|
| 140 |
+
|
| 141 |
+
> [!NOTE]
|
| 142 |
+
> FlashAttention-2, Liger kernels, and fused AdamW are acceleration paths. Keep
|
| 143 |
+
> the baseline load path compatible with standard Transformers and vLLM APIs
|
| 144 |
+
> before publishing checkpoints. `torch.compile` stayed disabled because this
|
| 145 |
+
> KD shape showed high Inductor memory overhead, dynamic-shape graph breaks,
|
| 146 |
+
> recompile overhead, and checkpoint portability risk from `_orig_mod.` state
|
| 147 |
+
> dict prefixes when compiled modules are not unwrapped before saving.
|
| 148 |
+
|
| 149 |
+
> [!TIP]
|
| 150 |
+
> The B200-oriented defaults are conservative for the 8B teacher to 1.7B
|
| 151 |
+
> student workload. Smaller teacher/student pairs may tolerate larger
|
| 152 |
+
> micro-batches, but full-vocabulary KD scales sharply with vocabulary width.
|
| 153 |
+
|
| 154 |
+
The editable run configuration lives in [configs/config.yaml](configs/config.yaml).
|
| 155 |
+
Paths and Hub destinations are left as placeholders so each runner can set local
|
| 156 |
+
directories and repository names directly.
|
| 157 |
+
|
| 158 |
+
## Why Online KD Replaced Offline Top-K KD
|
| 159 |
+
|
| 160 |
+
Earlier experiments cached only the teacher's top-k logits. That made storage
|
| 161 |
+
smaller, but with a Qwen vocabulary around 151K tokens, $k = 8$ exposes only:
|
| 162 |
+
|
| 163 |
+
$$
|
| 164 |
+
\frac{k}{|V|}
|
| 165 |
+
= \frac{8}{151{,}665}
|
| 166 |
+
\approx 5.3 \times 10^{-5}
|
| 167 |
+
= 0.0053\%
|
| 168 |
+
$$
|
| 169 |
+
|
| 170 |
+
of the vocabulary support at each position. The sparse signal could perturb the
|
| 171 |
+
student, but it did not consistently transfer deeper reasoning behavior.
|
| 172 |
+
|
| 173 |
+
The final online path keeps the teacher and student in memory together and
|
| 174 |
+
computes KL divergence against the teacher's full-vocabulary distribution. Token
|
| 175 |
+
chunking keeps that dense objective feasible without materializing a single
|
| 176 |
+
large KL workspace.
|
| 177 |
|
| 178 |
## Benchmark Scoreboard
|
| 179 |
|
| 180 |
+
The final public scoreboard compares `Qwen/Qwen3-1.7B-Base`,
|
| 181 |
+
`Qwen/Qwen3-1.7B-Instruct`, and Quintus-1.7B.
|
| 182 |
|
| 183 |
+

|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
The strongest signal is the reasoning crossover: Quintus beats both the base
|
| 186 |
+
and official 1.7B instruct model on GSM8K, ARC-Challenge, and WinoGrande while
|
| 187 |
+
remaining at the same parameter scale.
|
| 188 |
|
| 189 |
+
See [docs/benchmarks.md](docs/benchmarks.md) for the numeric table and
|
| 190 |
+
interpretation. See
|
| 191 |
+
[docs/evaluation_methodology.md](docs/evaluation_methodology.md) for benchmark
|
| 192 |
+
controls.
|
| 193 |
|
| 194 |
+
## Evaluation Notes
|
| 195 |
+
|
| 196 |
+
Evaluation uses a mixture of EvalPlus and `lm-evaluation-harness`/vLLM style
|
| 197 |
+
benchmarks. The repository keeps evaluation methodology separate because prompt
|
| 198 |
+
format can change the result:
|
| 199 |
+
|
| 200 |
+
- Raw completion comparisons are used for base capability.
|
| 201 |
+
- Chat-template comparisons are used for assistant-format behavior.
|
| 202 |
+
- Log-likelihood tasks such as ARC-Challenge and PIQA should usually stay raw.
|
| 203 |
+
- GSM8K can differ between strict `####` parsing and flexible number
|
| 204 |
+
extraction.
|
| 205 |
+
- Metric extraction must ignore `stderr`, aliases, and wrong filter keys.
|
| 206 |
+
- Runtime versions, checkpoint identity, generation budget, and stale output
|
| 207 |
+
cleanup are part of the evaluation contract.
|
| 208 |
+
|
| 209 |
+
The active benchmark runner is [sft/evaluate.py](sft/evaluate.py). It covers
|
| 210 |
+
EvalPlus code tasks and `lm-evaluation-harness`/vLLM tasks, including GSM8K
|
| 211 |
+
10-shot evaluation with an extended generation budget.
|
| 212 |
+
|
| 213 |
+
## Repository Map
|
| 214 |
+
|
| 215 |
+
```text
|
| 216 |
+
configs/ Public run profile and DeepSpeed Zero-2 template.
|
| 217 |
+
src/ Data prep, online KD, losses, packing, checkpoints, provenance.
|
| 218 |
+
sft/ Post-KD SFT, local chat, and consolidated evaluation runner.
|
| 219 |
+
docs/ Public architecture, training, evaluation, and release notes.
|
| 220 |
+
weight_audit/ Checkpoint structure and weight-divergence audit material.
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
Key files:
|
| 224 |
+
|
| 225 |
+
- [src/train.py](src/train.py): SFT, offline KD compatibility, and final
|
| 226 |
+
`online_kd` training entry point.
|
| 227 |
+
- [src/download.py](src/download.py): model setup, dataset loading, schema
|
| 228 |
+
normalization, tokenization, and assistant-only loss masks.
|
| 229 |
+
- [src/losses.py](src/losses.py): CE/KD objective, including online full-vocab
|
| 230 |
+
KD token chunking.
|
| 231 |
+
- [src/sequence_packing.py](src/sequence_packing.py): deterministic first-fit
|
| 232 |
+
decreasing sequence packing.
|
| 233 |
+
- [src/checkpoints.py](src/checkpoints.py): checkpoint save/resume metadata and
|
| 234 |
+
packing compatibility checks.
|
| 235 |
+
- [src/provenance.py](src/provenance.py): tokenizer/model/data contract checks.
|
| 236 |
+
- [sft/train_sft.py](sft/train_sft.py): post-KD supervised fine-tuning.
|
| 237 |
+
- [sft/evaluate.py](sft/evaluate.py): EvalPlus and
|
| 238 |
+
`lm-evaluation-harness`/vLLM benchmark runner.
|
| 239 |
+
- [sft/chat.py](sft/chat.py): local interactive chat wrapper.
|
| 240 |
+
|
| 241 |
+
## Commands
|
| 242 |
+
|
| 243 |
+
Install the base dependencies:
|
| 244 |
+
|
| 245 |
+
```bash
|
| 246 |
+
pip install -r requirements.txt
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
For training and benchmark runs, install the matching extras:
|
| 250 |
|
| 251 |
+
```bash
|
| 252 |
+
pip install -r requirements-train.txt
|
| 253 |
+
pip install -r requirements-eval.txt
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
Inspect or prepare data/model assets:
|
| 257 |
+
|
| 258 |
+
```bash
|
| 259 |
+
python -m src.download --help
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
Run the final KD path after editing [configs/config.yaml](configs/config.yaml)
|
| 263 |
+
for local paths and hardware:
|
| 264 |
+
|
| 265 |
+
```bash
|
| 266 |
+
python -m src.train --phase online_kd
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
Hub checkpoint uploads are off by default for local runs. Pass
|
| 270 |
+
`--upload_last_checkpoint` or the step/epoch upload flags only after setting the
|
| 271 |
+
target repository and `HF_TOKEN`.
|
| 272 |
+
|
| 273 |
+
Run the consolidated benchmark suite:
|
| 274 |
+
|
| 275 |
+
```bash
|
| 276 |
+
python sft/evaluate.py
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
Start local chat with a downloaded or local checkpoint:
|
| 280 |
+
|
| 281 |
+
```bash
|
| 282 |
+
python sft/chat.py --model_path path/to/quintus/checkpoint
|
| 283 |
+
```
|
| 284 |
+
|
| 285 |
+
## Interactive Chat
|
| 286 |
+
|
| 287 |
+
```python
|
| 288 |
import torch
|
| 289 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
| 290 |
|
|
|
|
| 291 |
PUBLIC_REPO_ID = "iamrahulreddy/Quintus"
|
| 292 |
|
| 293 |
print(f"Loading Quintus from {PUBLIC_REPO_ID}...")
|
| 294 |
tokenizer = AutoTokenizer.from_pretrained(PUBLIC_REPO_ID, trust_remote_code=True)
|
| 295 |
model = AutoModelForCausalLM.from_pretrained(
|
| 296 |
+
PUBLIC_REPO_ID,
|
| 297 |
+
device_map="auto",
|
| 298 |
+
dtype=torch.float16,
|
| 299 |
+
trust_remote_code=True,
|
| 300 |
)
|
| 301 |
|
|
|
|
| 302 |
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
|
| 303 |
eos_token_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
|
| 304 |
for token in stop_tokens:
|
| 305 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
| 306 |
+
if token_id is not None and token_id not in eos_token_ids:
|
| 307 |
+
eos_token_ids.append(token_id)
|
| 308 |
|
| 309 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 310 |
|
| 311 |
conversation_history = [
|
| 312 |
+
{
|
| 313 |
+
"role": "system",
|
| 314 |
+
"content": (
|
| 315 |
+
"You are Quintus, a highly capable AI assistant created by "
|
| 316 |
+
"Muskula Rahul. You are helpful, precise, and logically sound."
|
| 317 |
+
),
|
| 318 |
+
}
|
| 319 |
]
|
| 320 |
|
| 321 |
+
print()
|
| 322 |
+
print("Quintus Chat (type 'quit' to exit)")
|
| 323 |
+
print()
|
| 324 |
|
| 325 |
while True:
|
| 326 |
try:
|
|
|
|
| 332 |
continue
|
| 333 |
|
| 334 |
conversation_history.append({"role": "user", "content": user_input})
|
| 335 |
+
|
| 336 |
prompt = tokenizer.apply_chat_template(
|
| 337 |
+
conversation_history,
|
| 338 |
+
tokenize=False,
|
| 339 |
+
add_generation_prompt=True,
|
| 340 |
)
|
| 341 |
+
|
| 342 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 343 |
+
|
| 344 |
print("Quintus: ", end="", flush=True)
|
| 345 |
+
|
| 346 |
with torch.no_grad():
|
| 347 |
outputs = model.generate(
|
| 348 |
**inputs,
|
|
|
|
| 352 |
do_sample=True,
|
| 353 |
streamer=streamer,
|
| 354 |
pad_token_id=tokenizer.eos_token_id,
|
| 355 |
+
eos_token_id=eos_token_ids,
|
| 356 |
)
|
| 357 |
+
|
|
|
|
| 358 |
generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
|
| 359 |
+
assistant_response = tokenizer.decode(
|
| 360 |
+
generated_ids,
|
| 361 |
+
skip_special_tokens=True,
|
| 362 |
+
).strip()
|
| 363 |
conversation_history.append({"role": "assistant", "content": assistant_response})
|
| 364 |
print()
|
| 365 |
+
|
| 366 |
except KeyboardInterrupt:
|
| 367 |
print("\n\nGoodbye!")
|
| 368 |
+
break
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
## Documentation
|
| 372 |
+
|
| 373 |
+
- [Documentation Index](docs/index.md): recommended public reading order.
|
| 374 |
+
- [Architecture](docs/architecture.md): end-to-end data flow, modules, and
|
| 375 |
+
training phases.
|
| 376 |
+
- [Experiment Timeline](docs/experiment_timeline.md): why the project moved
|
| 377 |
+
from offline top-k KD to online full-vocabulary KD.
|
| 378 |
+
- [Training Playbook](docs/training_playbook.md): memory rules, packing,
|
| 379 |
+
kernels, checkpointing, and B200-oriented guidance.
|
| 380 |
+
- [Pipeline Hardening](docs/pipeline_hardening.md): silent-failure classes,
|
| 381 |
+
artifact contracts, and safety checks.
|
| 382 |
+
- [Evaluation Methodology](docs/evaluation_methodology.md): raw/chat controls,
|
| 383 |
+
parser traps, metric extraction, and qualitative evaluation rules.
|
| 384 |
+
- [Engineering Insights](docs/engineering_insights.md): condensed lessons and
|
| 385 |
+
design decisions.
|
| 386 |
+
- [Benchmarks](docs/benchmarks.md): verified scoreboard and interpretation.
|
| 387 |
+
- [Weight Audit](docs/weight_audit.md): structural checkpoint sanity checks and
|
| 388 |
+
weight-divergence summary.
|
| 389 |
+
- [Hugging Face Model Card](docs/huggingface_model_card.md): release-page
|
| 390 |
+
copy for the public model card.
|
| 391 |
+
|
| 392 |
+
## Limitations
|
| 393 |
+
|
| 394 |
+
- Quintus is still a 1.7B model and inherits compact-model capacity limits.
|
| 395 |
+
- Factual answers can be confidently wrong and should be verified.
|
| 396 |
+
- Code generation may still contradict stated complexity or edge-case
|
| 397 |
+
requirements.
|
| 398 |
+
- Raw and chat-template results are not interchangeable.
|
| 399 |
+
- Additional preference tuning or DPO would likely improve calibration, refusal
|
| 400 |
+
behavior, and open-ended assistant polish.
|
| 401 |
+
|
| 402 |
+
## Credits
|
| 403 |
+
|
| 404 |
+
Quintus builds on open model, dataset, and tooling work from the broader LLM
|
| 405 |
+
community:
|
| 406 |
+
|
| 407 |
+
- [Qwen Team](https://qwenlm.github.io/) and the
|
| 408 |
+
[Qwen Hugging Face organization](https://huggingface.co/Qwen) for the Qwen3
|
| 409 |
+
model family.
|
| 410 |
+
- [`Qwen/Qwen3-8B`](https://huggingface.co/Qwen/Qwen3-8B), used as the
|
| 411 |
+
distillation teacher.
|
| 412 |
+
- [`Qwen/Qwen3-1.7B-Base`](https://huggingface.co/Qwen/Qwen3-1.7B-Base), used
|
| 413 |
+
as the base student checkpoint.
|
| 414 |
+
- [`Qwen/Qwen3-1.7B`](https://huggingface.co/Qwen/Qwen3-1.7B), used for the
|
| 415 |
+
tokenizer and chat-template contract.
|
| 416 |
+
- [Alibaba PAI](https://huggingface.co/alibaba-pai) for the
|
| 417 |
+
[`DistilQwen_100k`](https://huggingface.co/datasets/alibaba-pai/DistilQwen_100k)
|
| 418 |
+
dataset used as the primary instruction source after filtering.
|
| 419 |
+
- [Hugging Face Transformers](https://github.com/huggingface/transformers) for
|
| 420 |
+
model loading, tokenization, and generation APIs.
|
| 421 |
+
- [vLLM](https://github.com/vllm-project/vllm),
|
| 422 |
+
[EvalPlus](https://github.com/evalplus/evalplus), and
|
| 423 |
+
[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
|
| 424 |
+
for evaluation infrastructure.
|
| 425 |
+
- [FlashAttention](https://github.com/Dao-AILab/flash-attention) and
|
| 426 |
+
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) for performance
|
| 427 |
+
kernels used or validated during training.
|
| 428 |
+
|
| 429 |
+
## License And Author
|
| 430 |
+
|
| 431 |
+
This software is distributed under the MIT License. Refer to the
|
| 432 |
+
[LICENSE](LICENSE) file for full text.
|
| 433 |
+
|
| 434 |
+
Author: Muskula Rahul - [@iamrahulreddy](https://github.com/iamrahulreddy)
|
| 435 |
+
|
| 436 |
+
## Citation
|
| 437 |
+
|
| 438 |
+
If this model, codebase, or training pipeline is useful in your work, please cite this repository and acknowledge the upstream Qwen3 models.
|
assets/benchmark_scoreboard.png
ADDED
|
Git LFS Details
|
assets/offline_vs_online_kd.svg
ADDED
|
|
assets/pipeline_hardening_flow.svg
ADDED
|
|
assets/quintus_architecture.svg
ADDED
|
|
configs/__init__.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from datetime import timezone, timedelta
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from zoneinfo import ZoneInfo
|
| 10 |
+
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
|
| 13 |
+
_THIS_DIR = Path(__file__).resolve().parent
|
| 14 |
+
_YAML_PATH = _THIS_DIR / "config.yaml"
|
| 15 |
+
|
| 16 |
+
def _load_cfg():
|
| 17 |
+
return OmegaConf.load(_YAML_PATH)
|
| 18 |
+
|
| 19 |
+
cfg = _load_cfg()
|
| 20 |
+
|
| 21 |
+
_LOG_TZ_NAME = os.environ.get("QUINTUS_LOG_TZ", "Asia/Kolkata")
|
| 22 |
+
try:
|
| 23 |
+
_LOG_TZ = ZoneInfo(_LOG_TZ_NAME)
|
| 24 |
+
except Exception:
|
| 25 |
+
_LOG_TZ = timezone(timedelta(hours=5, minutes=30))
|
| 26 |
+
_LOG_TZ_NAME = "Asia/Kolkata"
|
| 27 |
+
|
| 28 |
+
os.environ["TZ"] = _LOG_TZ_NAME
|
| 29 |
+
if hasattr(time, "tzset"):
|
| 30 |
+
time.tzset()
|
| 31 |
+
_LOG_TZ_LABEL = "IST" if _LOG_TZ_NAME == "Asia/Kolkata" else _LOG_TZ_NAME
|
| 32 |
+
|
| 33 |
+
def _read_bool_env(name: str) -> bool | None:
|
| 34 |
+
raw = os.environ.get(name)
|
| 35 |
+
if raw is None:
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
normalised = raw.strip().lower()
|
| 39 |
+
if normalised in {"1", "true", "yes", "on"}:
|
| 40 |
+
return True
|
| 41 |
+
if normalised in {"0", "false", "no", "off"}:
|
| 42 |
+
return False
|
| 43 |
+
raise ValueError(
|
| 44 |
+
f"Invalid boolean value for {name}: {raw!r}. "
|
| 45 |
+
"Use 1/0, true/false, yes/no, or on/off."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Environment variable overrides used by the wrapper.
|
| 49 |
+
if os.environ.get("QUINTUS_TEACHER_MODEL"):
|
| 50 |
+
cfg.model.teacher = os.environ["QUINTUS_TEACHER_MODEL"]
|
| 51 |
+
if os.environ.get("QUINTUS_TEACHER_REVISION"):
|
| 52 |
+
cfg.model.teacher_revision = os.environ["QUINTUS_TEACHER_REVISION"]
|
| 53 |
+
if os.environ.get("QUINTUS_STUDENT_MODEL"):
|
| 54 |
+
cfg.model.student = os.environ["QUINTUS_STUDENT_MODEL"]
|
| 55 |
+
if os.environ.get("QUINTUS_STUDENT_REVISION"):
|
| 56 |
+
cfg.model.student_revision = os.environ["QUINTUS_STUDENT_REVISION"]
|
| 57 |
+
if os.environ.get("QUINTUS_TOKENIZER_MODEL"):
|
| 58 |
+
cfg.model.tokenizer = os.environ["QUINTUS_TOKENIZER_MODEL"]
|
| 59 |
+
if os.environ.get("QUINTUS_TOKENIZER_REVISION"):
|
| 60 |
+
cfg.model.tokenizer_revision = os.environ["QUINTUS_TOKENIZER_REVISION"]
|
| 61 |
+
if os.environ.get("QUINTUS_STUDENT_DIR"):
|
| 62 |
+
cfg.paths.student_dir = os.environ["QUINTUS_STUDENT_DIR"]
|
| 63 |
+
if os.environ.get("QUINTUS_TOKENIZER_DIR"):
|
| 64 |
+
cfg.paths.tokenizer_dir = os.environ["QUINTUS_TOKENIZER_DIR"]
|
| 65 |
+
if os.environ.get("NUM_SAMPLES"):
|
| 66 |
+
cfg.data.num_samples = int(os.environ["NUM_SAMPLES"])
|
| 67 |
+
if os.environ.get("TRAIN_NUM_EPOCHS"):
|
| 68 |
+
cfg.training.num_epochs = int(os.environ["TRAIN_NUM_EPOCHS"])
|
| 69 |
+
if os.environ.get("TRAIN_LEARNING_RATE"):
|
| 70 |
+
cfg.training.learning_rate = float(os.environ["TRAIN_LEARNING_RATE"])
|
| 71 |
+
if os.environ.get("TRAIN_ALPHA"):
|
| 72 |
+
cfg.training.alpha = float(os.environ["TRAIN_ALPHA"])
|
| 73 |
+
if os.environ.get("TRAIN_TEMPERATURE"):
|
| 74 |
+
cfg.training.temperature = float(os.environ["TRAIN_TEMPERATURE"])
|
| 75 |
+
if os.environ.get("TRAIN_TOP_K"):
|
| 76 |
+
cfg.training.top_k = int(os.environ["TRAIN_TOP_K"])
|
| 77 |
+
if os.environ.get("QUINTUS_ONLINE_KD_TOKEN_CHUNK_SIZE"):
|
| 78 |
+
cfg.training.online_kd_token_chunk_size = int(os.environ["QUINTUS_ONLINE_KD_TOKEN_CHUNK_SIZE"])
|
| 79 |
+
if os.environ.get("TRAIN_MICRO_BATCH_SIZE"):
|
| 80 |
+
cfg.training.micro_batch_size = int(os.environ["TRAIN_MICRO_BATCH_SIZE"])
|
| 81 |
+
if os.environ.get("TRAIN_GRAD_ACCUM_STEPS"):
|
| 82 |
+
cfg.training.grad_accum_steps = int(os.environ["TRAIN_GRAD_ACCUM_STEPS"])
|
| 83 |
+
if os.environ.get("TRAIN_DATALOADER_WORKERS"):
|
| 84 |
+
cfg.training.dataloader_workers = int(os.environ["TRAIN_DATALOADER_WORKERS"])
|
| 85 |
+
if os.environ.get("TRAIN_PREFETCH_FACTOR"):
|
| 86 |
+
cfg.training.prefetch_factor = int(os.environ["TRAIN_PREFETCH_FACTOR"])
|
| 87 |
+
sequence_packing_override = _read_bool_env("QUINTUS_SEQUENCE_PACKING")
|
| 88 |
+
if sequence_packing_override is not None:
|
| 89 |
+
cfg.training.sequence_packing.enabled = sequence_packing_override
|
| 90 |
+
if os.environ.get("QUINTUS_PACK_LENGTH"):
|
| 91 |
+
cfg.training.sequence_packing.pack_length = int(os.environ["QUINTUS_PACK_LENGTH"])
|
| 92 |
+
compile_override = _read_bool_env("QUINTUS_COMPILE_MODEL")
|
| 93 |
+
if compile_override is not None:
|
| 94 |
+
cfg.training.compile_model = compile_override
|
| 95 |
+
fused_adamw_override = _read_bool_env("TRAIN_FUSED_ADAMW")
|
| 96 |
+
if fused_adamw_override is not None:
|
| 97 |
+
cfg.training.fused_adamw = fused_adamw_override
|
| 98 |
+
if os.environ.get("QUINTUS_DISTILLED_DIR"):
|
| 99 |
+
cfg.paths.distilled_dir = os.environ["QUINTUS_DISTILLED_DIR"]
|
| 100 |
+
if os.environ.get("DATA_STREAM_SHUFFLE_BUFFER_SIZE"):
|
| 101 |
+
cfg.data.stream_shuffle_buffer_size = int(os.environ["DATA_STREAM_SHUFFLE_BUFFER_SIZE"])
|
| 102 |
+
if os.environ.get("DATA_STREAM_SHUFFLE_SEED"):
|
| 103 |
+
cfg.data.stream_shuffle_seed = int(os.environ["DATA_STREAM_SHUFFLE_SEED"])
|
| 104 |
+
remote_code_override = _read_bool_env("QUINTUS_ALLOW_REMOTE_CODE")
|
| 105 |
+
if remote_code_override is not None:
|
| 106 |
+
cfg.model.allow_remote_code = remote_code_override
|
| 107 |
+
|
| 108 |
+
class _TagFormatter(logging.Formatter):
|
| 109 |
+
def __init__(self, tag: str, fmt: str, datefmt: str | None = None):
|
| 110 |
+
super().__init__(fmt=fmt, datefmt=datefmt)
|
| 111 |
+
self.tag = tag
|
| 112 |
+
|
| 113 |
+
def formatTime(self, record: logging.LogRecord, datefmt: str | None = None) -> str:
|
| 114 |
+
dt = datetime_from_timestamp(record.created)
|
| 115 |
+
if datefmt:
|
| 116 |
+
return dt.strftime(datefmt)
|
| 117 |
+
return dt.isoformat(timespec="seconds")
|
| 118 |
+
|
| 119 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 120 |
+
record.tag = self.tag # type: ignore[attr-defined]
|
| 121 |
+
return super().format(record)
|
| 122 |
+
|
| 123 |
+
def datetime_from_timestamp(timestamp: float):
|
| 124 |
+
from datetime import datetime
|
| 125 |
+
|
| 126 |
+
return datetime.fromtimestamp(timestamp, tz=_LOG_TZ)
|
| 127 |
+
|
| 128 |
+
def setup_logger(module_tag: str, rank: int = -1) -> logging.Logger:
|
| 129 |
+
name = f"quintus.{module_tag}"
|
| 130 |
+
logger = logging.getLogger(name)
|
| 131 |
+
|
| 132 |
+
if logger.handlers:
|
| 133 |
+
return logger
|
| 134 |
+
|
| 135 |
+
logger.setLevel(logging.DEBUG)
|
| 136 |
+
logger.propagate = False
|
| 137 |
+
|
| 138 |
+
# Suppress duplicate output from non-primary ranks.
|
| 139 |
+
if rank not in (-1, 0):
|
| 140 |
+
logger.addHandler(logging.NullHandler())
|
| 141 |
+
return logger
|
| 142 |
+
|
| 143 |
+
# Plain text file handler.
|
| 144 |
+
file_fmt = _TagFormatter(
|
| 145 |
+
tag=module_tag,
|
| 146 |
+
fmt=f"[%(asctime)s {_LOG_TZ_LABEL}] [%(levelname)-5s] [%(tag)-8s] %(message)s",
|
| 147 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 148 |
+
)
|
| 149 |
+
log_dir = os.path.dirname(cfg.paths.log_file)
|
| 150 |
+
if log_dir:
|
| 151 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 152 |
+
file_handler = logging.FileHandler(cfg.paths.log_file, mode="a", encoding="utf-8")
|
| 153 |
+
file_handler.setLevel(logging.DEBUG)
|
| 154 |
+
file_handler.setFormatter(file_fmt)
|
| 155 |
+
logger.addHandler(file_handler)
|
| 156 |
+
|
| 157 |
+
# Plain text console handler. Keep the runtime logs stable across terminals,
|
| 158 |
+
# notebooks and log files.
|
| 159 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
| 160 |
+
console_handler.setLevel(logging.INFO)
|
| 161 |
+
console_handler.setFormatter(file_fmt)
|
| 162 |
+
logger.addHandler(console_handler)
|
| 163 |
+
|
| 164 |
+
return logger
|
| 165 |
+
|
| 166 |
+
def emit_log_spacing(logger: logging.Logger, count: int = 2) -> None:
|
| 167 |
+
if count <= 0:
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
blank_block = "\n" * count
|
| 171 |
+
for handler in logger.handlers:
|
| 172 |
+
if isinstance(handler, logging.NullHandler):
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
stream = getattr(handler, "stream", None)
|
| 176 |
+
if stream is not None and hasattr(stream, "write"):
|
| 177 |
+
stream.write(blank_block)
|
| 178 |
+
flush = getattr(stream, "flush", None)
|
| 179 |
+
if callable(flush):
|
| 180 |
+
flush()
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
console = getattr(handler, "console", None)
|
| 184 |
+
if console is not None:
|
| 185 |
+
console.print(blank_block, end="")
|
| 186 |
+
|
configs/config.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quintus Distillation Pipeline
|
| 2 |
+
# Run profile: online full-vocabulary KD, 8B teacher -> 1.7B-Base student.
|
| 3 |
+
# Data: ~90K English-only samples from DistilQwen_100k.
|
| 4 |
+
|
| 5 |
+
data:
|
| 6 |
+
dataset_path: "<REDACTED_ON_PURPOSE>"
|
| 7 |
+
num_samples: 90234
|
| 8 |
+
max_seq_len: 4096
|
| 9 |
+
stream_shuffle_buffer_size: 20000
|
| 10 |
+
stream_shuffle_seed: 25
|
| 11 |
+
|
| 12 |
+
model:
|
| 13 |
+
teacher: "Qwen/Qwen3-8B"
|
| 14 |
+
student: "Qwen/Qwen3-1.7B-Base"
|
| 15 |
+
|
| 16 |
+
# The instruct tokenizer carries the chat template used to format the base
|
| 17 |
+
# student into assistant-style training examples.
|
| 18 |
+
tokenizer: "Qwen/Qwen3-1.7B"
|
| 19 |
+
|
| 20 |
+
teacher_revision: "main"
|
| 21 |
+
student_revision: "main"
|
| 22 |
+
tokenizer_revision: "main"
|
| 23 |
+
allow_remote_code: false
|
| 24 |
+
|
| 25 |
+
training:
|
| 26 |
+
# Schedule
|
| 27 |
+
num_epochs: 1
|
| 28 |
+
validation_ratio: 0.02
|
| 29 |
+
split_seed: 25
|
| 30 |
+
|
| 31 |
+
# Optimizer
|
| 32 |
+
learning_rate: 5.0e-6
|
| 33 |
+
weight_decay: 0.1
|
| 34 |
+
warmup_ratio: 0.05
|
| 35 |
+
|
| 36 |
+
# Loss mix used by src/losses.py:
|
| 37 |
+
# total = alpha * CE + (1 - alpha) * KD
|
| 38 |
+
alpha: 0.3
|
| 39 |
+
temperature: 2.0
|
| 40 |
+
|
| 41 |
+
# Online KD streams full-vocabulary teacher logits. top_k is retained for
|
| 42 |
+
# offline-KD compatibility/provenance checks.
|
| 43 |
+
top_k: 8
|
| 44 |
+
online_kd_token_chunk_size: 2048
|
| 45 |
+
|
| 46 |
+
# Conservative B200 profile. Effective batch = 4 * 2 = 8.
|
| 47 |
+
# If VRAM headroom is comfortable and Liger is installed, try 8 * 1.
|
| 48 |
+
micro_batch_size: 4
|
| 49 |
+
grad_accum_steps: 2
|
| 50 |
+
gradient_checkpointing: false
|
| 51 |
+
compile_model: false
|
| 52 |
+
fused_adamw: true
|
| 53 |
+
|
| 54 |
+
dataloader_workers: 8
|
| 55 |
+
prefetch_factor: 2
|
| 56 |
+
|
| 57 |
+
sequence_packing:
|
| 58 |
+
enabled: true
|
| 59 |
+
pack_length: 4096
|
| 60 |
+
mask_first_token_after_separator: true
|
| 61 |
+
|
| 62 |
+
hub:
|
| 63 |
+
# Prefer HF_TOKEN or huggingface-cli login for real runs.
|
| 64 |
+
token: null
|
| 65 |
+
username: "<REDACTED_ON_PURPOSE>"
|
| 66 |
+
repo_name: "<REDACTED_ON_PURPOSE>"
|
| 67 |
+
|
| 68 |
+
paths:
|
| 69 |
+
teacher_dir: "<REDACTED_ON_PURPOSE>"
|
| 70 |
+
student_dir: "<REDACTED_ON_PURPOSE>"
|
| 71 |
+
tokenizer_dir: "<REDACTED_ON_PURPOSE>"
|
| 72 |
+
tokenized_dir: "<REDACTED_ON_PURPOSE>"
|
| 73 |
+
logits_dir: "<REDACTED_ON_PURPOSE>"
|
| 74 |
+
distilled_dir: "<REDACTED_ON_PURPOSE>"
|
| 75 |
+
log_file: "<REDACTED_ON_PURPOSE>"
|
| 76 |
+
system_info: "<REDACTED_ON_PURPOSE>"
|
| 77 |
+
loss_csv: "<REDACTED_ON_PURPOSE>"
|
configs/ds_zero2.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"zero_optimization": {
|
| 3 |
+
"stage": 2,
|
| 4 |
+
"allgather_partitions": true,
|
| 5 |
+
"allgather_bucket_size": 500000000,
|
| 6 |
+
"reduce_scatter": true,
|
| 7 |
+
"reduce_bucket_size": 500000000,
|
| 8 |
+
"overlap_comm": true,
|
| 9 |
+
"contiguous_gradients": true
|
| 10 |
+
},
|
| 11 |
+
"bf16": {
|
| 12 |
+
"enabled": true
|
| 13 |
+
},
|
| 14 |
+
"gradient_clipping": 1.0,
|
| 15 |
+
"steps_per_print": 50,
|
| 16 |
+
"wall_clock_breakdown": false,
|
| 17 |
+
"comms_logger": {
|
| 18 |
+
"enabled": false
|
| 19 |
+
}
|
| 20 |
+
}
|
docs/architecture.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Architecture
|
| 2 |
+
|
| 3 |
+
Quintus is built as a two-stage model development pipeline:
|
| 4 |
+
|
| 5 |
+
1. Online full-vocabulary knowledge distillation from a larger Qwen3 teacher into a Qwen3-1.7B base student.
|
| 6 |
+
2. Targeted SFT to improve instruction-following behavior, persona consistency, and generation stability.
|
| 7 |
+
|
| 8 |
+

|
| 9 |
+
|
| 10 |
+
## Core Training Path
|
| 11 |
+
|
| 12 |
+
The main training entry point is `src/train.py`. It supports three phases:
|
| 13 |
+
|
| 14 |
+
- `sft`: Cross-entropy training on assistant response tokens.
|
| 15 |
+
- `kd`: Offline top-k teacher-logit distillation, retained for compatibility and provenance checks.
|
| 16 |
+
- `online_kd`: The final preferred path. Teacher logits are produced live during the student forward pass.
|
| 17 |
+
|
| 18 |
+
The final KD objective is implemented in `src/losses.py`:
|
| 19 |
+
|
| 20 |
+
$$
|
| 21 |
+
\mathcal{L}_{\text{total}}
|
| 22 |
+
= \alpha \mathcal{L}_{\text{CE}}
|
| 23 |
+
+ (1 - \alpha)\mathcal{L}_{\text{KD}}
|
| 24 |
+
$$
|
| 25 |
+
|
| 26 |
+
For the final run, $\alpha = 0.3$ and $T = 2.0$. In this codebase, $\alpha$ is the cross-entropy weight. The complementary weight is assigned to the KD term.
|
| 27 |
+
|
| 28 |
+
## Data Flow
|
| 29 |
+
|
| 30 |
+
`src/download.py` prepares the training data. It handles both pre-tokenized rows and raw instruction data. For raw rows, it normalizes common conversation schemas, applies the tokenizer chat template, and builds an assistant-only `loss_mask`.
|
| 31 |
+
|
| 32 |
+
Important details:
|
| 33 |
+
|
| 34 |
+
- Prompt and formatting tokens are masked out.
|
| 35 |
+
- Assistant response tokens receive loss.
|
| 36 |
+
- Samples longer than `max_seq_len` are rejected rather than silently truncated.
|
| 37 |
+
- The tokenizer contract is later validated to avoid teacher/student vocabulary mismatches.
|
| 38 |
+
|
| 39 |
+
## Sequence Packing
|
| 40 |
+
|
| 41 |
+
`src/sequence_packing.py` implements deterministic first-fit decreasing packing. It places multiple shorter samples into fixed-length bins, separated by EOS tokens.
|
| 42 |
+
|
| 43 |
+
Packing properties:
|
| 44 |
+
|
| 45 |
+
- Training split is packed; validation can remain unpacked for interpretability.
|
| 46 |
+
- Bins are fixed at `pack_length = 4096` in the final profile.
|
| 47 |
+
- EOS separators have `loss_mask = 0`.
|
| 48 |
+
- The first token after a separator is optionally masked to avoid cross-sample target leakage.
|
| 49 |
+
- Attention masks are built from the true packed length, not by comparing token IDs against `pad_token_id`.
|
| 50 |
+
|
| 51 |
+
The attention-mask detail is important because Qwen tokenizers can reuse EOS-like IDs in ways that make token-identity-derived padding masks unsafe.
|
| 52 |
+
|
| 53 |
+
## Online KD Memory Strategy
|
| 54 |
+
|
| 55 |
+
Full-vocabulary KD is expensive because both student and teacher produce logits shaped as:
|
| 56 |
+
|
| 57 |
+
$$
|
| 58 |
+
\text{student\_logits},\ \text{teacher\_logits}
|
| 59 |
+
\in \mathbb{R}^{B \times S \times |V|}
|
| 60 |
+
$$
|
| 61 |
+
|
| 62 |
+
The implementation keeps this feasible by chunking along the token dimension with:
|
| 63 |
+
|
| 64 |
+
$$
|
| 65 |
+
C_{\text{KD}} = 2048
|
| 66 |
+
$$
|
| 67 |
+
|
| 68 |
+
Each chunk computes the teacher softmax, student log-softmax, and masked KL contribution, then accumulates the result. This preserves the dense teacher distribution while avoiding a single large KL workspace.
|
| 69 |
+
|
| 70 |
+
## Validation, Provenance, And Safety Checks
|
| 71 |
+
|
| 72 |
+
Several modules exist to prevent silent training corruption:
|
| 73 |
+
|
| 74 |
+
- `src/provenance.py`: Validates tokenizer contracts, vocab sizes, revisions, and teacher-logit metadata.
|
| 75 |
+
- `src/kd_contracts.py`: Builds deterministic tokenizer fingerprints.
|
| 76 |
+
- `src/training_schedule.py`: Aligns train/validation splits with batch and gradient-accumulation constraints.
|
| 77 |
+
- `src/checkpoints.py`: Saves model, tokenizer, scheduler, trainer state, and packing metadata; validates resume compatibility.
|
| 78 |
+
- `src/transformers_compat.py`: Resolves attention backend and formats model-loading errors.
|
| 79 |
+
|
| 80 |
+
## SFT Layer
|
| 81 |
+
|
| 82 |
+
The `sft/` directory contains the post-KD alignment layer:
|
| 83 |
+
|
| 84 |
+
- `sft/train_sft.py`: SFT training with optional sequence packing, LoRA/QLoRA paths, and built-in spot evaluations.
|
| 85 |
+
- `sft/evaluate.py`: EvalPlus and lm-evaluation-harness orchestration.
|
| 86 |
+
- `sft/chat.py`: Local interactive chat wrapper using the tokenizer chat template.
|
| 87 |
+
|
| 88 |
+
This stage is intentionally separate from KD. KD transfers the teacher's probability structure; SFT teaches the model how to expose that capability in the intended assistant format.
|
docs/benchmarks.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Benchmarks
|
| 2 |
+
|
| 3 |
+
The release scoreboard compares Qwen3-1.7B-Base, Qwen3-1.7B-Instruct, and Quintus-1.7B. Evaluations use a mixture of EvalPlus and lm-evaluation-harness style benchmarks, with greedy or deterministic settings where applicable.
|
| 4 |
+
|
| 5 |
+
For the detailed benchmark-control rules, see [Evaluation Methodology](evaluation_methodology.md).
|
| 6 |
+
|
| 7 |
+
## Final Scoreboard
|
| 8 |
+
|
| 9 |
+
| Benchmark | Qwen3-1.7B-Base | Qwen3-1.7B-Instruct | Quintus-1.7B |
|
| 10 |
+
| :--- | :---: | :---: | :---: |
|
| 11 |
+
| HumanEval pass@1 | 67.1% | 70.7% | 67.7% |
|
| 12 |
+
| MBPP pass@1 | 67.2% | 58.2% | 64.8% |
|
| 13 |
+
| GSM8K, 10-shot flexible | 69.98% | 69.75% | 74.30% |
|
| 14 |
+
| ARC-Challenge acc_norm | 55.72% | 52.99% | 58.36% |
|
| 15 |
+
| WinoGrande, 5-shot | 65.67% | 61.01% | 66.38% |
|
| 16 |
+
| PIQA acc_norm | 75.63% | 72.09% | 75.57% |
|
| 17 |
+
|
| 18 |
+
## Interpretation
|
| 19 |
+
|
| 20 |
+
The strongest result is the reasoning crossover: Quintus beats both the base and the official 1.7B instruct model on GSM8K, ARC-Challenge, and WinoGrande, despite remaining at the same parameter scale.
|
| 21 |
+
|
| 22 |
+
The coding picture is mixed but useful:
|
| 23 |
+
|
| 24 |
+
- HumanEval remains slightly below Qwen3-1.7B-Instruct.
|
| 25 |
+
- MBPP is substantially above Qwen3-1.7B-Instruct, though still below the base model.
|
| 26 |
+
|
| 27 |
+
This suggests the model gained useful instruction-following and reasoning behavior without fully matching larger or more heavily aligned code-specialized models.
|
| 28 |
+
|
| 29 |
+
## What The Benchmarks Support
|
| 30 |
+
|
| 31 |
+
These results support four claims:
|
| 32 |
+
|
| 33 |
+
1. Online KD transferred reasoning capability into a compact student.
|
| 34 |
+
2. The final model did not merely memorize assistant formatting; it improved several reasoning and commonsense metrics.
|
| 35 |
+
3. SFT helped expose the distilled capability in an assistant setting.
|
| 36 |
+
4. The model still has capacity limits typical of the 1.7B scale, especially on code execution reliability and long multi-step algorithm generation.
|
| 37 |
+
|
| 38 |
+
## Evaluation Caveats
|
| 39 |
+
|
| 40 |
+
Benchmark comparisons are sensitive to prompt format. Raw completion, chat-template generation, and log-likelihood multiple-choice scoring can produce different rankings. For fair interpretation:
|
| 41 |
+
|
| 42 |
+
- Compare raw models against raw models when measuring base reasoning.
|
| 43 |
+
- Compare chat-wrapped models against chat-wrapped models when measuring format alignment.
|
| 44 |
+
- Treat open-ended qualitative prompts as alignment tests, not as a replacement for standardized benchmarks.
|
| 45 |
+
|
| 46 |
+
Important implementation caveats:
|
| 47 |
+
|
| 48 |
+
- GSM8K extraction can differ between strict `####` parsing and flexible number extraction.
|
| 49 |
+
- Multiple-choice log-likelihood tasks can be distorted by chat templates.
|
| 50 |
+
- `acc_norm` is preferred when answer-option length bias can change the ranking.
|
| 51 |
+
- Metric extraction scripts must reject `stderr` and `alias` fields when looking for the actual score.
|
| 52 |
+
- Runtime versions should be recorded with benchmark outputs because harness behavior can change across releases.
|
| 53 |
+
|
| 54 |
+
## Earlier Development Signals
|
| 55 |
+
|
| 56 |
+
Before the final Qwen3 8B -> 1.7B run, earlier experiments showed that sparse offline top-k KD could not consistently outperform strong baselines. Those runs were useful because they identified the bottleneck: sparse cached teacher logits were not dense enough to transfer deeper reasoning pathways.
|
| 57 |
+
|
| 58 |
+
The final move to online full-vocabulary KD is the key methodological change behind the stronger final results.
|
docs/engineering_insights.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Engineering Insights
|
| 2 |
+
|
| 3 |
+
This project evolved through several failed and successful training designs. The useful lessons are summarized here as public engineering notes.
|
| 4 |
+
|
| 5 |
+
For expanded operational detail, see [Training Playbook](training_playbook.md), [Pipeline Hardening](pipeline_hardening.md), and [Evaluation Methodology](evaluation_methodology.md).
|
| 6 |
+
|
| 7 |
+
## 1. Sparse Offline KD Hit A Ceiling
|
| 8 |
+
|
| 9 |
+
The earliest distillation path cached only a small top-k slice of teacher logits. That made training cheaper, but it discarded most of the teacher distribution. With a vocabulary of roughly 151K tokens and $k = 8$, the visible support was:
|
| 10 |
+
|
| 11 |
+
$$
|
| 12 |
+
\frac{k}{|V|}
|
| 13 |
+
= \frac{8}{151{,}665}
|
| 14 |
+
\approx 5.3 \times 10^{-5}
|
| 15 |
+
= 0.0053\%
|
| 16 |
+
$$
|
| 17 |
+
|
| 18 |
+
The result was clear: top-k KD could perturb the student, but it did not transfer enough "dark knowledge" to reliably improve reasoning. Different alphas, epochs, and student initializations could not escape this sparse-signal ceiling.
|
| 19 |
+
|
| 20 |
+
The final fix was to use online KD: load teacher and student together, run both forward passes, and compute KL against the teacher's full vocabulary distribution.
|
| 21 |
+
|
| 22 |
+
## 2. Base Student Was Better Than Fighting An Aligned Space
|
| 23 |
+
|
| 24 |
+
Distilling into an already instruction-tuned student can cause destructive interference. The student's weights already encode one aligned behavior manifold, while the teacher's soft logits pull toward another. Training can look numerically stable while reasoning metrics regress.
|
| 25 |
+
|
| 26 |
+
The final path uses `Qwen/Qwen3-1.7B-Base` as the student. The base model has more plasticity, while the CE term and later SFT stage teach assistant formatting.
|
| 27 |
+
|
| 28 |
+
## 3. KD And Alignment Are Different Problems
|
| 29 |
+
|
| 30 |
+
Standardized benchmarks showed that KD can improve reasoning and calibration, but open-ended chat quality still needs alignment data.
|
| 31 |
+
|
| 32 |
+
The important diagnosis:
|
| 33 |
+
|
| 34 |
+
- A distillation failure means the student did not absorb the teacher's useful probability structure.
|
| 35 |
+
- An alignment gap means the student has capability, but the generation path is not yet trained to behave like a polished assistant.
|
| 36 |
+
|
| 37 |
+
The project therefore separates the pipeline into KD first, then SFT.
|
| 38 |
+
|
| 39 |
+
## 4. Assistant-Only Loss Masking Matters
|
| 40 |
+
|
| 41 |
+
A key bug class was assigning loss to chat formatting tokens instead of only assistant response content. If the model is trained to optimize structural tokens too heavily, it can learn formatting before substance.
|
| 42 |
+
|
| 43 |
+
The current tokenization path derives an assistant-only `loss_mask`, so:
|
| 44 |
+
|
| 45 |
+
- User prompts are context, not targets.
|
| 46 |
+
- Chat headers and separators are masked.
|
| 47 |
+
- Assistant response tokens are the only supervised targets.
|
| 48 |
+
|
| 49 |
+
This keeps training focused on semantic outputs rather than wrapper reproduction.
|
| 50 |
+
|
| 51 |
+
## 5. Sequence Packing Was The Main Throughput Win
|
| 52 |
+
|
| 53 |
+
The dataset contains many sequences shorter than the maximum context length. Dynamic padding wastes a large fraction of compute. First-fit decreasing sequence packing converted that waste into useful tokens.
|
| 54 |
+
|
| 55 |
+
Observed engineering outcome:
|
| 56 |
+
|
| 57 |
+
- Unpacked B200 online KD ran around the low-20K tokens/sec range in earlier probes.
|
| 58 |
+
- Packed B200 online KD reached roughly the mid-40K tokens/sec range after warmup.
|
| 59 |
+
- Packed utilization was close to full 4096-token bins.
|
| 60 |
+
|
| 61 |
+
The final code keeps packing deterministic and stores packing metadata in checkpoints so packed/unpacked resume mismatches fail loudly.
|
| 62 |
+
|
| 63 |
+
## 6. Full-Vocab KD Needed Token Chunking
|
| 64 |
+
|
| 65 |
+
Online KD preserves the full teacher distribution, but a full KL workspace at Qwen vocabulary scale is too large to materialize casually:
|
| 66 |
+
|
| 67 |
+
$$
|
| 68 |
+
\text{KL workspace} \sim \mathbb{R}^{B \times S \times |V|}
|
| 69 |
+
$$
|
| 70 |
+
|
| 71 |
+
The solution is token-dimension chunking. The current implementation uses:
|
| 72 |
+
|
| 73 |
+
$$
|
| 74 |
+
C_{\text{KD}} = 2048
|
| 75 |
+
$$
|
| 76 |
+
|
| 77 |
+
Larger chunks reduce loop overhead, but increase temporary memory pressure. The selected value is a practical B200-oriented balance for the 8B -> 1.7B workload.
|
| 78 |
+
|
| 79 |
+
## 7. Shape Churn And Synchronization Can Quietly Drain Throughput
|
| 80 |
+
|
| 81 |
+
Several performance bugs were not correctness bugs:
|
| 82 |
+
|
| 83 |
+
- Dynamic sequence lengths caused allocator churn.
|
| 84 |
+
- Repeated `.item()` calls forced CPU-GPU synchronization.
|
| 85 |
+
- Single-GPU DeepSpeed could add overhead when the model already fit comfortably.
|
| 86 |
+
- `torch.compile` added memory overhead, dynamic-shape graph breaks, recompile overhead, and checkpoint portability risk.
|
| 87 |
+
|
| 88 |
+
The final training loop favors stable shapes, fewer scalar syncs, fused AdamW when available, FlashAttention when available, and Liger kernels where they do not conflict with KD logits.
|
| 89 |
+
|
| 90 |
+
## 8. Evaluation Requires Controlled Comparisons
|
| 91 |
+
|
| 92 |
+
Raw completion and chat-template evaluation activate different behavior. A base model can perform well in raw mode and poorly under chat markup. A chat-aligned model can underperform on raw continuation-style tasks if the benchmark asks for direct option likelihoods.
|
| 93 |
+
|
| 94 |
+
The project uses both controls:
|
| 95 |
+
|
| 96 |
+
- Raw-to-raw comparisons isolate distilled base capability.
|
| 97 |
+
- Chat-to-chat comparisons estimate template robustness and assistant-format alignment.
|
| 98 |
+
|
| 99 |
+
This distinction avoids blaming KD for failures that belong to alignment or benchmark formatting.
|
| 100 |
+
|
| 101 |
+
## 9. Post-KD SFT Is Not Optional For Assistant Quality
|
| 102 |
+
|
| 103 |
+
KD transfers probability structure; it does not guarantee careful behavior, refusal policy, calibrated uncertainty, or code reliability. Targeted SFT was added to address:
|
| 104 |
+
|
| 105 |
+
- Confident hallucination in open-ended answers.
|
| 106 |
+
- Persona and identity consistency.
|
| 107 |
+
- Repetition loops.
|
| 108 |
+
- Chat-format stability.
|
| 109 |
+
- Practical assistant presentation.
|
| 110 |
+
|
| 111 |
+
Preference training or DPO would be the natural next layer if the project continues beyond the current release.
|
| 112 |
+
|
| 113 |
+
## 10. Training Loss Is Not The Release Gate
|
| 114 |
+
|
| 115 |
+
Several development runs looked numerically healthy while downstream benchmarks moved in the wrong direction. That pattern is expected when the training objective is only a proxy for the release objective.
|
| 116 |
+
|
| 117 |
+
Useful release gates:
|
| 118 |
+
|
| 119 |
+
- Standardized benchmarks.
|
| 120 |
+
- Raw and chat controls.
|
| 121 |
+
- Mismatch inspection.
|
| 122 |
+
- Qualitative prompts after benchmark checks.
|
| 123 |
+
- Weight and checkpoint structure audits.
|
| 124 |
+
|
| 125 |
+
Held-out KD validation loss is important, but it cannot prove that the model improved on math, code, multiple-choice reasoning, or assistant behavior.
|
| 126 |
+
|
| 127 |
+
## 11. Fail-Fast Beats Silent Recovery
|
| 128 |
+
|
| 129 |
+
The pipeline hardened around a simple rule: corrupt artifacts should stop the run.
|
| 130 |
+
|
| 131 |
+
Examples:
|
| 132 |
+
|
| 133 |
+
- Missing teacher-logit shards fail instead of becoming zero tensors.
|
| 134 |
+
- Tokenization with zero usable rows fails immediately.
|
| 135 |
+
- Shard schema mismatches are rejected.
|
| 136 |
+
- Packed/unpacked checkpoint resume mismatches are rejected.
|
| 137 |
+
- Stale evaluation outputs are cleaned before new scores are written.
|
| 138 |
+
|
| 139 |
+
This makes errors louder, but it keeps published numbers trustworthy.
|
| 140 |
+
|
| 141 |
+
## 12. Public Docs Should Preserve Decisions
|
| 142 |
+
|
| 143 |
+
A release-quality project should expose durable engineering conclusions:
|
| 144 |
+
|
| 145 |
+
- why online KD replaced offline top-k KD,
|
| 146 |
+
- why assistant-only masking matters,
|
| 147 |
+
- why raw/chat evaluation controls are required,
|
| 148 |
+
- why sequence packing changed throughput,
|
| 149 |
+
- why SFT remains necessary after KD,
|
| 150 |
+
- why checkpoint and provenance checks exist.
|
| 151 |
+
|
| 152 |
+
That level of detail is enough for technical readers without turning the documentation into a chronological run journal.
|
docs/evaluation_methodology.md
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluation Methodology
|
| 2 |
+
|
| 3 |
+
Evaluation was one of the hardest parts of Quintus. Several early scores were misleading until prompt format, metric extraction, parser behavior, and runtime artifacts were audited carefully.
|
| 4 |
+
|
| 5 |
+
## Evaluation Principle
|
| 6 |
+
|
| 7 |
+
A model comparison is only meaningful when the prompt format and metric path match the question being asked.
|
| 8 |
+
|
| 9 |
+
Two distinct questions matter:
|
| 10 |
+
|
| 11 |
+
- Base capability: Does the distilled model improve raw reasoning and likelihood behavior?
|
| 12 |
+
- Assistant behavior: Does the distilled model handle chat formatting and produce usable responses?
|
| 13 |
+
|
| 14 |
+
Those questions need separate controls.
|
| 15 |
+
|
| 16 |
+
## Run Identity And Determinism
|
| 17 |
+
|
| 18 |
+
A benchmark record should identify the checkpoint role (`best` or `last`), exact model directory or revision, prompt mode, seeds, decoding mode, and runtime versions. Greedy decoding with fixed seeds makes repeated runs easier to compare, but hardware and kernel drift can still change edge cases.
|
| 19 |
+
|
| 20 |
+
Treat determinism as an artifact contract, not a vague claim.
|
| 21 |
+
|
| 22 |
+
## Raw-To-Raw And Chat-To-Chat
|
| 23 |
+
|
| 24 |
+
Raw completion and chat-template prompting activate different model behavior. A base model can be strong in raw mode and weak under chat markup. An instruct model can be strong in chat, but weak on raw continuation-style likelihood tasks.
|
| 25 |
+
|
| 26 |
+
Recommended controls:
|
| 27 |
+
|
| 28 |
+
- Raw-to-raw: compare base-style prompts against base-style prompts.
|
| 29 |
+
- Chat-to-chat: compare chat-wrapped prompts against chat-wrapped prompts.
|
| 30 |
+
- Raw-vs-chat within the same model: measure format tax.
|
| 31 |
+
|
| 32 |
+
Avoid comparing a chat-wrapped distilled model directly against a raw base baseline and treating the delta as pure capability transfer.
|
| 33 |
+
|
| 34 |
+
## Log-Likelihood Tasks Should Usually Stay Raw
|
| 35 |
+
|
| 36 |
+
Multiple-choice tasks such as ARC-Challenge, HellaSwag, and PIQA often score options by likelihood:
|
| 37 |
+
|
| 38 |
+
$$
|
| 39 |
+
P(\text{option}\mid\text{prompt})
|
| 40 |
+
$$
|
| 41 |
+
|
| 42 |
+
Wrapping the prompt in chat markup changes the next-token distribution. An aligned model may not want to begin a response with a bare option string after `<|im_start|>assistant`, so option likelihoods can fall for formatting reasons rather than reasoning reasons.
|
| 43 |
+
|
| 44 |
+
For log-likelihood tasks:
|
| 45 |
+
|
| 46 |
+
- Use raw completion format unless the benchmark was designed for chat.
|
| 47 |
+
- Prefer `acc_norm` where length bias matters.
|
| 48 |
+
- Record whether chat templates were applied.
|
| 49 |
+
|
| 50 |
+
## GSM8K Parser Traps
|
| 51 |
+
|
| 52 |
+
GSM8K evaluation can be distorted by parser behavior.
|
| 53 |
+
|
| 54 |
+
Two common filters behave differently:
|
| 55 |
+
|
| 56 |
+
- `strict-match`: looks for an answer after the `####` delimiter.
|
| 57 |
+
- `flexible-extract`: searches for numbers and may choose the last matched number.
|
| 58 |
+
|
| 59 |
+
A chat model can solve the problem, emit the correct `####` answer, miss EOS, and continue into a hallucinated next dialogue turn containing another number. In that case:
|
| 60 |
+
|
| 61 |
+
- `strict-match` may score the response correct.
|
| 62 |
+
- `flexible-extract` may grab the later hallucinated number and score it wrong.
|
| 63 |
+
|
| 64 |
+
This is not just a parser detail. It reveals an EOS and prompt-format interaction.
|
| 65 |
+
|
| 66 |
+
Mitigations:
|
| 67 |
+
|
| 68 |
+
- Register all relevant EOS tokens, including `<|im_end|>` and `<|endoftext|>`.
|
| 69 |
+
- Use deterministic generation for benchmark runs.
|
| 70 |
+
- Avoid excessive `fewshot_as_multiturn` wrapping unless the model was trained for that shape.
|
| 71 |
+
- Inspect mismatches, not just aggregate scores.
|
| 72 |
+
|
| 73 |
+
## Reasoning Models Need Enough Generation Budget
|
| 74 |
+
|
| 75 |
+
Instruction-tuned reasoning models may spend hundreds of tokens inside a reasoning trace before reaching the final answer. If `max_new_tokens` is too small, the model can be cut off before emitting the final answer marker.
|
| 76 |
+
|
| 77 |
+
That can make a capable model appear weak under exact-match metrics.
|
| 78 |
+
|
| 79 |
+
For fair GSM8K-style generation:
|
| 80 |
+
|
| 81 |
+
- Set a sufficient generation limit.
|
| 82 |
+
- Track truncation rate.
|
| 83 |
+
- Compare extracted answers against raw responses during audits.
|
| 84 |
+
|
| 85 |
+
## Batched Generation Details
|
| 86 |
+
|
| 87 |
+
Decoder-only batched generation should use left-padding. Right-padding can put the next-token position on padding for shorter prompts and make batched outputs differ from single-sample outputs.
|
| 88 |
+
|
| 89 |
+
Generation parsers should:
|
| 90 |
+
|
| 91 |
+
- Set `tokenizer.padding_side = "left"` for batched generation.
|
| 92 |
+
- Slice decoded continuations by each prompt's true input length.
|
| 93 |
+
- Stop at the first registered EOS token.
|
| 94 |
+
- Record truncation and empty-generation counts.
|
| 95 |
+
|
| 96 |
+
## English-Only Evaluation Controls
|
| 97 |
+
|
| 98 |
+
For English-only release checks, filtering the dataset is necessary but not sufficient. Evaluation should also use an English-only system instruction when chat prompts are enabled, register all relevant EOS IDs, and clean generated artifacts that continue into another language after the intended answer.
|
| 99 |
+
|
| 100 |
+
This cleanup is an evaluation-artifact guard. It is not a substitute for training data quality, SFT, preference tuning, or behavioral calibration.
|
| 101 |
+
|
| 102 |
+
## Metric Extraction Must Be Strict
|
| 103 |
+
|
| 104 |
+
Post-processing scripts should never fall back loosely to any metric key that starts with the right prefix. A loose fallback can accidentally read:
|
| 105 |
+
|
| 106 |
+
- `*_stderr`
|
| 107 |
+
- `alias`
|
| 108 |
+
- a different filter result
|
| 109 |
+
|
| 110 |
+
Robust extraction should:
|
| 111 |
+
|
| 112 |
+
- Match the exact metric and filter name.
|
| 113 |
+
- Ignore stderr and alias fields when extracting scores.
|
| 114 |
+
- Fail loudly if the expected key is absent.
|
| 115 |
+
|
| 116 |
+
## Boolean CLI Flags
|
| 117 |
+
|
| 118 |
+
Some harness flags use `action="store_true"`. Passing `"False"` after such a flag does not disable it; the presence of the flag enables it.
|
| 119 |
+
|
| 120 |
+
Correct pattern:
|
| 121 |
+
|
| 122 |
+
- Include the flag only when true.
|
| 123 |
+
- Omit the flag when false.
|
| 124 |
+
|
| 125 |
+
This matters for options such as multiturn few-shot formatting.
|
| 126 |
+
|
| 127 |
+
## Sample Log Format
|
| 128 |
+
|
| 129 |
+
`lm-evaluation-harness` may log different filters for the same document as separate JSONL objects with the same `doc_id`. A parser that assumes one object contains all filters can crash or silently compare the wrong fields.
|
| 130 |
+
|
| 131 |
+
Correct approach:
|
| 132 |
+
|
| 133 |
+
- Group sample records by `doc_id`.
|
| 134 |
+
- Index filter-specific records inside each group.
|
| 135 |
+
- Compare strict and flexible outputs from the same document.
|
| 136 |
+
|
| 137 |
+
## JSONL Parsing With Unicode Line Separators
|
| 138 |
+
|
| 139 |
+
Model outputs can contain Unicode line separator characters such as `\u2028` or `\u2029`. Calling `str.splitlines()` on a whole JSONL file can split a valid JSON string into invalid fragments.
|
| 140 |
+
|
| 141 |
+
Robust JSONL parsing:
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 145 |
+
for line in f:
|
| 146 |
+
if line.strip():
|
| 147 |
+
record = json.loads(line)
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Iterating the file handle respects actual line endings and does not split on Unicode separators inside JSON strings.
|
| 151 |
+
|
| 152 |
+
## Hub Loading And Snapshot Hygiene
|
| 153 |
+
|
| 154 |
+
If weights or datasets are stored on the Hub, the client should be told the correct repository type. Download or snapshot the artifact first, verify that expected files exist, then pass the local directory to Transformers, vLLM, or the evaluation harness.
|
| 155 |
+
|
| 156 |
+
This separates transfer failures from engine construction and avoids repeated downloads during long benchmark runs.
|
| 157 |
+
|
| 158 |
+
Optional high-throughput Hub transfer backends such as `hf_transfer` can reduce setup time, but the correctness contract is still local snapshot validation.
|
| 159 |
+
|
| 160 |
+
## Path-Length And Output Artifacts
|
| 161 |
+
|
| 162 |
+
Evaluation tools can derive output paths from model paths. Deep Hugging Face cache paths can become extremely long after sanitization, especially on Windows.
|
| 163 |
+
|
| 164 |
+
Public guidance:
|
| 165 |
+
|
| 166 |
+
- Copy or symlink model weights to a short local directory before evaluation.
|
| 167 |
+
- Pass short relative paths to the evaluator.
|
| 168 |
+
- Keep result directories shallow.
|
| 169 |
+
- Fail if expected sample files are missing.
|
| 170 |
+
|
| 171 |
+
This prevents silent write failures and missing-output confusion.
|
| 172 |
+
|
| 173 |
+
## vLLM Evaluation Settings
|
| 174 |
+
|
| 175 |
+
For large benchmark runs, vLLM can greatly reduce runtime through continuous batching and KV-cache management.
|
| 176 |
+
|
| 177 |
+
Useful settings in development:
|
| 178 |
+
|
| 179 |
+
- `batch_size = auto`
|
| 180 |
+
- prefix caching enabled
|
| 181 |
+
- PagedAttention-backed KV-cache management when available
|
| 182 |
+
- bounded GPU memory utilization
|
| 183 |
+
- explicit `max_model_len` where context bounds matter
|
| 184 |
+
- explicit attention backend where the runtime supports it
|
| 185 |
+
- local pre-caching of model snapshots before engine construction
|
| 186 |
+
- explicit engine teardown between model runs
|
| 187 |
+
|
| 188 |
+
The benchmark artifact should record runtime versions for:
|
| 189 |
+
|
| 190 |
+
- `lm-eval`
|
| 191 |
+
- `vllm`
|
| 192 |
+
- `transformers`
|
| 193 |
+
- `torch`
|
| 194 |
+
- `datasets`
|
| 195 |
+
- `accelerate`
|
| 196 |
+
|
| 197 |
+
Version drift can change metric keys, generation behavior, attention backends, and output formats.
|
| 198 |
+
|
| 199 |
+
## Qualitative Evaluation
|
| 200 |
+
|
| 201 |
+
Open-ended prompt suites are useful, but they are not replacements for standardized benchmarks.
|
| 202 |
+
|
| 203 |
+
A good qualitative suite should:
|
| 204 |
+
|
| 205 |
+
- Compare raw and chat modes separately.
|
| 206 |
+
- Use fixed prompts and deterministic ordering.
|
| 207 |
+
- Include benchmark-template leakage probes.
|
| 208 |
+
- Include factual, math, code, system design, and LLM-internals prompts.
|
| 209 |
+
- Record complete outputs.
|
| 210 |
+
- Inspect inherited base-model errors separately from new chat-mode errors.
|
| 211 |
+
|
| 212 |
+
Qualitative failures should be classified:
|
| 213 |
+
|
| 214 |
+
- Distillation failure: the student did not absorb useful teacher probability structure.
|
| 215 |
+
- Alignment gap: capability exists, but the generation path lacks SFT, preference tuning, or calibration.
|
| 216 |
+
- Data contamination: the model repeats benchmark or pretraining artifacts.
|
| 217 |
+
- Code reliability gap: prose is correct, but generated code violates stated constraints.
|
| 218 |
+
|
| 219 |
+
This distinction prevents the wrong fix. Distillation failures need KD changes. Alignment gaps need SFT, DPO, RLHF, or curated behavior data.
|
| 220 |
+
|
| 221 |
+
## Release Gate
|
| 222 |
+
|
| 223 |
+
The final checkpoint should pass all of these before public claims are made:
|
| 224 |
+
|
| 225 |
+
- Benchmark tasks use the intended prompt format.
|
| 226 |
+
- Metric keys are exact.
|
| 227 |
+
- Sample counts match the full benchmark set.
|
| 228 |
+
- Raw and chat comparisons are not mixed.
|
| 229 |
+
- Generation limits are sufficient for the model style.
|
| 230 |
+
- Checkpoint identity is explicit.
|
| 231 |
+
- Missing requested checkpoints fail instead of falling back to older local weights.
|
| 232 |
+
- Runtime versions are recorded.
|
| 233 |
+
- Mismatch samples are inspected for parser artifacts.
|
| 234 |
+
- No stale result directory or old JSON file is reused.
|
docs/experiment_timeline.md
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Experiment Timeline
|
| 2 |
+
|
| 3 |
+
This timeline explains why the final Quintus design looks the way it does. It focuses on the technical evolution from sparse offline distillation to the final online full-vocabulary pipeline.
|
| 4 |
+
|
| 5 |
+
## 1. Offline Top-K KD Prototype
|
| 6 |
+
|
| 7 |
+
The earliest design precomputed teacher logits to disk and trained the student from cached top-k supports.
|
| 8 |
+
|
| 9 |
+
Why it was attractive:
|
| 10 |
+
|
| 11 |
+
- Avoided loading teacher and student together.
|
| 12 |
+
- Reduced KD memory from full vocabulary to top-k support.
|
| 13 |
+
- Made cloud interruptions easier to survive because teacher logits were already saved.
|
| 14 |
+
|
| 15 |
+
Main lessons:
|
| 16 |
+
|
| 17 |
+
- Serialization contracts matter as much as loss math.
|
| 18 |
+
- Top-k token IDs need safe dtypes.
|
| 19 |
+
- Teacher-logit shards must preserve original row order.
|
| 20 |
+
- Missing or stale shards should fail loudly.
|
| 21 |
+
|
| 22 |
+
## 2. Static Audit And Fail-Fast Hardening
|
| 23 |
+
|
| 24 |
+
The project then moved through a static-audit phase focused on silent failure modes.
|
| 25 |
+
|
| 26 |
+
Major hardening themes:
|
| 27 |
+
|
| 28 |
+
- Dataset zero-retention checks.
|
| 29 |
+
- Missing-shard hard failures.
|
| 30 |
+
- Stale artifact cleanup.
|
| 31 |
+
- DeepSpeed accumulation correctness.
|
| 32 |
+
- Rank-safe writes.
|
| 33 |
+
- Explicit model revision and remote-code policy.
|
| 34 |
+
- Stronger provenance metadata.
|
| 35 |
+
|
| 36 |
+
This phase turned the code from a script bundle into a more reliable training pipeline.
|
| 37 |
+
|
| 38 |
+
## 3. Assistant-Only Supervision
|
| 39 |
+
|
| 40 |
+
The tokenization path originally risked supervising the whole conversation. That can over-train prompts, headers, and formatting tokens.
|
| 41 |
+
|
| 42 |
+
The corrected path derives `loss_mask` and trains only on assistant response tokens.
|
| 43 |
+
|
| 44 |
+
This changed the training contract:
|
| 45 |
+
|
| 46 |
+
- Prompt tokens provide context.
|
| 47 |
+
- Assistant tokens receive CE and KD loss.
|
| 48 |
+
- Rows without assistant targets are rejected.
|
| 49 |
+
- Checkpoints and datasets must agree on the mask schema.
|
| 50 |
+
|
| 51 |
+
## 4. Top-K Plus Residual Bucket
|
| 52 |
+
|
| 53 |
+
A later offline-KD pass improved the sparse support by adding an "other" bucket for teacher probability mass outside top-k.
|
| 54 |
+
|
| 55 |
+
This fixed a mathematical weakness: the student should be normalized against the full vocabulary before comparison, not only inside top-k. The residual bucket made offline KD less wrong, but it still compressed most of the teacher distribution into one scalar.
|
| 56 |
+
|
| 57 |
+
That design was useful, but not enough for flagship results.
|
| 58 |
+
|
| 59 |
+
## 5. Dataset And Objective Mismatch
|
| 60 |
+
|
| 61 |
+
Smoke runs showed a pattern that became important later: held-out KD validation loss can improve while benchmark quality worsens.
|
| 62 |
+
|
| 63 |
+
Key diagnosis:
|
| 64 |
+
|
| 65 |
+
- Matching teacher token distributions on a training corpus is not identical to improving GSM8K, ARC, coding, or open-ended assistant quality.
|
| 66 |
+
- Dataset order and first-N streaming can bias sample selection.
|
| 67 |
+
- Long reasoning traces can overweight style and process tokens relative to final answers.
|
| 68 |
+
- Small students can forget useful baseline behavior when full-parameter training is too aggressive.
|
| 69 |
+
|
| 70 |
+
This motivated stricter downstream evaluation gates.
|
| 71 |
+
|
| 72 |
+
## 6. Base Student Pivot
|
| 73 |
+
|
| 74 |
+
Several runs tested whether distilling into an already-instruct-tuned student caused destructive interference. The base-student hypothesis was sound: a raw base model has more plasticity and fewer alignment paths to overwrite.
|
| 75 |
+
|
| 76 |
+
The result was only a marginal improvement under offline top-k KD. That was the decisive clue.
|
| 77 |
+
|
| 78 |
+
Conclusion:
|
| 79 |
+
|
| 80 |
+
The student choice was not the main bottleneck. Offline top-k sparsity was the main bottleneck.
|
| 81 |
+
|
| 82 |
+
## 7. Offline Top-K Ceiling
|
| 83 |
+
|
| 84 |
+
With $k = 8$, the student saw only a tiny fraction of the teacher vocabulary distribution per target token:
|
| 85 |
+
|
| 86 |
+
$$
|
| 87 |
+
\frac{k}{|V|}
|
| 88 |
+
= \frac{8}{151{,}665}
|
| 89 |
+
\approx 5.3 \times 10^{-5}
|
| 90 |
+
= 0.0053\%
|
| 91 |
+
$$
|
| 92 |
+
|
| 93 |
+
Different $\alpha$ values, epochs, and student initializations did not remove this limit.
|
| 94 |
+
|
| 95 |
+
Offline top-k KD could perturb the student and sometimes improve narrow metrics, but it could not reliably transfer the teacher's broader reasoning distribution.
|
| 96 |
+
|
| 97 |
+
The project stopped treating offline top-k KD as the path to a flagship model.
|
| 98 |
+
|
| 99 |
+
## 8. Online Full-Vocabulary KD
|
| 100 |
+
|
| 101 |
+

|
| 102 |
+
|
| 103 |
+
Online KD became the final architecture.
|
| 104 |
+
|
| 105 |
+
Instead of reading cached teacher shards, the training loop loads a frozen teacher and runs live teacher forward passes beside the student. The KD loss uses the teacher's full-vocabulary distribution.
|
| 106 |
+
|
| 107 |
+
Benefits:
|
| 108 |
+
|
| 109 |
+
- No top-k sparsity ceiling.
|
| 110 |
+
- No shard-order mismatch risk.
|
| 111 |
+
- No stale teacher-logit cache.
|
| 112 |
+
- Stronger transfer signal for reasoning.
|
| 113 |
+
|
| 114 |
+
Cost:
|
| 115 |
+
|
| 116 |
+
- Higher VRAM footprint.
|
| 117 |
+
- Teacher and student must fit together.
|
| 118 |
+
- KL computation needs chunking.
|
| 119 |
+
- Throughput depends heavily on packing and kernels.
|
| 120 |
+
|
| 121 |
+
## 9. Sequence Packing And B200 Tuning
|
| 122 |
+
|
| 123 |
+
Sequence packing converted padding waste into useful tokens.
|
| 124 |
+
|
| 125 |
+
The packing implementation:
|
| 126 |
+
|
| 127 |
+
- Packs only training data.
|
| 128 |
+
- Keeps validation easier to interpret.
|
| 129 |
+
- Uses fixed 4096-token bins.
|
| 130 |
+
- Inserts masked EOS separators.
|
| 131 |
+
- Stores packing metadata in checkpoints.
|
| 132 |
+
- Rejects packed/unpacked resume mismatches.
|
| 133 |
+
|
| 134 |
+
Development probes showed the expected utilization improvement and made online KD fast enough for serious single-GPU runs.
|
| 135 |
+
|
| 136 |
+
## 10. English-Only Final Data
|
| 137 |
+
|
| 138 |
+
The release run focuses on English samples.
|
| 139 |
+
|
| 140 |
+
Reasons:
|
| 141 |
+
|
| 142 |
+
- Reduce language drift in open-ended outputs.
|
| 143 |
+
- Keep the model's assistant behavior aligned with the intended release language.
|
| 144 |
+
- Make qualitative evaluation cleaner.
|
| 145 |
+
- Avoid CJK continuation artifacts after missed EOS.
|
| 146 |
+
|
| 147 |
+
The tradeoff is real: removing multilingual data can reduce access to some reasoning traces. For a public English assistant, language stability is worth that tradeoff.
|
| 148 |
+
|
| 149 |
+
## 11. Targeted SFT After KD
|
| 150 |
+
|
| 151 |
+
Online KD transferred capability, but raw KD is not a full assistant-alignment process.
|
| 152 |
+
|
| 153 |
+
Targeted SFT was added after KD to improve:
|
| 154 |
+
|
| 155 |
+
- identity grounding,
|
| 156 |
+
- chat format stability,
|
| 157 |
+
- practical assistant style,
|
| 158 |
+
- repetition control,
|
| 159 |
+
- response presentation.
|
| 160 |
+
|
| 161 |
+
This created the final two-stage public model:
|
| 162 |
+
|
| 163 |
+
```text
|
| 164 |
+
Qwen3-1.7B-Base
|
| 165 |
+
-> online full-vocab KD from Qwen3-8B
|
| 166 |
+
-> targeted SFT
|
| 167 |
+
-> Quintus-1.7B
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
## 12. Release Verification
|
| 171 |
+
|
| 172 |
+
The final release surface combines:
|
| 173 |
+
|
| 174 |
+
- benchmark scoreboard,
|
| 175 |
+
- architecture documentation,
|
| 176 |
+
- evaluation methodology notes,
|
| 177 |
+
- pipeline hardening notes,
|
| 178 |
+
- weight audit,
|
| 179 |
+
- model-card draft.
|
| 180 |
+
|
| 181 |
+
The public docs focus on reusable methods, release results, and reproducible checks.
|
docs/huggingface_model_card.md
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quintus-1.7B
|
| 2 |
+
|
| 3 |
+
Quintus-1.7B is a compact instruction-following assistant derived from `Qwen/Qwen3-1.7B-Base`. It was trained with online full-vocabulary knowledge distillation from a larger Qwen3-8B teacher, followed by targeted SFT for assistant behavior and generation stability.
|
| 4 |
+
|
| 5 |
+
## Model Details
|
| 6 |
+
|
| 7 |
+
- Base architecture: Qwen3-1.7B
|
| 8 |
+
- Base checkpoint: `Qwen/Qwen3-1.7B-Base`
|
| 9 |
+
- Distillation teacher: Qwen3-8B class teacher
|
| 10 |
+
- Training method: Online full-vocabulary KD + targeted SFT
|
| 11 |
+
- Context length used in training: 4096 tokens
|
| 12 |
+
- Primary language focus: English
|
| 13 |
+
- Release repository: `iamrahulreddy/Quintus`
|
| 14 |
+
- Attention path: FlashAttention-2 when available
|
| 15 |
+
- Training kernels: Liger kernels for compatible Qwen-family operators
|
| 16 |
+
- Optimizer: fused AdamW
|
| 17 |
+
|
| 18 |
+
## Intended Use
|
| 19 |
+
|
| 20 |
+
Quintus is intended for:
|
| 21 |
+
|
| 22 |
+
- General assistant use.
|
| 23 |
+
- Reasoning and math prompts.
|
| 24 |
+
- Lightweight coding assistance.
|
| 25 |
+
- Local experimentation with compact LLMs.
|
| 26 |
+
- Research into online KD and small-model alignment.
|
| 27 |
+
|
| 28 |
+
It is not intended as a safety-critical decision system. Like other compact language models, it can hallucinate and should be verified on high-stakes tasks.
|
| 29 |
+
|
| 30 |
+
## Training Summary
|
| 31 |
+
|
| 32 |
+
The training pipeline has two main stages:
|
| 33 |
+
|
| 34 |
+
1. Online KD: The student learns from the teacher's dense full-vocabulary probability distribution. This avoids the sparse top-k ceiling encountered in earlier offline KD experiments.
|
| 35 |
+
2. SFT: The distilled checkpoint is tuned on curated instruction/persona data to improve assistant-style behavior and reduce repetition or formatting drift.
|
| 36 |
+
|
| 37 |
+
The KD loss combines assistant-token cross entropy and teacher-student KL divergence:
|
| 38 |
+
|
| 39 |
+
$$
|
| 40 |
+
\mathcal{L}_{\text{total}}
|
| 41 |
+
= \alpha \mathcal{L}_{\text{CE}}
|
| 42 |
+
+ (1 - \alpha)\mathcal{L}_{\text{KD}}
|
| 43 |
+
$$
|
| 44 |
+
|
| 45 |
+
For the release run, $\alpha = 0.3$ and $T = 2.0$.
|
| 46 |
+
|
| 47 |
+
`torch.compile` was kept disabled for the final KD path because this workload showed high Inductor memory overhead, dynamic-shape graph breaks, recompile overhead, and checkpoint portability risk from `_orig_mod.` state-dict prefixes when compiled modules are not unwrapped before saving.
|
| 48 |
+
|
| 49 |
+
## Evaluation
|
| 50 |
+
|
| 51 |
+
| Benchmark | Qwen3-1.7B-Base | Qwen3-1.7B-Instruct | Quintus-1.7B |
|
| 52 |
+
| :--- | :---: | :---: | :---: |
|
| 53 |
+
| HumanEval pass@1 | 67.1% | 70.7% | 67.7% |
|
| 54 |
+
| MBPP pass@1 | 67.2% | 58.2% | 64.8% |
|
| 55 |
+
| GSM8K, 10-shot flexible | 69.98% | 69.75% | 74.30% |
|
| 56 |
+
| ARC-Challenge acc_norm | 55.72% | 52.99% | 58.36% |
|
| 57 |
+
| WinoGrande, 5-shot | 65.67% | 61.01% | 66.38% |
|
| 58 |
+
| PIQA acc_norm | 75.63% | 72.09% | 75.57% |
|
| 59 |
+
|
| 60 |
+
## Strengths
|
| 61 |
+
|
| 62 |
+
- Strong math and reasoning transfer for the 1.7B parameter scale.
|
| 63 |
+
- Good commonsense and ARC-style benchmark performance.
|
| 64 |
+
- Compact enough for lower-resource deployment compared with larger teachers.
|
| 65 |
+
- Public weight audit indicates healthy structural divergence from the base checkpoint without collapse.
|
| 66 |
+
|
| 67 |
+
## Limitations
|
| 68 |
+
|
| 69 |
+
- The model can still produce confident factual errors.
|
| 70 |
+
- Code generation can contradict stated complexity constraints.
|
| 71 |
+
- It is smaller than the teacher and inherits capacity limits of the 1.7B scale.
|
| 72 |
+
- Evaluation results depend on prompt format; raw and chat-template modes are not interchangeable.
|
| 73 |
+
- Additional preference tuning would likely improve calibration and refusal behavior.
|
| 74 |
+
|
| 75 |
+
## Example Usage
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
import torch
|
| 79 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
| 80 |
+
|
| 81 |
+
PUBLIC_REPO_ID = "iamrahulreddy/Quintus"
|
| 82 |
+
|
| 83 |
+
print(f"Loading Quintus from {PUBLIC_REPO_ID}...")
|
| 84 |
+
tokenizer = AutoTokenizer.from_pretrained(PUBLIC_REPO_ID, trust_remote_code=True)
|
| 85 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 86 |
+
PUBLIC_REPO_ID,
|
| 87 |
+
device_map="auto",
|
| 88 |
+
dtype=torch.float16,
|
| 89 |
+
trust_remote_code=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
|
| 93 |
+
eos_token_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
|
| 94 |
+
for token in stop_tokens:
|
| 95 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
| 96 |
+
if token_id is not None and token_id not in eos_token_ids:
|
| 97 |
+
eos_token_ids.append(token_id)
|
| 98 |
+
|
| 99 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 100 |
+
|
| 101 |
+
conversation_history = [
|
| 102 |
+
{
|
| 103 |
+
"role": "system",
|
| 104 |
+
"content": (
|
| 105 |
+
"You are Quintus, a highly capable AI assistant created by "
|
| 106 |
+
"Muskula Rahul. You are helpful, precise, and logically sound."
|
| 107 |
+
),
|
| 108 |
+
}
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
print()
|
| 112 |
+
print("Quintus Chat (type 'quit' to exit)")
|
| 113 |
+
print()
|
| 114 |
+
|
| 115 |
+
while True:
|
| 116 |
+
try:
|
| 117 |
+
user_input = input("You: ").strip()
|
| 118 |
+
if user_input.lower() in ["quit", "exit"]:
|
| 119 |
+
print("\nGoodbye!")
|
| 120 |
+
break
|
| 121 |
+
if not user_input:
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
conversation_history.append({"role": "user", "content": user_input})
|
| 125 |
+
|
| 126 |
+
prompt = tokenizer.apply_chat_template(
|
| 127 |
+
conversation_history,
|
| 128 |
+
tokenize=False,
|
| 129 |
+
add_generation_prompt=True,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 133 |
+
|
| 134 |
+
print("Quintus: ", end="", flush=True)
|
| 135 |
+
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
outputs = model.generate(
|
| 138 |
+
**inputs,
|
| 139 |
+
max_new_tokens=512,
|
| 140 |
+
temperature=0.7,
|
| 141 |
+
top_p=0.9,
|
| 142 |
+
do_sample=True,
|
| 143 |
+
streamer=streamer,
|
| 144 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 145 |
+
eos_token_id=eos_token_ids,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
|
| 149 |
+
assistant_response = tokenizer.decode(
|
| 150 |
+
generated_ids,
|
| 151 |
+
skip_special_tokens=True,
|
| 152 |
+
).strip()
|
| 153 |
+
conversation_history.append({"role": "assistant", "content": assistant_response})
|
| 154 |
+
print()
|
| 155 |
+
|
| 156 |
+
except KeyboardInterrupt:
|
| 157 |
+
print("\n\nGoodbye!")
|
| 158 |
+
break
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
## Credits
|
| 162 |
+
|
| 163 |
+
- [Qwen Team](https://qwenlm.github.io/) and the [Qwen Hugging Face organization](https://huggingface.co/Qwen) for the Qwen3 model family.
|
| 164 |
+
- [`Qwen/Qwen3-8B`](https://huggingface.co/Qwen/Qwen3-8B), used as the distillation teacher.
|
| 165 |
+
- [`Qwen/Qwen3-1.7B-Base`](https://huggingface.co/Qwen/Qwen3-1.7B-Base), used as the base student checkpoint.
|
| 166 |
+
- [`Qwen/Qwen3-1.7B`](https://huggingface.co/Qwen/Qwen3-1.7B), used for the tokenizer and chat-template contract.
|
| 167 |
+
- [Alibaba PAI](https://huggingface.co/alibaba-pai) for [`DistilQwen_100k`](https://huggingface.co/datasets/alibaba-pai/DistilQwen_100k), the primary instruction source after filtering.
|
| 168 |
+
- [Hugging Face Transformers](https://github.com/huggingface/transformers), [vLLM](https://github.com/vllm-project/vllm), [EvalPlus](https://github.com/evalplus/evalplus), [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness), [FlashAttention](https://github.com/Dao-AILab/flash-attention), and [Liger Kernel](https://github.com/linkedin/Liger-Kernel) for training and evaluation infrastructure.
|
| 169 |
+
|
| 170 |
+
## License And Author
|
| 171 |
+
|
| 172 |
+
This software is distributed under the MIT License. Refer to the repository [LICENSE](../LICENSE) file for full text.
|
| 173 |
+
|
| 174 |
+
Author: Muskula Rahul - [@iamrahulreddy](https://github.com/iamrahulreddy)
|
| 175 |
+
|
| 176 |
+
## Citation
|
| 177 |
+
|
| 178 |
+
If you use this model or code, cite the repository and the upstream Qwen3 models.
|
docs/index.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quintus Documentation
|
| 2 |
+
|
| 3 |
+
Quintus-1.7B is a compact assistant built from the Qwen3-1.7B-Base architecture. The project uses online full-vocabulary knowledge distillation from a Qwen3-8B teacher, followed by targeted SFT for instruction style, identity grounding, and generation stability.
|
| 4 |
+
|
| 5 |
+
This documentation summarizes the public architecture, training decisions, evaluation controls, and release artifacts for the showcase branch.
|
| 6 |
+
|
| 7 |
+
## Reading Order
|
| 8 |
+
|
| 9 |
+
- [Architecture](architecture.md): End-to-end pipeline, modules, data flow, and training phases.
|
| 10 |
+
- [Experiment Timeline](experiment_timeline.md): How the project moved from offline top-k KD to final online full-vocabulary KD.
|
| 11 |
+
- [Training Playbook](training_playbook.md): Practical training choices, memory rules, packing, kernels, and checkpointing.
|
| 12 |
+
- [Pipeline Hardening](pipeline_hardening.md): Silent-failure classes and the safeguards added around artifacts, provenance, and runtime.
|
| 13 |
+
- [Evaluation Methodology](evaluation_methodology.md): Benchmark controls, parser traps, raw/chat comparisons, and qualitative evaluation rules.
|
| 14 |
+
- [Engineering Insights](engineering_insights.md): Condensed technical lessons and design decisions.
|
| 15 |
+
- [Benchmarks](benchmarks.md): Verified evaluation results and interpretation.
|
| 16 |
+
- [Weight Audit](weight_audit.md): Structural checkpoint verification and what the audit means.
|
| 17 |
+
- [Hugging Face Model Card](huggingface_model_card.md): Release-page text for the public model card.
|
| 18 |
+
|
| 19 |
+
## Project Summary
|
| 20 |
+
|
| 21 |
+
The core thesis is simple: a small base model can absorb useful reasoning behavior from a larger instruction model if the distillation signal is dense enough and the evaluation controls are fair.
|
| 22 |
+
|
| 23 |
+
The project initially explored sparse offline top-k distillation, but that approach hit a ceiling because the student only saw a tiny fraction of the teacher vocabulary distribution. The final pipeline pivots to online KD, where teacher and student are run together and the student receives the teacher's full-vocabulary probability distribution during training.
|
| 24 |
+
|
| 25 |
+
After KD, a small SFT stage teaches the model how to expose that knowledge in a conversational interface. This separation matters: KD transfers capability; SFT and later preference training improve behavior, style, and confidence calibration.
|
| 26 |
+
|
| 27 |
+
## Repository Map
|
| 28 |
+
|
| 29 |
+
```text
|
| 30 |
+
configs/ Training configuration and DeepSpeed template.
|
| 31 |
+
src/ Online KD, data loading, losses, checkpointing, and packing.
|
| 32 |
+
sft/ Post-KD supervised fine-tuning, chat, and consolidated evaluation runner.
|
| 33 |
+
weight_audit/ Checkpoint structure and weight-divergence audit.
|
| 34 |
+
docs/ Public architecture, training, evaluation, and release notes.
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## Main Public Artifact
|
| 38 |
+
|
| 39 |
+
The final model weights are available at: [Quintus](https://huggingface.co/iamrahulreddy/Quintus)
|
| 40 |
+
|
| 41 |
+
The Colab quickstart is available at: [Colab Quick Chat](https://colab.research.google.com/drive/1TdMSN5HzD1mToCFVf_qQoj10NGZLy2V0?usp=sharing)
|
| 42 |
+
|
docs/pipeline_hardening.md
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pipeline Hardening
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
This page summarizes the correctness and reliability lessons that shaped the Quintus codebase. Most of these are silent-failure classes: the pipeline can appear to run while producing invalid or misleading artifacts.
|
| 6 |
+
|
| 7 |
+
## Silent Serialization Bugs
|
| 8 |
+
|
| 9 |
+
Teacher token IDs must be stored in a dtype that can represent the tokenizer vocabulary.
|
| 10 |
+
|
| 11 |
+
An early offline-KD path stored top-k token IDs too narrowly. Qwen token IDs exceed signed 16-bit range, so IDs could wrap negative and later be clamped into valid-looking but wrong positions. Training could continue, but the KL support was corrupted.
|
| 12 |
+
|
| 13 |
+
Hardening rule:
|
| 14 |
+
|
| 15 |
+
- Store token IDs as `int32` or wider.
|
| 16 |
+
- Validate IDs on load.
|
| 17 |
+
- Reject negative IDs.
|
| 18 |
+
- Reject IDs outside the student vocabulary.
|
| 19 |
+
- Treat dtype as part of the shard contract.
|
| 20 |
+
|
| 21 |
+
## Row-Order Preservation
|
| 22 |
+
|
| 23 |
+
Teacher-logit extraction often sorts samples by length for throughput. Training usually expects logits to match the original tokenized row order.
|
| 24 |
+
|
| 25 |
+
If sorted extraction writes shards in sorted order without restoring original indices, the student receives teacher logits for the wrong sample. This is a model-poisoning bug, not a performance issue.
|
| 26 |
+
|
| 27 |
+
Hardening rule:
|
| 28 |
+
|
| 29 |
+
- Batch by sorted length if useful.
|
| 30 |
+
- Preserve `original_idx`.
|
| 31 |
+
- Write final shards in original dataset order.
|
| 32 |
+
- Verify teacher-logit length against the tokenized row length at training time.
|
| 33 |
+
|
| 34 |
+
## Dataset Schema And Decoding
|
| 35 |
+
|
| 36 |
+
Public instruction datasets do not share a single row schema. Some rows arrive as `messages`; others use Alpaca-style `instruction`, `input`, and `output` fields. Some content fields contain nested dict/list payloads that need structured coercion before templating.
|
| 37 |
+
|
| 38 |
+
Dataset streaming can also fail late when a compression codec or file decoder is missing. That failure should remain visible instead of being replaced by a generic "zero samples" result.
|
| 39 |
+
|
| 40 |
+
Hardening rule:
|
| 41 |
+
|
| 42 |
+
- Detect Alpaca-style instruction/output rows before chat-message conversion.
|
| 43 |
+
- Coerce nested dict/list content through structured serialization, then normalize to text.
|
| 44 |
+
- Normalize common role aliases before applying a chat template.
|
| 45 |
+
- Preserve the first real dataset exception when streaming fails.
|
| 46 |
+
- Validate dataset decoding and schema mapping before large model downloads.
|
| 47 |
+
|
| 48 |
+
## Zero-Data And Data-Erasure Guards
|
| 49 |
+
|
| 50 |
+
Data preparation should fail when no usable rows are produced. It should also distinguish "download only" from "tokenize and overwrite output".
|
| 51 |
+
|
| 52 |
+
Hardening rule:
|
| 53 |
+
|
| 54 |
+
- Abort if filtering retains zero samples.
|
| 55 |
+
- Abort if tokenization writes zero rows.
|
| 56 |
+
- Do not open tokenized output in write mode for asset-only setup.
|
| 57 |
+
- Use explicit flags for model-only or data-only phases.
|
| 58 |
+
|
| 59 |
+
## Missing Shards Must Fail
|
| 60 |
+
|
| 61 |
+
Replacing missing teacher-logit shards with zero tensors makes the training loop look healthy while removing the KD signal.
|
| 62 |
+
|
| 63 |
+
Hardening rule:
|
| 64 |
+
|
| 65 |
+
- Missing shard means hard failure.
|
| 66 |
+
- Stale shard directories are cleaned before extraction.
|
| 67 |
+
- `_provenance.json` is required for KD.
|
| 68 |
+
- Shard count, sample count, max sequence length, temperature, top-k, and schema version are checked before training.
|
| 69 |
+
|
| 70 |
+
## Provenance Contracts
|
| 71 |
+
|
| 72 |
+
Path equality is weak provenance because paths change across machines. Data identity should come from content and model contracts.
|
| 73 |
+
|
| 74 |
+
Useful provenance fields:
|
| 75 |
+
|
| 76 |
+
- schema version
|
| 77 |
+
- dataset fingerprint or SHA-256
|
| 78 |
+
- sample count
|
| 79 |
+
- shard count
|
| 80 |
+
- max sequence length
|
| 81 |
+
- top-k or full-vocab mode
|
| 82 |
+
- temperature
|
| 83 |
+
- teacher model ID and revision
|
| 84 |
+
- student model ID and revision
|
| 85 |
+
- tokenizer sizes
|
| 86 |
+
- tokenizer fingerprints
|
| 87 |
+
- shard dtypes
|
| 88 |
+
|
| 89 |
+
Tokenizer fingerprints can drift across library versions. Vocab size and schema compatibility should remain hard gates; fingerprint drift can be a warning when stronger invariants still match.
|
| 90 |
+
|
| 91 |
+
## Assistant-Only Loss Masks
|
| 92 |
+
|
| 93 |
+
Supervising prompt and chat-template tokens can teach formatting before substance. It can also make chat-mode behavior fragile.
|
| 94 |
+
|
| 95 |
+
Hardening rule:
|
| 96 |
+
|
| 97 |
+
- Tokenized rows must include `loss_mask`.
|
| 98 |
+
- Loss mask must be binary.
|
| 99 |
+
- Rows with zero assistant targets are rejected.
|
| 100 |
+
- User prompts, system prompts, separators, and padding are not targets.
|
| 101 |
+
- Assistant response tokens are the supervised region.
|
| 102 |
+
|
| 103 |
+
Prefix-stable mask derivation is useful when tokenizer-provided assistant masks are unavailable.
|
| 104 |
+
|
| 105 |
+
## Gradient Accumulation Semantics
|
| 106 |
+
|
| 107 |
+
DeepSpeed and non-DeepSpeed paths need different step-accounting logic.
|
| 108 |
+
|
| 109 |
+
DeepSpeed accumulation is global across the full run, not local to each epoch. Epoch-end remainder branches should not create phantom optimizer steps.
|
| 110 |
+
|
| 111 |
+
Non-DeepSpeed accumulation needs an explicit final flush when a leftover accumulation window exists. That flush must rescale gradients so the update represents the mean over the remainder, not a shrunken `remainder / grad_accum` update.
|
| 112 |
+
|
| 113 |
+
Hardening rule:
|
| 114 |
+
|
| 115 |
+
- Advance `global_step` only after a real optimizer update.
|
| 116 |
+
- Align scheduler steps with real updates.
|
| 117 |
+
- Log flush steps.
|
| 118 |
+
- Include flush steps in training-loss CSVs.
|
| 119 |
+
- Prefer validation split sizes that align with effective batch size.
|
| 120 |
+
|
| 121 |
+
## Checkpoint Semantics
|
| 122 |
+
|
| 123 |
+
`init_from_checkpoint` and `resume_from_checkpoint` are different operations.
|
| 124 |
+
|
| 125 |
+
- Initialization starts a new phase from an existing model.
|
| 126 |
+
- Resume continues an interrupted phase from training state.
|
| 127 |
+
|
| 128 |
+
Mixing the two can skip training, restart from the wrong model, or reuse stale state.
|
| 129 |
+
|
| 130 |
+
Hardening rule:
|
| 131 |
+
|
| 132 |
+
- Forbid simultaneous init and resume.
|
| 133 |
+
- Save trainer state and scheduler state.
|
| 134 |
+
- Search both `step_*` and `epoch_*` checkpoints for resume.
|
| 135 |
+
- Store batch offset for mid-epoch resume.
|
| 136 |
+
- Keep final model-loading checkpoints portable.
|
| 137 |
+
|
| 138 |
+
## Compiler Portability
|
| 139 |
+
|
| 140 |
+
Compiled PyTorch modules can save weights with `_orig_mod.` prefixes if not unwrapped. Standard Transformers and vLLM loaders do not expect those keys.
|
| 141 |
+
|
| 142 |
+
Hardening rule:
|
| 143 |
+
|
| 144 |
+
- Keep `torch.compile` opt-in.
|
| 145 |
+
- Treat dynamic-shape recompile overhead as a throughput risk, not just a startup cost.
|
| 146 |
+
- Unwrap compiled modules before saving.
|
| 147 |
+
- Strip `_orig_mod.` only as a repair path, not as the normal release path.
|
| 148 |
+
- Verify saved checkpoints load through standard APIs.
|
| 149 |
+
|
| 150 |
+
## Artifact Hygiene
|
| 151 |
+
|
| 152 |
+
Stale outputs are a real ML correctness problem. Old result JSONs, old plots, or old sample logs can make a failed run look successful.
|
| 153 |
+
|
| 154 |
+
Hardening rule:
|
| 155 |
+
|
| 156 |
+
- Clean evaluation output directories before a new run.
|
| 157 |
+
- Clean stale plots before rendering.
|
| 158 |
+
- Select result files by clear recency rules.
|
| 159 |
+
- Fail if expected task outputs are incomplete.
|
| 160 |
+
- Fail if a requested checkpoint is missing; do not fall back to older local weights.
|
| 161 |
+
- Include runtime versions in result summaries.
|
| 162 |
+
|
| 163 |
+
## Environment Contracts
|
| 164 |
+
|
| 165 |
+
Notebook and cloud images often contain mixed binary packages. Import success for `torch` alone does not prove the stack is healthy.
|
| 166 |
+
|
| 167 |
+
Hardening rule:
|
| 168 |
+
|
| 169 |
+
- Treat `torch`, `torchvision`, and `torchaudio` as one binary compatibility family.
|
| 170 |
+
- Use staged dependency manifests instead of ad hoc installs.
|
| 171 |
+
- Keep vLLM dependencies separate from HF-only evaluation dependencies.
|
| 172 |
+
- Prefer clear preflight errors over late framework crashes.
|
| 173 |
+
- Print exception chains, not only the outer error.
|
| 174 |
+
|
| 175 |
+
## Remote Code And Revisions
|
| 176 |
+
|
| 177 |
+
Model loading should be reproducible and explicit.
|
| 178 |
+
|
| 179 |
+
Hardening rule:
|
| 180 |
+
|
| 181 |
+
- Pin teacher, student, and tokenizer revisions when possible.
|
| 182 |
+
- Default remote-code trust to false.
|
| 183 |
+
- Provide an explicit override for models that need custom code.
|
| 184 |
+
- Explain remote-code failures clearly.
|
| 185 |
+
|
| 186 |
+
## Safe Logging
|
| 187 |
+
|
| 188 |
+
Training logs should be rich enough for issue diagnosis without dumping config internals.
|
| 189 |
+
|
| 190 |
+
Hardening rule:
|
| 191 |
+
|
| 192 |
+
- Avoid logging authentication values or full config payloads.
|
| 193 |
+
- Disable traceback local-variable dumps in rich tracebacks.
|
| 194 |
+
- Strip ANSI sequences from file logs while keeping colored notebook output if desired.
|
| 195 |
+
- Use UTF-8 file logs and replacement-safe console output for generated model text.
|
| 196 |
+
- Log checkpoint save/upload intent, output size, duration, and destination path without sensitive values.
|
| 197 |
+
|
| 198 |
+
## Public Release Rule
|
| 199 |
+
|
| 200 |
+
A project can be release-ready without every possible production safeguard. The line is crossed when:
|
| 201 |
+
|
| 202 |
+
- known silent corruption paths are removed,
|
| 203 |
+
- remaining tradeoffs are documented,
|
| 204 |
+
- artifacts are reproducible enough to audit,
|
| 205 |
+
- public docs focus on decisions, methods, and release artifacts,
|
| 206 |
+
- evaluation claims are tied to clear methodology.
|
| 207 |
+
|
| 208 |
+
For Quintus, the release surface should describe the engineering decisions and results.
|
docs/training_playbook.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Playbook
|
| 2 |
+
|
| 3 |
+
This page captures the practical training lessons behind Quintus. It focuses on the engineering decisions that made the final online-KD run stable, reproducible, and fast enough to complete on large single-GPU hardware.
|
| 4 |
+
|
| 5 |
+
## Core Objective
|
| 6 |
+
|
| 7 |
+
The training objective combines assistant-token cross entropy with teacher-student KL divergence:
|
| 8 |
+
|
| 9 |
+
$$
|
| 10 |
+
\mathcal{L}_{\text{total}}
|
| 11 |
+
= \alpha \mathcal{L}_{\text{CE}}
|
| 12 |
+
+ (1 - \alpha)\mathcal{L}_{\text{KD}}
|
| 13 |
+
$$
|
| 14 |
+
|
| 15 |
+
For the final Qwen3 run:
|
| 16 |
+
|
| 17 |
+
$$
|
| 18 |
+
\alpha = 0.3,\quad
|
| 19 |
+
T = 2.0,\quad
|
| 20 |
+
C_{\text{KD}} = 2048,\quad
|
| 21 |
+
S_{\max} = 4096
|
| 22 |
+
$$
|
| 23 |
+
|
| 24 |
+
In this codebase, $\alpha$ is the cross-entropy weight. Lower $\alpha$ gives the teacher distribution more influence. Higher $\alpha$ gives hard assistant targets more influence.
|
| 25 |
+
|
| 26 |
+
## Why Online KD Replaced Offline Top-K KD
|
| 27 |
+
|
| 28 |
+
The early pipeline precomputed only a small top-k slice of the teacher distribution. That made storage and training cheaper, but it created a hard information ceiling.
|
| 29 |
+
|
| 30 |
+
With a Qwen vocabulary around 151K tokens:
|
| 31 |
+
|
| 32 |
+
$$
|
| 33 |
+
\frac{k}{|V|}
|
| 34 |
+
= \frac{8}{151{,}665}
|
| 35 |
+
\approx 5.3 \times 10^{-5}
|
| 36 |
+
= 0.0053\%
|
| 37 |
+
$$
|
| 38 |
+
|
| 39 |
+
That sparse signal was enough to disturb student weights, but not enough to reliably transfer deeper reasoning behavior. Several development probes changed alpha, epochs, and student initialization; the same ceiling remained.
|
| 40 |
+
|
| 41 |
+
The final online path removes that bottleneck. Teacher and student run together, and the KL term is computed from the live full-vocabulary teacher distribution.
|
| 42 |
+
|
| 43 |
+
## Memory Shape To Respect
|
| 44 |
+
|
| 45 |
+
Full-vocabulary KD is dominated by logits:
|
| 46 |
+
|
| 47 |
+
$$
|
| 48 |
+
\text{student\_logits},\ \text{teacher\_logits}
|
| 49 |
+
\in \mathbb{R}^{B \times S \times |V|}
|
| 50 |
+
$$
|
| 51 |
+
|
| 52 |
+
At Qwen vocabulary scale, increasing micro-batch size by one can add many GiB of temporary memory pressure. Effective batch size is not the same as memory cost. Peak memory is mostly driven by micro-batch size, sequence length, vocabulary width, activation storage, and the backward pass.
|
| 53 |
+
|
| 54 |
+
Useful rule:
|
| 55 |
+
|
| 56 |
+
$$
|
| 57 |
+
B_{\text{eff}} = B_{\mu} \times A
|
| 58 |
+
$$
|
| 59 |
+
|
| 60 |
+
Keeping $B_{\mu}$ lower and $A$ higher is often safer than a large micro-batch with the same effective batch size.
|
| 61 |
+
|
| 62 |
+
## Token Chunking
|
| 63 |
+
|
| 64 |
+
A naive full-vocabulary KL implementation materializes too much temporary state. Quintus computes KD over token chunks:
|
| 65 |
+
|
| 66 |
+
$$
|
| 67 |
+
C_{\text{KD}} = 2048
|
| 68 |
+
$$
|
| 69 |
+
|
| 70 |
+
Larger chunks reduce loop overhead but increase temporary memory. Smaller chunks save memory but can add kernel-launch and Python overhead. The final value is a B200-oriented balance for the 8B -> 1.7B workload.
|
| 71 |
+
|
| 72 |
+
## Sequence Packing
|
| 73 |
+
|
| 74 |
+
Sequence packing was the largest throughput win in development probes.
|
| 75 |
+
|
| 76 |
+
The packing strategy:
|
| 77 |
+
|
| 78 |
+
- Sort samples by length descending.
|
| 79 |
+
- Pack samples with deterministic first-fit decreasing binning.
|
| 80 |
+
- Insert EOS separators between samples.
|
| 81 |
+
- Set separator `loss_mask = 0`.
|
| 82 |
+
- Optionally mask the first token after each separator.
|
| 83 |
+
- Build `attention_mask` from true packed length, not from token identity.
|
| 84 |
+
|
| 85 |
+
The attention-mask detail matters because Qwen tokenizers can share EOS-like IDs with padding behavior. Deriving attention from `input_ids != pad_token_id` can accidentally mask real EOS separators inside packed rows.
|
| 86 |
+
|
| 87 |
+
Packing probes showed an unpacked B200 online-KD baseline around the low-20K tokens/sec range. Packed training reached roughly the mid-40K tokens/sec range after warmup. The final Qwen3 profile uses the same design principle with a conservative 8B -> 1.7B batch shape.
|
| 88 |
+
|
| 89 |
+
## B200-Oriented Final Shape
|
| 90 |
+
|
| 91 |
+
The Qwen3 config is intentionally conservative:
|
| 92 |
+
|
| 93 |
+
$$
|
| 94 |
+
B_{\mu}=4,\quad
|
| 95 |
+
A=2,\quad
|
| 96 |
+
B_{\text{eff}}=8,\quad
|
| 97 |
+
L_{\text{pack}}=4096
|
| 98 |
+
$$
|
| 99 |
+
|
| 100 |
+
Runtime choices:
|
| 101 |
+
|
| 102 |
+
- `gradient_checkpointing = false`
|
| 103 |
+
- `compile_model = false`
|
| 104 |
+
- `fused_adamw = true`
|
| 105 |
+
- `sequence_packing.enabled = true`
|
| 106 |
+
- FlashAttention-2 when available
|
| 107 |
+
- Liger kernels for compatible Qwen-family operators
|
| 108 |
+
|
| 109 |
+
The main reason is the 8B teacher plus 1.7B student online-KD footprint. A smaller teacher/student pair can use larger micro-batches, but the release workload reserves more headroom.
|
| 110 |
+
|
| 111 |
+
## Kernel Choices
|
| 112 |
+
|
| 113 |
+
FlashAttention-2 is the preferred stable attention path when available.
|
| 114 |
+
|
| 115 |
+
Liger kernels are useful for Qwen-family training, but KD places an important constraint on fusion:
|
| 116 |
+
|
| 117 |
+
- Safe to fuse: RMSNorm, RoPE, SwiGLU.
|
| 118 |
+
- Avoid for KD: fused linear cross entropy that hides raw student logits.
|
| 119 |
+
|
| 120 |
+
The KD loss needs raw student logits to compute teacher-student KL. Any optimization that bypasses logits entirely can break the objective.
|
| 121 |
+
|
| 122 |
+
## Why `torch.compile` Stayed Off
|
| 123 |
+
|
| 124 |
+
`torch.compile` can be useful for some SFT paths, but it was not the production choice for final KD.
|
| 125 |
+
|
| 126 |
+
Observed risks:
|
| 127 |
+
|
| 128 |
+
- Large Inductor memory overhead.
|
| 129 |
+
- Warmup cost on short-lived cloud instances.
|
| 130 |
+
- Dynamic-shape graph breaks from variable sequence lengths.
|
| 131 |
+
- Recompile overhead that reduced cumulative throughput in probes.
|
| 132 |
+
- `_orig_mod.` prefixes in saved checkpoints if compiled modules are not unwrapped before saving.
|
| 133 |
+
- Limited benefit after FlashAttention and Liger already fuse the major kernels.
|
| 134 |
+
|
| 135 |
+
For this workload, stable eager execution with targeted kernels was more predictable than compiler-driven fusion.
|
| 136 |
+
|
| 137 |
+
## DataLoader And Cloud Stability
|
| 138 |
+
|
| 139 |
+
Large worker counts can improve throughput on local systems, but notebook and cloud environments can deadlock through multiprocessing queues, IPC limits, or shared-memory pressure.
|
| 140 |
+
|
| 141 |
+
Practical policy:
|
| 142 |
+
|
| 143 |
+
- Start with conservative worker and prefetch settings.
|
| 144 |
+
- Treat a silent training hang as a DataLoader candidate, even when GPU utilization remains high.
|
| 145 |
+
- For some cloud notebook runs, `dataloader_workers = 0` was the most stable choice.
|
| 146 |
+
- For the release config, `dataloader_workers = 8` and `prefetch_factor = 2` are a controlled default, not a universal rule.
|
| 147 |
+
|
| 148 |
+
## Checkpointing And Resume
|
| 149 |
+
|
| 150 |
+
Cloud GPUs are preemptible and notebook sessions disappear. The training loop therefore treats checkpointing as a core training feature, not an afterthought.
|
| 151 |
+
|
| 152 |
+
Important design points:
|
| 153 |
+
|
| 154 |
+
- `best` is selected from validation loss where available.
|
| 155 |
+
- `last` is saved for final-state inspection.
|
| 156 |
+
- Step checkpoints can resume mid-epoch.
|
| 157 |
+
- Scheduler state is saved.
|
| 158 |
+
- Optimizer state may be intentionally omitted for very large runs to avoid massive checkpoint overhead.
|
| 159 |
+
- Resume semantics distinguish initialization from a completed checkpoint and continuation from an interrupted checkpoint.
|
| 160 |
+
|
| 161 |
+
This avoids the common trap where `resume_from_checkpoint` silently starts from the wrong phase or stale state.
|
| 162 |
+
|
| 163 |
+
## Provenance Rules
|
| 164 |
+
|
| 165 |
+
The pipeline is strict about artifact compatibility:
|
| 166 |
+
|
| 167 |
+
- Tokenizer vocabulary sizes must match the model contract.
|
| 168 |
+
- Teacher-logit metadata must match expected temperature, sample count, max sequence length, and tokenizer/model identity.
|
| 169 |
+
- Dataset fingerprints are preferred over path equality because paths are machine-local.
|
| 170 |
+
- Tokenizer fingerprints can drift across library versions, so hard checks should focus on vocab-size and schema invariants.
|
| 171 |
+
|
| 172 |
+
The principle is simple: train only when artifacts prove they belong together.
|
| 173 |
+
|
| 174 |
+
## Dataset Sampling
|
| 175 |
+
|
| 176 |
+
Taking the first N valid streamed examples can bias a run if the upstream dataset is ordered by source, task, difficulty, or language. Later configs added stream shuffling before selection.
|
| 177 |
+
|
| 178 |
+
The config uses a non-default seed:
|
| 179 |
+
|
| 180 |
+
```text
|
| 181 |
+
stream_shuffle_seed = 25
|
| 182 |
+
split_seed = 25
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
The number is intentionally explicit. Reproducibility needs stable seeds; it does not require the overused value `42`.
|
| 186 |
+
|
| 187 |
+
## Practical Watchpoints
|
| 188 |
+
|
| 189 |
+
During a run, these signals matter more than a single loss number:
|
| 190 |
+
|
| 191 |
+
- Loss stays finite from the first logging window.
|
| 192 |
+
- CE and KD move in plausible ranges.
|
| 193 |
+
- Rolling throughput remains stable after warmup.
|
| 194 |
+
- GPU memory is high but not near an unpredictable OOM edge.
|
| 195 |
+
- Validation loss is computed on the intended holdout.
|
| 196 |
+
- Saved checkpoints load in standard Transformers and vLLM paths.
|
| 197 |
+
- Downstream benchmark results agree with the training story.
|
| 198 |
+
|
| 199 |
+
Held-out KD loss is useful, but it is not the release gate. Standardized benchmarks and qualitative checks must decide whether the checkpoint improved the target behavior.
|
docs/weight_audit.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Weight Audit
|
| 2 |
+
|
| 3 |
+
The `weight_audit/` directory contains a structural audit script and a generated report comparing the final distilled checkpoint against `Qwen/Qwen3-1.7B-Base`.
|
| 4 |
+
|
| 5 |
+
The audit is not a behavioral benchmark. It answers a narrower question: is the checkpoint structurally intact, same-architecture, and plausibly modified by training without signs of collapse?
|
| 6 |
+
|
| 7 |
+
## What Was Checked
|
| 8 |
+
|
| 9 |
+
The audit verifies:
|
| 10 |
+
|
| 11 |
+
- Base and distilled checkpoint commits.
|
| 12 |
+
- Architecture and config compatibility.
|
| 13 |
+
- Parameter counts and tensor keys.
|
| 14 |
+
- Weight tying between embeddings and LM head.
|
| 15 |
+
- Per-tensor statistics.
|
| 16 |
+
- Layer-type aggregate statistics.
|
| 17 |
+
- Isotropy of 2D weight matrices.
|
| 18 |
+
- Base-vs-distilled divergence for all shared tensors.
|
| 19 |
+
- Sparsity, dead rows, low cosine similarity, and low SNR warnings.
|
| 20 |
+
|
| 21 |
+
## Headline Result
|
| 22 |
+
|
| 23 |
+
The final report shows:
|
| 24 |
+
|
| 25 |
+
```text
|
| 26 |
+
shared tensors : 311
|
| 27 |
+
tensors changed vs base : 277 / 311
|
| 28 |
+
cosine similarity : mean = 0.999991 | median = 0.999992
|
| 29 |
+
relative error : mean = 0.001093 | median = 0.001293
|
| 30 |
+
SNR dB : mean = 81.86 | median = 47.79
|
| 31 |
+
high-sparsity layers (>10%) : 0
|
| 32 |
+
heavy-tail layers (|kurt_d|>5.0) : 0
|
| 33 |
+
dead-row layers : 0
|
| 34 |
+
low-cos layers (<0.95) : 0
|
| 35 |
+
low-SNR layers (<20 dB) : 0
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Interpretation
|
| 39 |
+
|
| 40 |
+
This is a healthy pattern for light-touch distillation:
|
| 41 |
+
|
| 42 |
+
- The architecture is unchanged.
|
| 43 |
+
- Most tensors changed.
|
| 44 |
+
- The changes are small relative to the original base weights.
|
| 45 |
+
- Projection matrices, embeddings, and MLP/attention layers moved.
|
| 46 |
+
- Some normalization tensors remained unchanged or changed only slightly.
|
| 47 |
+
- No layer shows obvious structural collapse.
|
| 48 |
+
|
| 49 |
+
The unchanged tensors are primarily normalization-related weights. That is not concerning by itself. It suggests the main semantic projection weights absorbed the training signal while basic scaling structure stayed stable.
|
| 50 |
+
|
| 51 |
+
## Why Isotropy Matters
|
| 52 |
+
|
| 53 |
+
The report's global isotropy score is close to zero. Near-zero average pairwise row cosine means the weight rows are not collapsing into one shared direction.
|
| 54 |
+
|
| 55 |
+
This is useful as a sanity check after KD. A collapsed model can sometimes load and produce text, but its internal geometry becomes degenerate. The audit does not show that pattern.
|
| 56 |
+
|
| 57 |
+
## What The Audit Does Not Prove
|
| 58 |
+
|
| 59 |
+
The weight audit does not prove that answers are correct, safe, or well calibrated. It should be read alongside:
|
| 60 |
+
|
| 61 |
+
- Standard benchmarks.
|
| 62 |
+
- Open-ended qualitative evaluations.
|
| 63 |
+
- SFT evaluation outputs.
|
| 64 |
+
- Manual regression prompts.
|
| 65 |
+
|
| 66 |
+
The audit says the checkpoint is structurally ready for downstream evaluation and release packaging.
|
requirements-eval.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-r requirements.txt
|
| 2 |
+
|
| 3 |
+
# Consolidated benchmark runner dependencies.
|
| 4 |
+
evalplus>=0.3
|
| 5 |
+
lm-eval>=0.4.8
|
| 6 |
+
vllm>=0.8; platform_system == "Linux"
|
requirements-train.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-r requirements.txt
|
| 2 |
+
|
| 3 |
+
# Optional but recommended for the documented high-throughput training path.
|
| 4 |
+
# These packages are Linux/CUDA-oriented and may require matching compiler,
|
| 5 |
+
# CUDA, and PyTorch builds.
|
| 6 |
+
liger-kernel>=0.5
|
| 7 |
+
flash-attn>=2.7; platform_system == "Linux"
|
| 8 |
+
deepspeed>=0.16; platform_system == "Linux"
|
| 9 |
+
|
| 10 |
+
# Optional SFT/QLoRA paths in sft/train_sft.py.
|
| 11 |
+
peft>=0.14
|
| 12 |
+
bitsandbytes>=0.45; platform_system == "Linux"
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies for downloading, training entry points, local chat, and
|
| 2 |
+
# lightweight repository utilities. Install CUDA-specific PyTorch wheels from
|
| 3 |
+
# the official PyTorch index when your environment requires a specific CUDA
|
| 4 |
+
# build.
|
| 5 |
+
torch>=2.6
|
| 6 |
+
transformers>=4.52
|
| 7 |
+
datasets>=2.19
|
| 8 |
+
huggingface-hub>=0.31
|
| 9 |
+
omegaconf>=2.3
|
| 10 |
+
PyYAML>=6.0
|
| 11 |
+
safetensors>=0.4
|
| 12 |
+
accelerate>=1.0
|
| 13 |
+
tqdm>=4.66
|
sft/chat.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
| 3 |
+
import sys
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser(description="Quintus Interactive Chat")
|
| 8 |
+
parser.add_argument("--model_path", type=str, default="iamrahulreddy/Quintus", help="Model repo ID or local weights directory")
|
| 9 |
+
parser.add_argument("--trust_remote_code", action="store_true", help="Allow custom code from the model repository.")
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
model_path = args.model_path
|
| 13 |
+
print(f"Loading Quintus from {model_path}...")
|
| 14 |
+
try:
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=args.trust_remote_code)
|
| 16 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 17 |
+
model_path,
|
| 18 |
+
device_map="auto",
|
| 19 |
+
dtype=torch.float16,
|
| 20 |
+
trust_remote_code=args.trust_remote_code
|
| 21 |
+
)
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"Error loading model: {e}")
|
| 24 |
+
print(f"Ensure '{model_path}' exists and contains the model weights.")
|
| 25 |
+
sys.exit(1)
|
| 26 |
+
|
| 27 |
+
# Defining stopping criteria
|
| 28 |
+
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
|
| 29 |
+
eos_token_ids = [tokenizer.eos_token_id] if tokenizer.eos_token_id is not None else []
|
| 30 |
+
for token in stop_tokens:
|
| 31 |
+
t_id = tokenizer.convert_tokens_to_ids(token)
|
| 32 |
+
if t_id is not None and t_id not in eos_token_ids:
|
| 33 |
+
eos_token_ids.append(t_id)
|
| 34 |
+
|
| 35 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 36 |
+
|
| 37 |
+
conversation_history = [
|
| 38 |
+
{"role": "system", "content": "You are Quintus, a highly capable AI assistant created by Muskula Rahul. You are helpful, precise, and logically sound."}
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
print()
|
| 42 |
+
print("Quintus Chat (type 'quit' to exit)")
|
| 43 |
+
print()
|
| 44 |
+
|
| 45 |
+
while True:
|
| 46 |
+
try:
|
| 47 |
+
user_input = input("You: ").strip()
|
| 48 |
+
if user_input.lower() in ["quit", "exit"]:
|
| 49 |
+
print("\nGoodbye!")
|
| 50 |
+
break
|
| 51 |
+
if not user_input:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
conversation_history.append({"role": "user", "content": user_input})
|
| 55 |
+
|
| 56 |
+
prompt = tokenizer.apply_chat_template(
|
| 57 |
+
conversation_history,
|
| 58 |
+
tokenize=False,
|
| 59 |
+
add_generation_prompt=True
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 63 |
+
|
| 64 |
+
print("Quintus: ", end="", flush=True)
|
| 65 |
+
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
outputs = model.generate(
|
| 68 |
+
**inputs,
|
| 69 |
+
max_new_tokens=512,
|
| 70 |
+
temperature=0.7,
|
| 71 |
+
top_p=0.9,
|
| 72 |
+
do_sample=True,
|
| 73 |
+
streamer=streamer,
|
| 74 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 75 |
+
eos_token_id=eos_token_ids
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Extract response for history
|
| 79 |
+
generated_ids = outputs[0][inputs.input_ids.shape[-1]:]
|
| 80 |
+
assistant_response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 81 |
+
conversation_history.append({"role": "assistant", "content": assistant_response})
|
| 82 |
+
print()
|
| 83 |
+
|
| 84 |
+
except KeyboardInterrupt:
|
| 85 |
+
print("\n\nGoodbye!")
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
sft/evaluate.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automated EvalPlus runner for HumanEval and MBPP benchmarks.
|
| 2 |
+
# Using the vLLM backend in greedy mode.
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import subprocess
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from huggingface_hub import snapshot_download
|
| 13 |
+
|
| 14 |
+
MODELS = [
|
| 15 |
+
{
|
| 16 |
+
"name": "Quintus-1.7B",
|
| 17 |
+
"id": "iamrahulreddy/Quintus",
|
| 18 |
+
"is_local": False
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"name": "Qwen3-1.7B-Instruct",
|
| 22 |
+
"id": "Qwen/Qwen3-1.7B",
|
| 23 |
+
"is_local": False
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"name": "Qwen3-1.7B-Base",
|
| 27 |
+
"id": "Qwen/Qwen3-1.7B-Base",
|
| 28 |
+
"is_local": False
|
| 29 |
+
}
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
DATASETS = [
|
| 33 |
+
"humaneval", "mbpp", # EvalPlus benchmarks
|
| 34 |
+
"gsm8k", "winogrande", # lm-eval fast benchmarks
|
| 35 |
+
"arc_challenge", "boolq", "piqa"
|
| 36 |
+
]
|
| 37 |
+
EVALPLUS_DATASETS = {"humaneval", "mbpp"}
|
| 38 |
+
|
| 39 |
+
LM_EVAL_SHOTS = {
|
| 40 |
+
"gsm8k": "10",
|
| 41 |
+
"winogrande": "5",
|
| 42 |
+
"arc_challenge": "25",
|
| 43 |
+
"boolq": "0",
|
| 44 |
+
"piqa": "0"
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 48 |
+
TRUST_REMOTE_CODE = os.environ.get("QUINTUS_TRUST_REMOTE_CODE", "").strip().lower() in {"1", "true", "yes", "on"}
|
| 49 |
+
|
| 50 |
+
def extract_lm_eval_score(results_dir: Path, task: str) -> str:
|
| 51 |
+
"""Finds and extracts the primary score from JSON files outputted by lm-evaluation-harness."""
|
| 52 |
+
for json_path in sorted(results_dir.rglob("*.json"), reverse=True):
|
| 53 |
+
try:
|
| 54 |
+
with open(json_path, encoding="utf-8") as fh:
|
| 55 |
+
data = json.load(fh)
|
| 56 |
+
task_results = data.get("results", {})
|
| 57 |
+
for candidate in (task, f"leaderboard_{task}"):
|
| 58 |
+
if candidate in task_results:
|
| 59 |
+
task_data = task_results[candidate]
|
| 60 |
+
# Try common metric names
|
| 61 |
+
for metric in ["acc,none", "acc_norm,none", "exact_match,strict-match", "exact_match,none"]:
|
| 62 |
+
if metric in task_data:
|
| 63 |
+
return f"{task_data[metric]*100:.1f}"
|
| 64 |
+
except Exception:
|
| 65 |
+
continue
|
| 66 |
+
return "N/A"
|
| 67 |
+
|
| 68 |
+
def is_noise(line: str) -> bool:
|
| 69 |
+
l = line.strip()
|
| 70 |
+
if not l:
|
| 71 |
+
return False
|
| 72 |
+
# Progress bar indicators & block characters
|
| 73 |
+
if any(c in l for c in ["█", "━", "╸", "•", "━━━━━━━━"]):
|
| 74 |
+
return True
|
| 75 |
+
# vLLM, ray, flash_attn, huggingface setup/warnings logs
|
| 76 |
+
noise_keywords = [
|
| 77 |
+
"INFO ", "WARNING ", "DEBUG ", "ERROR ", "(EngineCore",
|
| 78 |
+
"Loading safetensors", "Capturing CUDA graphs",
|
| 79 |
+
"Codegen:", "Downloading dataset", "downloading dataset",
|
| 80 |
+
"Initializing a decoder", "Unknown vLLM environment",
|
| 81 |
+
"world_size=", "Using V2 Model Runner", "Model loading took",
|
| 82 |
+
"Using FLASH_ATTN", "Using FlashAttention", "Kernel JIT monitor",
|
| 83 |
+
"autotuner.py", "autotuning", "Autotuning", "loading weights",
|
| 84 |
+
"Loading weights", "Failed to get device capability", "Sanitized code outputs",
|
| 85 |
+
"Raw outputs will be saved", "init engine", "Dynamo bytecode",
|
| 86 |
+
"Directly load the compiled graph", "Directly load AOT compilation", "torch.compile took"
|
| 87 |
+
]
|
| 88 |
+
if any(k.lower() in l.lower() for k in noise_keywords):
|
| 89 |
+
return True
|
| 90 |
+
# TQDM lines (e.g. 100%|... [00:17<00:00, 9.45it/s])
|
| 91 |
+
if "%|" in l and ("it/s" in l or "s/it" in l):
|
| 92 |
+
return True
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
print("=" * 80)
|
| 97 |
+
print(" EVALPLUS BENCHMARK RUNNER (HUMANEVAL & MBPP)")
|
| 98 |
+
print("=" * 80)
|
| 99 |
+
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 100 |
+
print(f"Models to evaluate: {[m['name'] for m in MODELS]}")
|
| 101 |
+
print(f"Datasets: {DATASETS}")
|
| 102 |
+
print("=" * 80)
|
| 103 |
+
|
| 104 |
+
# Set optional HF token and runtime configuration.
|
| 105 |
+
if HF_TOKEN:
|
| 106 |
+
os.environ["HF_TOKEN"] = HF_TOKEN
|
| 107 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 108 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 109 |
+
os.environ["VLLM_MAX_MODEL_LEN"] = "4096"
|
| 110 |
+
|
| 111 |
+
# Step 1: Pre-download and prepare model caches
|
| 112 |
+
print("\n--- STAGE 1: WARMING UP MODEL WEIGHTS CACHE ---")
|
| 113 |
+
|
| 114 |
+
# Cache all models
|
| 115 |
+
for model in MODELS:
|
| 116 |
+
if model["is_local"]:
|
| 117 |
+
continue
|
| 118 |
+
print(f"\n[DOWNLOADING] Fetching cache for {model['name']} ({model['id']})...")
|
| 119 |
+
try:
|
| 120 |
+
snapshot_download(
|
| 121 |
+
repo_id=model["id"],
|
| 122 |
+
token=HF_TOKEN or None
|
| 123 |
+
)
|
| 124 |
+
print(f"[DOWNLOAD SUCCESS] {model['name']} is cached and ready.")
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"[DOWNLOAD WARNING] Could not pre-download model {model['name']} via snapshot_download: {e}")
|
| 127 |
+
print("The evaluation run will attempt to download it directly during execution.")
|
| 128 |
+
|
| 129 |
+
print("\n--- STAGE 2: SEQUENTIAL EVALPLUS EVALUATION ---")
|
| 130 |
+
results = []
|
| 131 |
+
|
| 132 |
+
# Run evaluations sequentially
|
| 133 |
+
for model in MODELS:
|
| 134 |
+
# Resolve path
|
| 135 |
+
model_path = str(Path(model["id"]).resolve()) if model["is_local"] else model["id"]
|
| 136 |
+
|
| 137 |
+
for dataset in DATASETS:
|
| 138 |
+
print(f"\n[STARTING] Evaluating {model['name']} on {dataset}...")
|
| 139 |
+
print("-" * 60)
|
| 140 |
+
|
| 141 |
+
if dataset in EVALPLUS_DATASETS:
|
| 142 |
+
cmd = [
|
| 143 |
+
sys.executable, "-m", "evalplus.evaluate",
|
| 144 |
+
"--model", model_path,
|
| 145 |
+
"--dataset", dataset,
|
| 146 |
+
"--backend", "vllm",
|
| 147 |
+
"--greedy"
|
| 148 |
+
]
|
| 149 |
+
else:
|
| 150 |
+
shots = LM_EVAL_SHOTS.get(dataset, "0")
|
| 151 |
+
out_dir = Path("eval_results") / model["name"] / dataset
|
| 152 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 153 |
+
|
| 154 |
+
model_args = (
|
| 155 |
+
f"pretrained={model_path},dtype=bfloat16,"
|
| 156 |
+
f"trust_remote_code={str(TRUST_REMOTE_CODE).lower()},"
|
| 157 |
+
"gpu_memory_utilization=0.9,max_model_len=4096"
|
| 158 |
+
)
|
| 159 |
+
cmd = [
|
| 160 |
+
sys.executable, "-m", "lm_eval",
|
| 161 |
+
"--model", "vllm",
|
| 162 |
+
"--model_args", model_args,
|
| 163 |
+
"--tasks", dataset,
|
| 164 |
+
"--num_fewshot", shots,
|
| 165 |
+
"--batch_size", "auto",
|
| 166 |
+
"--output_path", str(out_dir),
|
| 167 |
+
"--log_samples"
|
| 168 |
+
]
|
| 169 |
+
if dataset == "gsm8k":
|
| 170 |
+
cmd.extend(["--gen_kwargs", "max_gen_toks=512"])
|
| 171 |
+
|
| 172 |
+
print(f"Running command: {' '.join(cmd)}")
|
| 173 |
+
start_time = time.time()
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
# Run the command and stream output
|
| 177 |
+
process = subprocess.Popen(
|
| 178 |
+
cmd,
|
| 179 |
+
stdout=subprocess.PIPE,
|
| 180 |
+
stderr=subprocess.STDOUT,
|
| 181 |
+
text=True,
|
| 182 |
+
bufsize=1
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Stream and capture output (filtering out vLLM and progress bar noise)
|
| 186 |
+
stdout_text = ""
|
| 187 |
+
for line in process.stdout:
|
| 188 |
+
stdout_text += line
|
| 189 |
+
if not is_noise(line):
|
| 190 |
+
print(line, end="")
|
| 191 |
+
|
| 192 |
+
process.wait()
|
| 193 |
+
duration = time.time() - start_time
|
| 194 |
+
time.sleep(5) # Let OS/driver fully reclaim GPU VRAM before starting next subprocess
|
| 195 |
+
|
| 196 |
+
score_str = "N/A"
|
| 197 |
+
if process.returncode == 0:
|
| 198 |
+
print(f"[SUCCESS] Completed {model['name']} on {dataset} in {duration:.1f} seconds.")
|
| 199 |
+
|
| 200 |
+
# Parse scores
|
| 201 |
+
if dataset in EVALPLUS_DATASETS:
|
| 202 |
+
# Find all pass@1 scores
|
| 203 |
+
matches = re.findall(r"pass@1:\s+([0-9.]+)", stdout_text)
|
| 204 |
+
if len(matches) >= 2:
|
| 205 |
+
val0 = float(matches[0])
|
| 206 |
+
val1 = float(matches[1])
|
| 207 |
+
if val0 <= 1.0: val0 *= 100
|
| 208 |
+
if val1 <= 1.0: val1 *= 100
|
| 209 |
+
score_str = f"Base: {val0:.1f} | Plus: {val1:.1f}"
|
| 210 |
+
elif len(matches) == 1:
|
| 211 |
+
val0 = float(matches[0])
|
| 212 |
+
if val0 <= 1.0: val0 *= 100
|
| 213 |
+
score_str = f"Base: {val0:.1f}"
|
| 214 |
+
else:
|
| 215 |
+
score_str = extract_lm_eval_score(out_dir, dataset)
|
| 216 |
+
|
| 217 |
+
results.append({
|
| 218 |
+
"model": model["name"],
|
| 219 |
+
"dataset": dataset,
|
| 220 |
+
"status": "Success",
|
| 221 |
+
"score": score_str,
|
| 222 |
+
"duration": f"{duration/60:.1f} min"
|
| 223 |
+
})
|
| 224 |
+
else:
|
| 225 |
+
print(f"[ERROR] command failed with exit code {process.returncode}")
|
| 226 |
+
results.append({
|
| 227 |
+
"model": model["name"],
|
| 228 |
+
"dataset": dataset,
|
| 229 |
+
"status": f"Failed ({process.returncode})",
|
| 230 |
+
"score": "ERROR",
|
| 231 |
+
"duration": f"{duration/60:.1f} min"
|
| 232 |
+
})
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
duration = time.time() - start_time
|
| 236 |
+
print(f"[ERROR] Failed to run benchmark: {e}")
|
| 237 |
+
results.append({
|
| 238 |
+
"model": model["name"],
|
| 239 |
+
"dataset": dataset,
|
| 240 |
+
"status": f"Error",
|
| 241 |
+
"score": "ERROR",
|
| 242 |
+
"duration": f"{duration/60:.1f} min"
|
| 243 |
+
})
|
| 244 |
+
print("-" * 60)
|
| 245 |
+
|
| 246 |
+
# Print and save summary report
|
| 247 |
+
report_lines = []
|
| 248 |
+
report_lines.append("\n" + "=" * 100)
|
| 249 |
+
report_lines.append(" BENCHMARK RUN SUMMARY")
|
| 250 |
+
report_lines.append("=" * 100)
|
| 251 |
+
report_lines.append(f"| {'Model':<30} | {'Dataset':<15} | {'Score':<25} | {'Status':<10} | {'Time':<8} |")
|
| 252 |
+
report_lines.append(f"|{'-'*32}|{'-'*17}|{'-'*27}|{'-'*12}|{'-'*10}|")
|
| 253 |
+
for r in results:
|
| 254 |
+
report_lines.append(f"| {r['model']:<30} | {r['dataset']:<15} | {r['score']:<25} | {r['status']:<10} | {r['duration']:<8} |")
|
| 255 |
+
report_lines.append("=" * 100)
|
| 256 |
+
|
| 257 |
+
report_text = "\n".join(report_lines)
|
| 258 |
+
print(report_text)
|
| 259 |
+
print("\nNote: Results are saved in the default EvalPlus directory and eval_results/.")
|
| 260 |
+
|
| 261 |
+
# Save to file
|
| 262 |
+
with open("qwen_quintus_scores.txt", "w", encoding="utf-8") as f:
|
| 263 |
+
f.write(report_text + "\n")
|
| 264 |
+
print("\n[SUCCESS] Final score report saved to 'qwen_quintus_scores.txt'")
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
main()
|
sft/train_sft.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SFT Training and Downstream Evaluation Pipeline
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import gc
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import yaml
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch.utils.data import DataLoader, Dataset
|
| 19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
|
| 20 |
+
|
| 21 |
+
# Load Configuration
|
| 22 |
+
def load_config() -> dict:
|
| 23 |
+
cfg_path = Path(__file__).resolve().parent / "config.yaml"
|
| 24 |
+
if not cfg_path.exists():
|
| 25 |
+
return {}
|
| 26 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 27 |
+
return yaml.safe_load(f) or {}
|
| 28 |
+
|
| 29 |
+
cfg = load_config()
|
| 30 |
+
|
| 31 |
+
# PROMPTS (50 PROMPTS)
|
| 32 |
+
EASY_PROMPTS = [
|
| 33 |
+
"What is the capital of Japan, and what is it known for?",
|
| 34 |
+
"What does the term 'CPU' stand for, and what is its role in a computer?",
|
| 35 |
+
"Name three mammals that live primarily in water.",
|
| 36 |
+
"What is the difference between a virus and a bacterium?",
|
| 37 |
+
"Convert 72 degrees Fahrenheit to Celsius.",
|
| 38 |
+
"What is the purpose of a hash function?",
|
| 39 |
+
"What does HTTP stand for and what is it used for?",
|
| 40 |
+
"In which continent is the Amazon rainforest located?",
|
| 41 |
+
"What is the difference between RAM and ROM?",
|
| 42 |
+
"Name two programming languages commonly used for data science.",
|
| 43 |
+
"What is the function of the mitochondria in a cell?",
|
| 44 |
+
"What is a palindrome? Give two examples.",
|
| 45 |
+
"What is the difference between a compiler and an interpreter?",
|
| 46 |
+
"What unit is used to measure electrical resistance?",
|
| 47 |
+
"Name the four blood types in the ABO system.",
|
| 48 |
+
"What is the primary purpose of DNS in networking?",
|
| 49 |
+
"What does it mean for a function to be 'pure' in programming?"
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
MEDIUM_PROMPTS = [
|
| 53 |
+
"Explain the difference between supervised and unsupervised learning with a concrete example of each.",
|
| 54 |
+
"Write a Python function that takes a list of integers and returns all pairs that sum to a given target.",
|
| 55 |
+
"Explain how TCP/IP ensures reliable data delivery over an unreliable network.",
|
| 56 |
+
"What are the trade-offs between using a relational database and a document store for a user profile system?",
|
| 57 |
+
"Describe how gradient descent works and explain the role of the learning rate.",
|
| 58 |
+
"Write a SQL query that returns the top 5 customers by total order value, including customers with no orders.",
|
| 59 |
+
"What is the CAP theorem and what does it imply for distributed system design?",
|
| 60 |
+
"Explain the difference between process and thread, including when you would prefer one over the other.",
|
| 61 |
+
"How does HTTPS prevent a man-in-the-middle attack? Walk through the handshake at a high level.",
|
| 62 |
+
"Write a regex that validates an email address and annotate each part of the pattern.",
|
| 63 |
+
"What is the difference between memoization and dynamic programming?",
|
| 64 |
+
"Describe three ways to handle class imbalance in a machine learning dataset.",
|
| 65 |
+
"Explain what a foreign key constraint does and give an example of why it matters.",
|
| 66 |
+
"What is the difference between horizontal and vertical scaling, and when would you choose each?",
|
| 67 |
+
"How does Python's garbage collector handle circular references?",
|
| 68 |
+
"Explain the intuition behind the attention mechanism in Transformer models.",
|
| 69 |
+
"What is a race condition? Write a minimal pseudocode example that demonstrates one."
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
TOUGH_PROMPTS = [
|
| 73 |
+
"Design a rate limiter for a public API that must handle 100k requests per second across multiple regions. Describe the data structures, algorithms, and infrastructure trade-offs involved.",
|
| 74 |
+
"Explain why training very deep neural networks with sigmoid activations suffers from vanishing gradients. How do residual connections and normalization layers address this, and what are their respective limitations?",
|
| 75 |
+
"A message queue is consuming events from an upstream producer faster than a downstream consumer can process them. The queue is filling up and the producer cannot be slowed down. Describe at least three architectural strategies to resolve this, with trade-offs.",
|
| 76 |
+
"Given an undirected weighted graph, write Python code to find the minimum spanning tree using Kruskal's algorithm. Include the union-find data structure. Analyze time and space complexity.",
|
| 77 |
+
"You are given two sorted arrays of size m and n. Find the median of the combined array in O(log(m+n)) time. Explain the approach before writing the code.",
|
| 78 |
+
"Explain the difference between Byzantine fault tolerance and crash fault tolerance. In what scenario does the distinction become critical, and how does a consensus protocol like PBFT address Byzantine failures?",
|
| 79 |
+
"A large language model fine-tuned on customer service data starts producing confident but factually wrong answers about product details. Propose a complete mitigation strategy covering training, inference, and deployment layers.",
|
| 80 |
+
"Explain the mechanism behind speculative execution in modern CPUs and how it led to the Spectre vulnerability. What classes of software-level mitigations exist and what performance cost do they carry?",
|
| 81 |
+
"Design a schema and indexing strategy for a social graph where you need to efficiently answer: (1) mutual friends between two users, (2) shortest path between two users, (3) top-k most influential accounts. Justify your choices.",
|
| 82 |
+
"Implement a thread-safe LRU cache in Python with O(1) get and put operations. Explain why your synchronization approach is correct and where contention bottlenecks might appear under high concurrency.",
|
| 83 |
+
"Explain the difference between weak, strong, and eventual consistency in distributed databases. Give a concrete example of a bug that arises when a developer assumes strong consistency but the system only guarantees eventual consistency.",
|
| 84 |
+
"You are designing the storage layer for a time-series database that ingests 1 million data points per second and must support range queries going back 2 years. Describe compression strategies, write amplification concerns, and compaction trade-offs.",
|
| 85 |
+
"Explain how LoRA (Low-Rank Adaptation) reduces the number of trainable parameters in fine-tuning. Derive why a weight update matrix can be approximated as a product of two low-rank matrices and discuss what is lost in this approximation.",
|
| 86 |
+
"A binary tree is given where each node has a value. Write an algorithm to find the maximum path sum between any two nodes (not necessarily leaf nodes). Prove the correctness of your recurrence relation.",
|
| 87 |
+
"Explain the economic concept of Goodhart's Law and give three examples of how it manifests in AI system evaluation.",
|
| 88 |
+
"Describe the full lifecycle of a memory allocation in a system using jemalloc or tcmalloc. How do thread-local caches, size classes, and slab allocation interact, and what are the implications for long-running server processes?"
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
ALL_PROMPTS = []
|
| 92 |
+
for p in EASY_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "EASY"})
|
| 93 |
+
for p in MEDIUM_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "MEDIUM"})
|
| 94 |
+
for p in TOUGH_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "TOUGH"})
|
| 95 |
+
|
| 96 |
+
# UTILITIES AND DATASET LOADERS
|
| 97 |
+
class SFTDataset(Dataset):
|
| 98 |
+
def __init__(self, file_path: str, max_samples: int = -1):
|
| 99 |
+
self.samples = []
|
| 100 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 101 |
+
for line in f:
|
| 102 |
+
if 0 < max_samples <= len(self.samples):
|
| 103 |
+
break
|
| 104 |
+
self.samples.append(json.loads(line))
|
| 105 |
+
print(f"Loaded {len(self.samples)} SFT samples from {file_path}")
|
| 106 |
+
|
| 107 |
+
def __len__(self) -> int:
|
| 108 |
+
return len(self.samples)
|
| 109 |
+
|
| 110 |
+
def __getitem__(self, idx: int) -> dict:
|
| 111 |
+
return self.samples[idx]
|
| 112 |
+
|
| 113 |
+
def pack_sequences(samples: list[dict], pack_length: int, pad_token_id: int, eos_token_id: int) -> list[dict]:
|
| 114 |
+
"""Sort and pack short samples into fixed-size bins (FFD packing) to accelerate training."""
|
| 115 |
+
print(f"Packing sequences into {pack_length}-token bins...")
|
| 116 |
+
# Sort samples by input_ids length descending
|
| 117 |
+
indexed_samples = sorted(
|
| 118 |
+
samples,
|
| 119 |
+
key=lambda x: len(x["input_ids"]),
|
| 120 |
+
reverse=True
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
bins: list[list[dict]] = []
|
| 124 |
+
bin_lengths: list[int] = []
|
| 125 |
+
|
| 126 |
+
for sample in indexed_samples:
|
| 127 |
+
s_len = len(sample["input_ids"])
|
| 128 |
+
if s_len > pack_length:
|
| 129 |
+
sample["input_ids"] = sample["input_ids"][:pack_length]
|
| 130 |
+
sample["loss_mask"] = sample["loss_mask"][:pack_length]
|
| 131 |
+
s_len = pack_length
|
| 132 |
+
|
| 133 |
+
# Try to place sample into an existing bin
|
| 134 |
+
placed = False
|
| 135 |
+
for b_idx in range(len(bins)):
|
| 136 |
+
needed = s_len + (1 if len(bins[b_idx]) > 0 else 0)
|
| 137 |
+
if bin_lengths[b_idx] + needed <= pack_length:
|
| 138 |
+
bins[b_idx].append(sample)
|
| 139 |
+
bin_lengths[b_idx] += needed
|
| 140 |
+
placed = True
|
| 141 |
+
break
|
| 142 |
+
|
| 143 |
+
if not placed:
|
| 144 |
+
bins.append([sample])
|
| 145 |
+
bin_lengths.append(s_len)
|
| 146 |
+
|
| 147 |
+
# Convert packed bins to training formats
|
| 148 |
+
packed_samples = []
|
| 149 |
+
for b in bins:
|
| 150 |
+
input_ids = []
|
| 151 |
+
loss_mask = []
|
| 152 |
+
for i, sample in enumerate(b):
|
| 153 |
+
if i > 0:
|
| 154 |
+
input_ids.append(eos_token_id)
|
| 155 |
+
loss_mask.append(0) # Mask out the EOS separator token
|
| 156 |
+
input_ids.extend(sample["input_ids"])
|
| 157 |
+
loss_mask.extend(sample["loss_mask"])
|
| 158 |
+
|
| 159 |
+
real_len = len(input_ids)
|
| 160 |
+
pad_len = pack_length - real_len
|
| 161 |
+
if pad_len > 0:
|
| 162 |
+
input_ids.extend([pad_token_id] * pad_len)
|
| 163 |
+
loss_mask.extend([0] * pad_len)
|
| 164 |
+
|
| 165 |
+
packed_samples.append({
|
| 166 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 167 |
+
"loss_mask": torch.tensor(loss_mask, dtype=torch.long),
|
| 168 |
+
"attention_mask": torch.cat([
|
| 169 |
+
torch.ones(real_len, dtype=torch.long),
|
| 170 |
+
torch.zeros(pad_len, dtype=torch.long)
|
| 171 |
+
])
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
utilization = sum(bin_lengths) / (len(bins) * pack_length)
|
| 175 |
+
print(f"Packed {len(samples)} samples into {len(bins)} bins. Utilization: {utilization * 100:.2f}%")
|
| 176 |
+
return packed_samples
|
| 177 |
+
|
| 178 |
+
def collate_sft(batch: list[dict], pad_token_id: int) -> dict:
|
| 179 |
+
"""Collates batch for standard unpacked training, dynamically padding batch to max length."""
|
| 180 |
+
max_len = max(len(s["input_ids"]) for s in batch)
|
| 181 |
+
input_ids_list = []
|
| 182 |
+
attention_mask_list = []
|
| 183 |
+
labels_list = []
|
| 184 |
+
|
| 185 |
+
for s in batch:
|
| 186 |
+
ids = s["input_ids"]
|
| 187 |
+
mask = s["loss_mask"]
|
| 188 |
+
pad_len = max_len - len(ids)
|
| 189 |
+
|
| 190 |
+
padded_ids = ids + [pad_token_id] * pad_len
|
| 191 |
+
padded_labels = [ids[i] if mask[i] == 1 else -100 for i in range(len(ids))] + [-100] * pad_len
|
| 192 |
+
|
| 193 |
+
input_ids_list.append(torch.tensor(padded_ids, dtype=torch.long))
|
| 194 |
+
attention_mask_list.append(torch.tensor([1] * len(ids) + [0] * pad_len, dtype=torch.long))
|
| 195 |
+
labels_list.append(torch.tensor(padded_labels, dtype=torch.long))
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"input_ids": torch.stack(input_ids_list),
|
| 199 |
+
"attention_mask": torch.stack(attention_mask_list),
|
| 200 |
+
"labels": torch.stack(labels_list)
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
def collate_packed(batch: list[dict]) -> dict:
|
| 204 |
+
"""Collates pre-packed sequence bins by simple stacking."""
|
| 205 |
+
input_ids = torch.stack([item["input_ids"] for item in batch])
|
| 206 |
+
attention_mask = torch.stack([item["attention_mask"] for item in batch])
|
| 207 |
+
loss_mask = torch.stack([item["loss_mask"] for item in batch])
|
| 208 |
+
|
| 209 |
+
labels = input_ids.clone()
|
| 210 |
+
labels = labels.masked_fill(loss_mask == 0, -100)
|
| 211 |
+
|
| 212 |
+
return {
|
| 213 |
+
"input_ids": input_ids,
|
| 214 |
+
"attention_mask": attention_mask,
|
| 215 |
+
"labels": labels
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
# PARSING AND MAIN LOGIC
|
| 219 |
+
def parse_args() -> argparse.Namespace:
|
| 220 |
+
parser = argparse.ArgumentParser(description="Clean SFT training and evaluation suite")
|
| 221 |
+
parser.add_argument("--student_model", type=str, default=cfg.get("model", {}).get("student", "Qwen/Qwen3-1.7B-Base"))
|
| 222 |
+
parser.add_argument("--tokenizer_model", type=str, default=cfg.get("model", {}).get("tokenizer", "Qwen/Qwen3-1.7B"))
|
| 223 |
+
parser.add_argument("--data_repo", type=str, default=os.environ.get("QUINTUS_SFT_DATA_REPO"), help="HF dataset repo containing train_sft.jsonl. Optional when data/tokenized/train_sft.jsonl exists.")
|
| 224 |
+
parser.add_argument("--token", type=str, default=None)
|
| 225 |
+
parser.add_argument("--trust_remote_code", action="store_true", help="Allow custom code from model/tokenizer repositories.")
|
| 226 |
+
|
| 227 |
+
parser.add_argument("--num_epochs", type=int, default=1)
|
| 228 |
+
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
| 229 |
+
parser.add_argument("--micro_batch_size", type=int, default=4)
|
| 230 |
+
parser.add_argument("--grad_accum_steps", type=int, default=2)
|
| 231 |
+
parser.add_argument("--max_seq_len", type=int, default=4096)
|
| 232 |
+
parser.add_argument("--sequence_packing", action="store_true", default=True)
|
| 233 |
+
parser.add_argument("--no_sequence_packing", action="store_false", dest="sequence_packing")
|
| 234 |
+
|
| 235 |
+
parser.add_argument("--output_dir", type=str, default="quintus_sft_output")
|
| 236 |
+
|
| 237 |
+
parser.add_argument("--run_prompt_suite", action="store_true", default=True)
|
| 238 |
+
parser.add_argument("--no_prompt_suite", action="store_false", dest="run_prompt_suite")
|
| 239 |
+
parser.add_argument("--run_gsm8k", action="store_true", default=True)
|
| 240 |
+
parser.add_argument("--no_gsm8k", action="store_false", dest="run_gsm8k")
|
| 241 |
+
parser.add_argument("--gsm8k_samples", type=int, default=100)
|
| 242 |
+
|
| 243 |
+
parser.add_argument("--optim", type=str, choices=["adamw", "adamw_8bit"], default="adamw")
|
| 244 |
+
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
|
| 245 |
+
parser.add_argument("--load_in_4bit", action="store_true", default=False)
|
| 246 |
+
parser.add_argument("--use_lora", action="store_true", default=False)
|
| 247 |
+
parser.add_argument("--lora_r", type=int, default=8)
|
| 248 |
+
parser.add_argument("--lora_alpha", type=int, default=16)
|
| 249 |
+
|
| 250 |
+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Automatically push fine-tuned model to Hugging Face Hub after training")
|
| 251 |
+
parser.add_argument("--hub_model_id", type=str, default="iamrahulreddy/Quintus", help="Target Hugging Face Hub repository ID")
|
| 252 |
+
|
| 253 |
+
return parser.parse_args()
|
| 254 |
+
|
| 255 |
+
def download_hf_dataset(repo_id: str | None, token: str | None) -> str:
|
| 256 |
+
print(f"Checking for tokenized dataset in local folders...")
|
| 257 |
+
local_path = "data/tokenized/train_sft.jsonl"
|
| 258 |
+
if os.path.exists(local_path):
|
| 259 |
+
print(f"Found local dataset: {local_path}")
|
| 260 |
+
return local_path
|
| 261 |
+
|
| 262 |
+
if not repo_id:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
"No local SFT dataset found at data/tokenized/train_sft.jsonl. "
|
| 265 |
+
"Pass --data_repo or set QUINTUS_SFT_DATA_REPO."
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
print(f"Local file not found. Pulling from Hugging Face: {repo_id}...")
|
| 269 |
+
from huggingface_hub import hf_hub_download
|
| 270 |
+
os.makedirs("data/tokenized", exist_ok=True)
|
| 271 |
+
downloaded = hf_hub_download(
|
| 272 |
+
repo_id=repo_id,
|
| 273 |
+
filename="train_sft.jsonl",
|
| 274 |
+
repo_type="dataset",
|
| 275 |
+
local_dir="data/tokenized",
|
| 276 |
+
token=token
|
| 277 |
+
)
|
| 278 |
+
# Ensure correct local path layout
|
| 279 |
+
if os.path.exists(downloaded) and downloaded != local_path:
|
| 280 |
+
os.rename(downloaded, local_path)
|
| 281 |
+
print(f"Dataset downloaded to: {local_path}")
|
| 282 |
+
return local_path
|
| 283 |
+
|
| 284 |
+
# DOWNSTREAM EVALUATION CODE
|
| 285 |
+
def run_prompt_suite(model, tokenizer, device, output_dir: str):
|
| 286 |
+
print("\n" + "="*70)
|
| 287 |
+
print("RUNNING QUALITATIVE PROMPT SUITE (50 Prompts)")
|
| 288 |
+
print("="*70)
|
| 289 |
+
|
| 290 |
+
# Compile stop token IDs
|
| 291 |
+
eos_token_ids = [tokenizer.eos_token_id]
|
| 292 |
+
for token in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
|
| 293 |
+
t_id = tokenizer.convert_tokens_to_ids(token)
|
| 294 |
+
if t_id is not None and t_id != tokenizer.unk_token_id:
|
| 295 |
+
eos_token_ids.append(t_id)
|
| 296 |
+
eos_token_ids = list(set(eos_token_ids))
|
| 297 |
+
|
| 298 |
+
# Initialize output file
|
| 299 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 300 |
+
out_path = os.path.join(output_dir, f"prompt_suite_eval_{timestamp}.txt")
|
| 301 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 302 |
+
|
| 303 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 304 |
+
f.write("QUINTUS SFT POST-TRAINING PROMPT SUITE\n")
|
| 305 |
+
f.write(f"Timestamp: {timestamp}\n")
|
| 306 |
+
f.write("="*72 + "\n\n")
|
| 307 |
+
f.flush()
|
| 308 |
+
|
| 309 |
+
# Set padding side to left for batch generation
|
| 310 |
+
orig_padding_side = tokenizer.padding_side
|
| 311 |
+
tokenizer.padding_side = "left"
|
| 312 |
+
|
| 313 |
+
batch_size = 16
|
| 314 |
+
for i in range(0, len(ALL_PROMPTS), batch_size):
|
| 315 |
+
batch_items = ALL_PROMPTS[i : i + batch_size]
|
| 316 |
+
|
| 317 |
+
# Format prompts
|
| 318 |
+
formatted_prompts = []
|
| 319 |
+
for item in batch_items:
|
| 320 |
+
prompt_text = item["text"]
|
| 321 |
+
if tokenizer.chat_template is not None:
|
| 322 |
+
prompt_str = tokenizer.apply_chat_template(
|
| 323 |
+
[{"role": "user", "content": prompt_text}],
|
| 324 |
+
tokenize=False, add_generation_prompt=True
|
| 325 |
+
)
|
| 326 |
+
else:
|
| 327 |
+
prompt_str = f"<|im_start|>user\n{prompt_text}<|im_end|>\n<|im_start|>assistant\n"
|
| 328 |
+
formatted_prompts.append(prompt_str)
|
| 329 |
+
|
| 330 |
+
# Tokenize with padding
|
| 331 |
+
inputs = tokenizer(formatted_prompts, padding=True, return_tensors="pt").to(device)
|
| 332 |
+
|
| 333 |
+
with torch.no_grad():
|
| 334 |
+
outputs = model.generate(
|
| 335 |
+
**inputs,
|
| 336 |
+
max_new_tokens=2048,
|
| 337 |
+
do_sample=False, # Greedy for clean, reproducible comparison
|
| 338 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 339 |
+
eos_token_id=eos_token_ids
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Decode and write results in real-time
|
| 343 |
+
for idx, item in enumerate(batch_items):
|
| 344 |
+
input_len = inputs["input_ids"][idx].shape[0]
|
| 345 |
+
gen_tokens = outputs[idx][input_len:]
|
| 346 |
+
|
| 347 |
+
# Slice at the first EOS token
|
| 348 |
+
eos_indices = []
|
| 349 |
+
for eos_id in eos_token_ids:
|
| 350 |
+
indices = (gen_tokens == eos_id).nonzero(as_tuple=True)[0]
|
| 351 |
+
if len(indices) > 0:
|
| 352 |
+
eos_indices.append(indices[0].item())
|
| 353 |
+
if eos_indices:
|
| 354 |
+
gen_tokens = gen_tokens[:min(eos_indices)]
|
| 355 |
+
|
| 356 |
+
response = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
|
| 357 |
+
|
| 358 |
+
# Log progress
|
| 359 |
+
global_idx = i + idx + 1
|
| 360 |
+
print(f"[{global_idx:02d}/50] ({item['difficulty']}) Q: {item['text'][:40]}... -> Answered ({len(gen_tokens)} tokens)")
|
| 361 |
+
|
| 362 |
+
# Append directly to output file
|
| 363 |
+
with open(out_path, "a", encoding="utf-8") as f:
|
| 364 |
+
f.write(f"[{global_idx:02d}/50] {item['difficulty']}\n")
|
| 365 |
+
f.write(f"Q: {item['text']}\n\n")
|
| 366 |
+
f.write(f"Response:\n{response}\n")
|
| 367 |
+
f.write("\n" + "-"*72 + "\n\n")
|
| 368 |
+
f.flush()
|
| 369 |
+
|
| 370 |
+
# Restore original tokenizer settings
|
| 371 |
+
tokenizer.padding_side = orig_padding_side
|
| 372 |
+
print(f"\nPrompt suite evaluation complete. Saved report to: {out_path}\n")
|
| 373 |
+
|
| 374 |
+
def extract_gsm8k_answer(text: str) -> str | None:
|
| 375 |
+
text = text.replace(",", "")
|
| 376 |
+
match = re.findall(r"The answer is\s*:?\s*(-?\d+)", text, re.IGNORECASE)
|
| 377 |
+
if match:
|
| 378 |
+
return match[-1]
|
| 379 |
+
match = re.findall(r"(-?\d+)", text)
|
| 380 |
+
if match:
|
| 381 |
+
return match[-1]
|
| 382 |
+
return None
|
| 383 |
+
|
| 384 |
+
def run_gsm8k_eval(model, tokenizer, device, num_samples: int = 100):
|
| 385 |
+
print("\n" + "="*70)
|
| 386 |
+
print(f"RUNNING GSM8K MATH EVALUATION ({num_samples} Samples)")
|
| 387 |
+
print("="*70)
|
| 388 |
+
|
| 389 |
+
from datasets import load_dataset
|
| 390 |
+
try:
|
| 391 |
+
dataset = load_dataset("openai/gsm8k", "main", split="test")
|
| 392 |
+
except Exception as e:
|
| 393 |
+
print(f"Warning: Could not download GSM8K test set directly: {e}")
|
| 394 |
+
return
|
| 395 |
+
|
| 396 |
+
dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))
|
| 397 |
+
|
| 398 |
+
correct = 0
|
| 399 |
+
total = 0
|
| 400 |
+
|
| 401 |
+
for idx, item in enumerate(dataset):
|
| 402 |
+
question = item["question"]
|
| 403 |
+
answer = item["answer"]
|
| 404 |
+
|
| 405 |
+
target_match = re.search(r"####\s*(-?\d+)", answer)
|
| 406 |
+
if not target_match:
|
| 407 |
+
continue
|
| 408 |
+
target_val = target_match.group(1)
|
| 409 |
+
|
| 410 |
+
if tokenizer.chat_template is not None:
|
| 411 |
+
prompt = tokenizer.apply_chat_template(
|
| 412 |
+
[{"role": "user", "content": question + "\nShow your work and conclude with 'The answer is: <number>'."}],
|
| 413 |
+
tokenize=False, add_generation_prompt=True
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
prompt = f"<|im_start|>user\n{question}\nShow your work and conclude with 'The answer is: <number>'.<|im_end|>\n<|im_start|>assistant\n"
|
| 417 |
+
|
| 418 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 419 |
+
|
| 420 |
+
with torch.no_grad():
|
| 421 |
+
outputs = model.generate(
|
| 422 |
+
**inputs,
|
| 423 |
+
max_new_tokens=1024,
|
| 424 |
+
do_sample=False,
|
| 425 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 426 |
+
eos_token_id=tokenizer.eos_token_id
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
gen_tokens = outputs[0][inputs.input_ids.shape[1]:]
|
| 430 |
+
generated_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
|
| 431 |
+
|
| 432 |
+
pred_val = extract_gsm8k_answer(generated_text)
|
| 433 |
+
is_match = (pred_val == target_val)
|
| 434 |
+
|
| 435 |
+
if is_match:
|
| 436 |
+
correct += 1
|
| 437 |
+
total += 1
|
| 438 |
+
|
| 439 |
+
# Log sample output periodically
|
| 440 |
+
if idx % 10 == 0:
|
| 441 |
+
print(f"\n[GSM8K Sample {idx+1}]")
|
| 442 |
+
print(f"Q: {question[:80]}...")
|
| 443 |
+
print(f"A: {generated_text[:120]}... (Target: {target_val} | Pred: {pred_val})")
|
| 444 |
+
print(f"Match: {is_match}")
|
| 445 |
+
|
| 446 |
+
accuracy = (correct / total * 100) if total > 0 else 0
|
| 447 |
+
print("\n" + "="*70)
|
| 448 |
+
print(f"GSM8K EVALUATION SUMMARY: {correct}/{total} Correct -> Accuracy: {accuracy:.2f}%")
|
| 449 |
+
print("="*70 + "\n")
|
| 450 |
+
|
| 451 |
+
# TRAINING PIPELINE
|
| 452 |
+
def main() -> None:
|
| 453 |
+
args = parse_args()
|
| 454 |
+
|
| 455 |
+
# Propagate HF token to environment for auto-authentication of downstream hub calls
|
| 456 |
+
try:
|
| 457 |
+
import huggingface_hub
|
| 458 |
+
cached_token = huggingface_hub.get_token()
|
| 459 |
+
except Exception:
|
| 460 |
+
cached_token = None
|
| 461 |
+
resolved_token = os.environ.get("HF_TOKEN") or cached_token or args.token
|
| 462 |
+
if resolved_token:
|
| 463 |
+
os.environ["HF_TOKEN"] = resolved_token
|
| 464 |
+
args.token = resolved_token
|
| 465 |
+
|
| 466 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 467 |
+
print(f"SFT Environment initialized. Target device: {device}")
|
| 468 |
+
|
| 469 |
+
# 1. Pull dataset from HF
|
| 470 |
+
try:
|
| 471 |
+
dataset_file = download_hf_dataset(args.data_repo, args.token)
|
| 472 |
+
except ValueError as exc:
|
| 473 |
+
print(f"Error: {exc}")
|
| 474 |
+
sys.exit(1)
|
| 475 |
+
|
| 476 |
+
# 2. Setup Tokenizer and Model
|
| 477 |
+
print(f"Loading tokenizer: {args.tokenizer_model}")
|
| 478 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_model, trust_remote_code=args.trust_remote_code)
|
| 479 |
+
if tokenizer.pad_token is None:
|
| 480 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 481 |
+
|
| 482 |
+
# 4-bit configuration if requested
|
| 483 |
+
bnb_config = None
|
| 484 |
+
if args.load_in_4bit:
|
| 485 |
+
from transformers import BitsAndBytesConfig
|
| 486 |
+
bnb_config = BitsAndBytesConfig(
|
| 487 |
+
load_in_4bit=True,
|
| 488 |
+
bnb_4bit_compute_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
|
| 489 |
+
bnb_4bit_quant_type="nf4",
|
| 490 |
+
bnb_4bit_use_double_quant=True
|
| 491 |
+
)
|
| 492 |
+
print("Using 4-bit BitsAndBytes quantization.")
|
| 493 |
+
|
| 494 |
+
# Liger Kernel (skipped for 4-bit/PEFT as it can interfere with quantized layers)
|
| 495 |
+
if not args.load_in_4bit:
|
| 496 |
+
try:
|
| 497 |
+
from liger_kernel.transformers import apply_liger_kernel_to_qwen3
|
| 498 |
+
apply_liger_kernel_to_qwen3(
|
| 499 |
+
rope=True,
|
| 500 |
+
swiglu=True,
|
| 501 |
+
rms_norm=True,
|
| 502 |
+
cross_entropy=False,
|
| 503 |
+
fused_linear_cross_entropy=False,
|
| 504 |
+
)
|
| 505 |
+
print("Liger Kernel optimizations applied successfully.")
|
| 506 |
+
except ImportError:
|
| 507 |
+
print("Liger Kernel not installed, skipping optimizations.")
|
| 508 |
+
|
| 509 |
+
attn_impl = "sdpa"
|
| 510 |
+
if device.type == "cuda":
|
| 511 |
+
try:
|
| 512 |
+
import flash_attn
|
| 513 |
+
attn_impl = "flash_attention_2"
|
| 514 |
+
print("FlashAttention-2 enabled.")
|
| 515 |
+
except ImportError:
|
| 516 |
+
print("flash-attn not installed, falling back to SDPA.")
|
| 517 |
+
|
| 518 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 519 |
+
args.student_model,
|
| 520 |
+
quantization_config=bnb_config,
|
| 521 |
+
dtype=torch.bfloat16 if device.type == "cuda" else torch.float32,
|
| 522 |
+
trust_remote_code=args.trust_remote_code,
|
| 523 |
+
attn_implementation=attn_impl
|
| 524 |
+
)
|
| 525 |
+
if not args.load_in_4bit:
|
| 526 |
+
model = model.to(device)
|
| 527 |
+
model.config.use_cache = False
|
| 528 |
+
|
| 529 |
+
# Wrap with LoRA if requested or required for 4-bit training
|
| 530 |
+
if args.use_lora or args.load_in_4bit:
|
| 531 |
+
try:
|
| 532 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 533 |
+
if args.load_in_4bit:
|
| 534 |
+
model = prepare_model_for_kbit_training(model)
|
| 535 |
+
|
| 536 |
+
peft_config = LoraConfig(
|
| 537 |
+
r=args.lora_r,
|
| 538 |
+
lora_alpha=args.lora_alpha,
|
| 539 |
+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 540 |
+
lora_dropout=0.05,
|
| 541 |
+
bias="none",
|
| 542 |
+
task_type="CAUSAL_LM"
|
| 543 |
+
)
|
| 544 |
+
model = get_peft_model(model, peft_config)
|
| 545 |
+
print("LoRA adapters successfully attached to target modules.")
|
| 546 |
+
model.print_trainable_parameters()
|
| 547 |
+
except ImportError:
|
| 548 |
+
print("Error: peft not installed. Please run `!pip install -q peft` to use LoRA/QLoRA.")
|
| 549 |
+
sys.exit(1)
|
| 550 |
+
|
| 551 |
+
if args.gradient_checkpointing:
|
| 552 |
+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
| 553 |
+
print("Gradient checkpointing enabled.")
|
| 554 |
+
|
| 555 |
+
# 3. Prepare dataset
|
| 556 |
+
raw_dataset = SFTDataset(dataset_file)
|
| 557 |
+
|
| 558 |
+
if args.sequence_packing:
|
| 559 |
+
packed_samples = pack_sequences(
|
| 560 |
+
raw_dataset.samples,
|
| 561 |
+
pack_length=args.max_seq_len,
|
| 562 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 563 |
+
eos_token_id=tokenizer.eos_token_id
|
| 564 |
+
)
|
| 565 |
+
train_dataloader = DataLoader(
|
| 566 |
+
packed_samples,
|
| 567 |
+
batch_size=args.micro_batch_size,
|
| 568 |
+
shuffle=True,
|
| 569 |
+
collate_fn=collate_packed
|
| 570 |
+
)
|
| 571 |
+
else:
|
| 572 |
+
train_dataloader = DataLoader(
|
| 573 |
+
raw_dataset,
|
| 574 |
+
batch_size=args.micro_batch_size,
|
| 575 |
+
shuffle=True,
|
| 576 |
+
collate_fn=lambda b: collate_sft(b, tokenizer.pad_token_id)
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# 4. Optimizer and scheduler setup
|
| 580 |
+
if args.optim == "adamw_8bit":
|
| 581 |
+
try:
|
| 582 |
+
import bitsandbytes as bnb
|
| 583 |
+
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
|
| 584 |
+
print("Using BitsAndBytes 8-bit AdamW optimizer.")
|
| 585 |
+
except ImportError:
|
| 586 |
+
print("Warning: bitsandbytes not installed. Falling back to standard AdamW.")
|
| 587 |
+
use_fused = (device.type == "cuda")
|
| 588 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1, fused=use_fused)
|
| 589 |
+
else:
|
| 590 |
+
use_fused = (device.type == "cuda")
|
| 591 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1, fused=use_fused)
|
| 592 |
+
print(f"Using standard AdamW optimizer (fused={use_fused}).")
|
| 593 |
+
steps_per_epoch = (len(train_dataloader) + args.grad_accum_steps - 1) // args.grad_accum_steps
|
| 594 |
+
total_steps = steps_per_epoch * args.num_epochs
|
| 595 |
+
warmup_steps = int(total_steps * 0.05)
|
| 596 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
|
| 597 |
+
|
| 598 |
+
# 5. Training Loop
|
| 599 |
+
print("\n" + "="*70)
|
| 600 |
+
print(f"STARTING SFT TRAINING (Epochs: {args.num_epochs} | Steps: {total_steps})")
|
| 601 |
+
print("="*70)
|
| 602 |
+
|
| 603 |
+
model.train()
|
| 604 |
+
step = 0
|
| 605 |
+
total_tokens_processed = 0
|
| 606 |
+
t0 = time.time()
|
| 607 |
+
|
| 608 |
+
for epoch in range(args.num_epochs):
|
| 609 |
+
epoch_loss = 0.0
|
| 610 |
+
for batch_idx, batch in enumerate(train_dataloader):
|
| 611 |
+
input_ids = batch["input_ids"].to(device)
|
| 612 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 613 |
+
labels = batch["labels"].to(device)
|
| 614 |
+
|
| 615 |
+
# Accumulate the number of active (non-padded) tokens processed
|
| 616 |
+
total_tokens_processed += attention_mask.sum().item()
|
| 617 |
+
|
| 618 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| 619 |
+
loss = outputs.loss / args.grad_accum_steps
|
| 620 |
+
loss.backward()
|
| 621 |
+
|
| 622 |
+
epoch_loss += loss.item() * args.grad_accum_steps
|
| 623 |
+
|
| 624 |
+
if (batch_idx + 1) % args.grad_accum_steps == 0 or (batch_idx + 1) == len(train_dataloader):
|
| 625 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 626 |
+
optimizer.step()
|
| 627 |
+
scheduler.step()
|
| 628 |
+
optimizer.zero_grad()
|
| 629 |
+
step += 1
|
| 630 |
+
|
| 631 |
+
if step % 5 == 0 or step == total_steps:
|
| 632 |
+
elapsed = time.time() - t0
|
| 633 |
+
tokens_per_sec = total_tokens_processed / max(elapsed, 1e-5)
|
| 634 |
+
print(
|
| 635 |
+
f"Epoch {epoch+1}/{args.num_epochs} | "
|
| 636 |
+
f"Step {step}/{total_steps} | "
|
| 637 |
+
f"Loss: {loss.item() * args.grad_accum_steps:.4f} | "
|
| 638 |
+
f"LR: {scheduler.get_last_lr()[0]:.2e} | "
|
| 639 |
+
f"Tokens: {total_tokens_processed} | "
|
| 640 |
+
f"Speed: {tokens_per_sec:.2f} tokens/s"
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
# 6. Save model weights and tokenizer
|
| 644 |
+
print(f"\nTraining complete in {time.time() - t0:.1f}s. Saving weights to: {args.output_dir}")
|
| 645 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 646 |
+
if hasattr(model, "merge_and_unload") and not args.load_in_4bit:
|
| 647 |
+
print("Merging LoRA adapters into base weights...")
|
| 648 |
+
try:
|
| 649 |
+
merged_model = model.merge_and_unload()
|
| 650 |
+
merged_model.save_pretrained(args.output_dir)
|
| 651 |
+
print("Merged model weights saved successfully.")
|
| 652 |
+
except Exception as e:
|
| 653 |
+
print(f"Failed to merge and unload: {e}. Saving adapter weights only.")
|
| 654 |
+
model.save_pretrained(args.output_dir)
|
| 655 |
+
else:
|
| 656 |
+
model.save_pretrained(args.output_dir)
|
| 657 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 658 |
+
print("Weights and configuration saved successfully.")
|
| 659 |
+
|
| 660 |
+
# 7. SFT Downstream Evaluations
|
| 661 |
+
model.eval()
|
| 662 |
+
|
| 663 |
+
if args.run_prompt_suite:
|
| 664 |
+
run_prompt_suite(model, tokenizer, device, args.output_dir)
|
| 665 |
+
|
| 666 |
+
if args.run_gsm8k:
|
| 667 |
+
run_gsm8k_eval(model, tokenizer, device, num_samples=args.gsm8k_samples)
|
| 668 |
+
|
| 669 |
+
if args.push_to_hub:
|
| 670 |
+
print(f"\nUploading fine-tuned model and tokenizer to Hugging Face Hub: {args.hub_model_id}...")
|
| 671 |
+
try:
|
| 672 |
+
from huggingface_hub import create_repo, HfApi
|
| 673 |
+
token_val = args.token or os.environ.get("HF_TOKEN")
|
| 674 |
+
create_repo(repo_id=args.hub_model_id, token=token_val, exist_ok=True)
|
| 675 |
+
|
| 676 |
+
api = HfApi()
|
| 677 |
+
api.upload_folder(
|
| 678 |
+
folder_path=args.output_dir,
|
| 679 |
+
repo_id=args.hub_model_id,
|
| 680 |
+
repo_type="model",
|
| 681 |
+
token=token_val
|
| 682 |
+
)
|
| 683 |
+
print("Successfully uploaded model and tokenizer to Hugging Face Hub!")
|
| 684 |
+
except Exception as hub_err:
|
| 685 |
+
print(f"Failed to push to Hub: {hub_err}")
|
| 686 |
+
|
| 687 |
+
print("Pipeline Execution Complete. Model is ready.")
|
| 688 |
+
|
| 689 |
+
if __name__ == "__main__":
|
| 690 |
+
main()
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# src package
|
src/checkpoints.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from configs import cfg
|
| 11 |
+
|
| 12 |
+
def checkpoint_rank(path: str) -> tuple[int, int]:
|
| 13 |
+
name = os.path.basename(path)
|
| 14 |
+
prefix, _, raw_value = name.partition("_")
|
| 15 |
+
try:
|
| 16 |
+
value = int(raw_value)
|
| 17 |
+
except ValueError:
|
| 18 |
+
value = -1
|
| 19 |
+
if prefix == "epoch":
|
| 20 |
+
return (2, value)
|
| 21 |
+
if prefix == "step":
|
| 22 |
+
return (1, value)
|
| 23 |
+
return (0, value)
|
| 24 |
+
|
| 25 |
+
def find_latest_training_checkpoint(output_dir: str) -> str | None:
|
| 26 |
+
candidates = []
|
| 27 |
+
for pattern in ("epoch_*", "step_*"):
|
| 28 |
+
candidates.extend(str(path) for path in Path(output_dir).glob(pattern) if path.is_dir())
|
| 29 |
+
if not candidates:
|
| 30 |
+
return None
|
| 31 |
+
return max(candidates, key=checkpoint_rank)
|
| 32 |
+
|
| 33 |
+
def load_trainer_state(checkpoint_dir: str, log) -> dict:
|
| 34 |
+
state_path = os.path.join(checkpoint_dir, "trainer_state.json")
|
| 35 |
+
if os.path.exists(state_path):
|
| 36 |
+
try:
|
| 37 |
+
with open(state_path, "r", encoding="utf-8") as f:
|
| 38 |
+
state = json.load(f)
|
| 39 |
+
if isinstance(state, dict):
|
| 40 |
+
return state
|
| 41 |
+
except (OSError, json.JSONDecodeError) as exc:
|
| 42 |
+
log.warning(f"Could not read trainer_state.json from {checkpoint_dir}: {exc}")
|
| 43 |
+
|
| 44 |
+
name = os.path.basename(checkpoint_dir)
|
| 45 |
+
prefix, _, raw_value = name.partition("_")
|
| 46 |
+
try:
|
| 47 |
+
value = int(raw_value)
|
| 48 |
+
except ValueError:
|
| 49 |
+
value = 0
|
| 50 |
+
if prefix == "epoch":
|
| 51 |
+
return {
|
| 52 |
+
"checkpoint_type": "epoch",
|
| 53 |
+
"start_epoch": value,
|
| 54 |
+
"global_step": 0,
|
| 55 |
+
"micro_step_global": 0,
|
| 56 |
+
"next_batch_in_epoch": 0,
|
| 57 |
+
}
|
| 58 |
+
if prefix == "step":
|
| 59 |
+
return {
|
| 60 |
+
"checkpoint_type": "step",
|
| 61 |
+
"start_epoch": 0,
|
| 62 |
+
"global_step": value,
|
| 63 |
+
"micro_step_global": 0,
|
| 64 |
+
"next_batch_in_epoch": 0,
|
| 65 |
+
}
|
| 66 |
+
return {}
|
| 67 |
+
|
| 68 |
+
def packing_checkpoint_metadata(enabled: bool, pack_length: int | None, max_seq_len: int) -> dict[str, int | bool | None]:
|
| 69 |
+
return {
|
| 70 |
+
"sequence_packing_enabled": bool(enabled),
|
| 71 |
+
"sequence_packing_pack_length": int(pack_length) if enabled and pack_length is not None else None,
|
| 72 |
+
"data_max_seq_len": int(max_seq_len),
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
def validate_resume_packing_state(
|
| 76 |
+
trainer_state: dict,
|
| 77 |
+
*,
|
| 78 |
+
enabled: bool,
|
| 79 |
+
pack_length: int,
|
| 80 |
+
max_seq_len: int,
|
| 81 |
+
log,
|
| 82 |
+
) -> None:
|
| 83 |
+
checkpoint_enabled = bool(trainer_state.get("sequence_packing_enabled", False))
|
| 84 |
+
if checkpoint_enabled != bool(enabled):
|
| 85 |
+
log.error(
|
| 86 |
+
"Checkpoint sequence-packing state does not match the current run: "
|
| 87 |
+
f"checkpoint={checkpoint_enabled}, current={bool(enabled)}."
|
| 88 |
+
)
|
| 89 |
+
raise SystemExit(1)
|
| 90 |
+
|
| 91 |
+
if checkpoint_enabled:
|
| 92 |
+
checkpoint_pack_length = trainer_state.get("sequence_packing_pack_length")
|
| 93 |
+
try:
|
| 94 |
+
checkpoint_pack_length = int(checkpoint_pack_length)
|
| 95 |
+
except (TypeError, ValueError):
|
| 96 |
+
log.error("Checkpoint is missing a valid sequence_packing_pack_length value.")
|
| 97 |
+
raise SystemExit(1)
|
| 98 |
+
if checkpoint_pack_length != int(pack_length):
|
| 99 |
+
log.error(
|
| 100 |
+
"Checkpoint pack length does not match the current run: "
|
| 101 |
+
f"checkpoint={checkpoint_pack_length}, current={int(pack_length)}."
|
| 102 |
+
)
|
| 103 |
+
raise SystemExit(1)
|
| 104 |
+
|
| 105 |
+
checkpoint_max_seq_len = trainer_state.get("data_max_seq_len")
|
| 106 |
+
if checkpoint_max_seq_len is not None:
|
| 107 |
+
try:
|
| 108 |
+
checkpoint_max_seq_len = int(checkpoint_max_seq_len)
|
| 109 |
+
except (TypeError, ValueError):
|
| 110 |
+
log.error("Checkpoint is missing a valid data_max_seq_len value.")
|
| 111 |
+
raise SystemExit(1)
|
| 112 |
+
if checkpoint_max_seq_len != int(max_seq_len):
|
| 113 |
+
log.error(
|
| 114 |
+
"Checkpoint max sequence length does not match the current run: "
|
| 115 |
+
f"checkpoint={checkpoint_max_seq_len}, current={int(max_seq_len)}."
|
| 116 |
+
)
|
| 117 |
+
raise SystemExit(1)
|
| 118 |
+
|
| 119 |
+
def save_checkpoint(
|
| 120 |
+
model,
|
| 121 |
+
tokenizer,
|
| 122 |
+
output_dir: str,
|
| 123 |
+
tag: str,
|
| 124 |
+
logger,
|
| 125 |
+
*,
|
| 126 |
+
scheduler=None,
|
| 127 |
+
trainer_state: dict | None = None,
|
| 128 |
+
) -> str:
|
| 129 |
+
save_dir = os.path.join(output_dir, tag)
|
| 130 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 131 |
+
save_start = time.time()
|
| 132 |
+
logger.info(f"[CKPT] Saving {tag} -> {save_dir}/")
|
| 133 |
+
|
| 134 |
+
model_to_save = model.module if hasattr(model, "module") else model
|
| 135 |
+
if hasattr(model_to_save, "_orig_mod"):
|
| 136 |
+
model_to_save = model_to_save._orig_mod
|
| 137 |
+
|
| 138 |
+
model_to_save.config.save_pretrained(save_dir)
|
| 139 |
+
tokenizer.save_pretrained(save_dir)
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
from safetensors.torch import save_file
|
| 143 |
+
state_dict = {k: v.contiguous().cpu() for k, v in model_to_save.state_dict().items()}
|
| 144 |
+
save_file(state_dict, os.path.join(save_dir, "model.safetensors"))
|
| 145 |
+
logger.info("[CKPT] Saved via safetensors")
|
| 146 |
+
except ImportError:
|
| 147 |
+
torch.save(model_to_save.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
|
| 148 |
+
logger.info("[CKPT] Saved via torch.save")
|
| 149 |
+
|
| 150 |
+
if scheduler is not None:
|
| 151 |
+
torch.save(scheduler.state_dict(), os.path.join(save_dir, "scheduler.pt"))
|
| 152 |
+
|
| 153 |
+
if trainer_state is not None:
|
| 154 |
+
trainer_state = dict(trainer_state)
|
| 155 |
+
trainer_state.setdefault("tag", tag)
|
| 156 |
+
trainer_state.setdefault("saved_at", time.strftime("%Y-%m-%d %H:%M:%S %Z"))
|
| 157 |
+
with open(os.path.join(save_dir, "trainer_state.json"), "w", encoding="utf-8") as f:
|
| 158 |
+
json.dump(trainer_state, f, indent=2)
|
| 159 |
+
|
| 160 |
+
size_mb = sum(f.stat().st_size for f in Path(save_dir).rglob("*") if f.is_file()) / 1e6
|
| 161 |
+
save_elapsed = time.time() - save_start
|
| 162 |
+
logger.info(f"[CKPT] {tag} -> {save_dir}/ ({size_mb:.0f} MB, {save_elapsed:.1f}s)")
|
| 163 |
+
return save_dir
|
| 164 |
+
|
| 165 |
+
def read_env_flag(name: str, default: bool = False) -> bool:
|
| 166 |
+
raw = os.environ.get(name)
|
| 167 |
+
if raw is None:
|
| 168 |
+
return default
|
| 169 |
+
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
| 170 |
+
|
| 171 |
+
def hub_upload_strict() -> bool:
|
| 172 |
+
strict = getattr(getattr(cfg, "hub", None), "hub_upload_strict", None)
|
| 173 |
+
if strict is None:
|
| 174 |
+
return read_env_flag("QUINTUS_HUB_UPLOAD_STRICT", False)
|
| 175 |
+
return bool(strict)
|
| 176 |
+
|
| 177 |
+
def should_upload_checkpoint_tag(tag: str) -> bool:
|
| 178 |
+
upload_regular = getattr(getattr(cfg, "hub", None), "upload_kd_checkpoints", False) or read_env_flag("QUINTUS_UPLOAD_KD_CHECKPOINTS", False)
|
| 179 |
+
upload_steps = getattr(getattr(cfg, "hub", None), "upload_step_checkpoints", False) or read_env_flag("QUINTUS_UPLOAD_STEP_CHECKPOINTS", False)
|
| 180 |
+
upload_last = getattr(getattr(cfg, "hub", None), "upload_last_checkpoint", False) or read_env_flag("QUINTUS_UPLOAD_LAST_CHECKPOINT", False)
|
| 181 |
+
if tag.startswith("step_"):
|
| 182 |
+
return upload_steps
|
| 183 |
+
if tag.startswith("epoch_"):
|
| 184 |
+
return upload_regular
|
| 185 |
+
if tag == "best":
|
| 186 |
+
return upload_regular
|
| 187 |
+
if tag == "last":
|
| 188 |
+
return upload_last or upload_regular
|
| 189 |
+
return False
|
| 190 |
+
|
| 191 |
+
def maybe_upload_checkpoint(checkpoint_dir: str, tag: str, logger) -> None:
|
| 192 |
+
if not should_upload_checkpoint_tag(tag):
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
token = os.environ.get("HF_TOKEN") or getattr(cfg.hub, "token", None)
|
| 196 |
+
if not token:
|
| 197 |
+
msg = "HF checkpoint upload requested, but HF_TOKEN/cfg.hub.token is missing"
|
| 198 |
+
strict = hub_upload_strict()
|
| 199 |
+
if strict:
|
| 200 |
+
raise RuntimeError(msg)
|
| 201 |
+
logger.warning(f"[CKPT] {msg}; continuing without remote backup")
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
repo_id = getattr(getattr(cfg, "hub", None), "repo_id", None) or os.environ.get("QUINTUS_HUB_REPO_ID") or f"{cfg.hub.username}/{cfg.hub.repo_name}"
|
| 205 |
+
base_path = getattr(getattr(cfg, "hub", None), "ckpt_path_in_repo", None) or os.environ.get("KD_CKPT_PATH_IN_REPO", "models/online_kd_3b_05b_ep3_B200_20260601")
|
| 206 |
+
base_path = base_path.strip("/")
|
| 207 |
+
path_in_repo = f"{base_path}/{tag}"
|
| 208 |
+
commit_prefix = getattr(getattr(cfg, "hub", None), "commit_message_prefix", None) or os.environ.get(
|
| 209 |
+
"KD_COMMIT_MESSAGE_PREFIX",
|
| 210 |
+
"Online KD 8B->1.7B Run",
|
| 211 |
+
)
|
| 212 |
+
commit_message = os.environ.get("KD_COMMIT_MESSAGE") or f"{commit_prefix}: upload {tag}"
|
| 213 |
+
upload_start = time.time()
|
| 214 |
+
size_mb = sum(f.stat().st_size for f in Path(checkpoint_dir).rglob("*") if f.is_file()) / 1e6
|
| 215 |
+
strict = hub_upload_strict()
|
| 216 |
+
logger.info(
|
| 217 |
+
f"[CKPT] Uploading {tag} -> {repo_id}/{path_in_repo} "
|
| 218 |
+
f"({size_mb:.0f} MB, strict={strict})"
|
| 219 |
+
)
|
| 220 |
+
logger.info(f"[CKPT] Commit: {commit_message}")
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
from huggingface_hub import HfApi
|
| 224 |
+
|
| 225 |
+
api = HfApi(token=token)
|
| 226 |
+
api.create_repo(repo_id=repo_id, repo_type="dataset", private=True, exist_ok=True)
|
| 227 |
+
api.upload_folder(
|
| 228 |
+
folder_path=checkpoint_dir,
|
| 229 |
+
repo_id=repo_id,
|
| 230 |
+
path_in_repo=path_in_repo,
|
| 231 |
+
repo_type="dataset",
|
| 232 |
+
commit_message=commit_message,
|
| 233 |
+
ignore_patterns=["*.tmp", "*.log", "__pycache__/*"],
|
| 234 |
+
)
|
| 235 |
+
upload_elapsed = time.time() - upload_start
|
| 236 |
+
logger.info(f"[CKPT] Uploaded {tag} to HF Hub in {upload_elapsed / 60:.1f}m")
|
| 237 |
+
except Exception as exc:
|
| 238 |
+
msg = f"HF checkpoint upload failed for {tag}: {exc}"
|
| 239 |
+
if hub_upload_strict():
|
| 240 |
+
raise RuntimeError(msg) from exc
|
| 241 |
+
logger.warning(f"[CKPT] {msg}; continuing because hub upload strict mode is disabled")
|
src/download.py
ADDED
|
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import platform
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 15 |
+
|
| 16 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "0")
|
| 17 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 18 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 19 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from datasets import load_dataset
|
| 23 |
+
from huggingface_hub import snapshot_download
|
| 24 |
+
from transformers import AutoTokenizer
|
| 25 |
+
|
| 26 |
+
from configs import cfg, emit_log_spacing, setup_logger
|
| 27 |
+
from src.transformers_compat import format_model_load_error
|
| 28 |
+
|
| 29 |
+
_IGNORE_PATTERNS = ["*.msgpack", "*.h5", "*.bin", "optimizer.pt", "optimizer.safetensors"]
|
| 30 |
+
_TOKENIZER_ALLOW_PATTERNS = [
|
| 31 |
+
"tokenizer.json",
|
| 32 |
+
"tokenizer.model",
|
| 33 |
+
"tokenizer_config.json",
|
| 34 |
+
"special_tokens_map.json",
|
| 35 |
+
"added_tokens.json",
|
| 36 |
+
"vocab.json",
|
| 37 |
+
"merges.txt",
|
| 38 |
+
"generation_config.json",
|
| 39 |
+
]
|
| 40 |
+
_MIN_TOKEN_LENGTH = 10
|
| 41 |
+
_DATA_STATS_FILENAME = "_data_stats.json"
|
| 42 |
+
_ASSISTANT_MASK_KEYS = ("assistant_masks", "assistant_mask", "assistant_tokens_mask")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _build_chat_template_error_types() -> tuple[type[BaseException], ...]:
|
| 46 |
+
error_types: list[type[BaseException]] = [
|
| 47 |
+
AttributeError,
|
| 48 |
+
IndexError,
|
| 49 |
+
KeyError,
|
| 50 |
+
RuntimeError,
|
| 51 |
+
TypeError,
|
| 52 |
+
ValueError,
|
| 53 |
+
]
|
| 54 |
+
try:
|
| 55 |
+
from jinja2 import TemplateError
|
| 56 |
+
error_types.append(TemplateError)
|
| 57 |
+
except ImportError:
|
| 58 |
+
pass
|
| 59 |
+
return tuple(error_types)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
_CHAT_TEMPLATE_ERRORS = _build_chat_template_error_types()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _config_revision(value: str | None) -> str | None:
|
| 66 |
+
if value is None:
|
| 67 |
+
return None
|
| 68 |
+
stripped = value.strip()
|
| 69 |
+
return stripped or None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _download_tokenizer_artifacts(tokenizer_model: str, tokenizer_revision: str | None, tokenizer_dir: str, log) -> None:
|
| 73 |
+
log.info(f"Downloading tokenizer -> ./{tokenizer_dir}/")
|
| 74 |
+
t0 = time.time()
|
| 75 |
+
try:
|
| 76 |
+
snapshot_download(
|
| 77 |
+
repo_id=tokenizer_model,
|
| 78 |
+
local_dir=tokenizer_dir,
|
| 79 |
+
revision=tokenizer_revision,
|
| 80 |
+
allow_patterns=_TOKENIZER_ALLOW_PATTERNS,
|
| 81 |
+
)
|
| 82 |
+
size_mb = sum(f.stat().st_size for f in Path(tokenizer_dir).rglob("*") if f.is_file()) / 1e6
|
| 83 |
+
log.info(f"Tokenizer downloaded: {size_mb:.1f} MB in {time.time() - t0:.0f}s")
|
| 84 |
+
except Exception as exc:
|
| 85 |
+
log.error(f"Failed to download tokenizer: {exc}")
|
| 86 |
+
sys.exit(1)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def write_system_info(output_path: str, logger) -> None:
|
| 90 |
+
output_dir = os.path.dirname(output_path)
|
| 91 |
+
if output_dir:
|
| 92 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 93 |
+
info = {
|
| 94 |
+
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S %Z"),
|
| 95 |
+
"platform": platform.platform(),
|
| 96 |
+
"python": sys.version.split()[0],
|
| 97 |
+
"torch": torch.__version__,
|
| 98 |
+
"cuda": torch.version.cuda if torch.cuda.is_available() else None,
|
| 99 |
+
"gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
|
| 100 |
+
"cpu_count": os.cpu_count(),
|
| 101 |
+
}
|
| 102 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 103 |
+
json.dump(info, f, indent=2)
|
| 104 |
+
logger.info(f"System info -> {output_path}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def write_data_stats(output_path: str, stats: dict, dataset_id: str, config_name: str, target_samples: int, max_seq_len: int, logger) -> None:
|
| 108 |
+
output_dir = os.path.dirname(output_path)
|
| 109 |
+
if output_dir:
|
| 110 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 111 |
+
meta = {
|
| 112 |
+
"dataset": dataset_id,
|
| 113 |
+
"config": config_name,
|
| 114 |
+
"target_samples": target_samples,
|
| 115 |
+
"max_seq_len": max_seq_len,
|
| 116 |
+
"stats": stats,
|
| 117 |
+
}
|
| 118 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 119 |
+
json.dump(meta, f, indent=2)
|
| 120 |
+
logger.info(f"Dataset stats -> {output_path}")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def _coerce_content_str(content) -> str | None:
|
| 124 |
+
if isinstance(content, str):
|
| 125 |
+
return content.strip() or None
|
| 126 |
+
if isinstance(content, dict):
|
| 127 |
+
for key in ("answer_content", "text", "value", "content", "think_content"):
|
| 128 |
+
value = content.get(key)
|
| 129 |
+
if isinstance(value, str) and value.strip():
|
| 130 |
+
return value.strip()
|
| 131 |
+
return None
|
| 132 |
+
if isinstance(content, list):
|
| 133 |
+
parts = []
|
| 134 |
+
for item in content:
|
| 135 |
+
if isinstance(item, str) and item.strip():
|
| 136 |
+
parts.append(item.strip())
|
| 137 |
+
elif isinstance(item, dict):
|
| 138 |
+
value = item.get("text", item.get("value", ""))
|
| 139 |
+
if isinstance(value, str) and value.strip():
|
| 140 |
+
parts.append(value.strip())
|
| 141 |
+
joined = " ".join(parts).strip()
|
| 142 |
+
return joined or None
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _extract_clean_sample(row: dict) -> dict | None:
|
| 147 |
+
messages = row.get("messages", row.get("conversations", []))
|
| 148 |
+
if not messages and "instruction" in row and "output" in row:
|
| 149 |
+
messages = [
|
| 150 |
+
{"role": "user", "content": row["instruction"]},
|
| 151 |
+
{"role": "assistant", "content": row["output"]},
|
| 152 |
+
]
|
| 153 |
+
if isinstance(messages, str):
|
| 154 |
+
try:
|
| 155 |
+
messages = json.loads(messages)
|
| 156 |
+
except (json.JSONDecodeError, TypeError):
|
| 157 |
+
return None
|
| 158 |
+
if not isinstance(messages, list) or len(messages) < 2:
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
clean_messages: list[dict] = []
|
| 162 |
+
for message in messages:
|
| 163 |
+
if not isinstance(message, dict):
|
| 164 |
+
return None
|
| 165 |
+
role = message.get("role", message.get("from", ""))
|
| 166 |
+
content = _coerce_content_str(message.get("content", message.get("value", "")))
|
| 167 |
+
if not isinstance(role, str) or not role.strip() or content is None:
|
| 168 |
+
return None
|
| 169 |
+
role = role.strip().lower()
|
| 170 |
+
if role == "human":
|
| 171 |
+
role = "user"
|
| 172 |
+
elif role == "gpt":
|
| 173 |
+
role = "assistant"
|
| 174 |
+
clean_messages.append({"role": role, "content": content})
|
| 175 |
+
|
| 176 |
+
source = str(row.get("source", "unknown"))
|
| 177 |
+
return {"messages": clean_messages, "source": source}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _coerce_token_ids(token_ids) -> list[int]:
|
| 181 |
+
if hasattr(token_ids, "input_ids"):
|
| 182 |
+
token_ids = token_ids.input_ids
|
| 183 |
+
if isinstance(token_ids, dict):
|
| 184 |
+
token_ids = token_ids.get("input_ids", token_ids)
|
| 185 |
+
if hasattr(token_ids, "tolist"):
|
| 186 |
+
token_ids = token_ids.tolist()
|
| 187 |
+
if isinstance(token_ids, tuple):
|
| 188 |
+
token_ids = list(token_ids)
|
| 189 |
+
if not isinstance(token_ids, list):
|
| 190 |
+
raise TypeError(f"Unexpected token id payload: {type(token_ids).__name__}")
|
| 191 |
+
if token_ids and isinstance(token_ids[0], list):
|
| 192 |
+
raise TypeError("Expected a single token sequence, not a batched payload")
|
| 193 |
+
return [int(token_id) for token_id in token_ids]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _coerce_binary_mask(mask_values, expected_len: int) -> list[int]:
|
| 197 |
+
if hasattr(mask_values, "tolist"):
|
| 198 |
+
mask_values = mask_values.tolist()
|
| 199 |
+
if isinstance(mask_values, tuple):
|
| 200 |
+
mask_values = list(mask_values)
|
| 201 |
+
if not isinstance(mask_values, list):
|
| 202 |
+
raise TypeError(f"Unexpected assistant mask payload: {type(mask_values).__name__}")
|
| 203 |
+
if mask_values and isinstance(mask_values[0], list):
|
| 204 |
+
raise TypeError("Expected a single assistant mask, not a batched payload")
|
| 205 |
+
mask = [1 if int(value) != 0 else 0 for value in mask_values]
|
| 206 |
+
if len(mask) != expected_len:
|
| 207 |
+
raise ValueError(f"Assistant mask length mismatch: got {len(mask)}, expected {expected_len}")
|
| 208 |
+
return mask
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _try_builtin_assistant_mask(sample: dict, tokenizer: AutoTokenizer, input_ids: list[int]) -> list[int] | None:
|
| 212 |
+
try:
|
| 213 |
+
with warnings.catch_warnings():
|
| 214 |
+
warnings.filterwarnings("ignore", message="return_assistant_tokens_mask")
|
| 215 |
+
encoded = tokenizer.apply_chat_template(
|
| 216 |
+
sample["messages"],
|
| 217 |
+
tokenize=True,
|
| 218 |
+
add_generation_prompt=False,
|
| 219 |
+
return_dict=True,
|
| 220 |
+
return_assistant_tokens_mask=True,
|
| 221 |
+
)
|
| 222 |
+
except TypeError:
|
| 223 |
+
return None
|
| 224 |
+
except _CHAT_TEMPLATE_ERRORS:
|
| 225 |
+
return None
|
| 226 |
+
if not hasattr(encoded, "get"):
|
| 227 |
+
return None
|
| 228 |
+
try:
|
| 229 |
+
encoded_ids = _coerce_token_ids(encoded)
|
| 230 |
+
except _CHAT_TEMPLATE_ERRORS:
|
| 231 |
+
return None
|
| 232 |
+
if encoded_ids != input_ids:
|
| 233 |
+
return None
|
| 234 |
+
for key in _ASSISTANT_MASK_KEYS:
|
| 235 |
+
if key not in encoded:
|
| 236 |
+
continue
|
| 237 |
+
try:
|
| 238 |
+
mask = _coerce_binary_mask(encoded[key], expected_len=len(input_ids))
|
| 239 |
+
except (TypeError, ValueError):
|
| 240 |
+
return None
|
| 241 |
+
if any(mask):
|
| 242 |
+
return mask
|
| 243 |
+
return None
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _build_assistant_mask_from_prefixes(sample: dict, tokenizer: AutoTokenizer, input_ids: list[int]) -> list[int]:
|
| 247 |
+
loss_mask = [0] * len(input_ids)
|
| 248 |
+
prefix_ids: list[int] = []
|
| 249 |
+
for turn_index, message in enumerate(sample["messages"], start=1):
|
| 250 |
+
role = str(message.get("role", "")).strip().lower()
|
| 251 |
+
if role == "assistant":
|
| 252 |
+
# prompt_ids contains everything up to user message + assistant header (<|im_start|>assistant\n)
|
| 253 |
+
prompt_ids = _coerce_token_ids(
|
| 254 |
+
tokenizer.apply_chat_template(
|
| 255 |
+
sample["messages"][:turn_index-1],
|
| 256 |
+
tokenize=True,
|
| 257 |
+
add_generation_prompt=True,
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
+
# full_ids contains prompt + assistant response content + eos
|
| 261 |
+
full_ids = _coerce_token_ids(
|
| 262 |
+
tokenizer.apply_chat_template(
|
| 263 |
+
sample["messages"][:turn_index],
|
| 264 |
+
tokenize=True,
|
| 265 |
+
add_generation_prompt=False,
|
| 266 |
+
)
|
| 267 |
+
)
|
| 268 |
+
if len(full_ids) < len(prompt_ids) or full_ids[:len(prompt_ids)] != prompt_ids:
|
| 269 |
+
raise ValueError("Chat template is not prefix-stable enough to derive assistant-only targets")
|
| 270 |
+
|
| 271 |
+
# Loss mask is 1 only for assistant's content tokens (after prompt_ids)
|
| 272 |
+
for j in range(len(prompt_ids), len(full_ids)):
|
| 273 |
+
loss_mask[j] = 1
|
| 274 |
+
prefix_ids = _coerce_token_ids(
|
| 275 |
+
tokenizer.apply_chat_template(
|
| 276 |
+
sample["messages"][:turn_index],
|
| 277 |
+
tokenize=True,
|
| 278 |
+
add_generation_prompt=False,
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
if prefix_ids != input_ids:
|
| 282 |
+
raise ValueError("Prefix tokenization mismatch while deriving assistant-only targets")
|
| 283 |
+
if len(loss_mask) != len(input_ids):
|
| 284 |
+
raise ValueError(f"Assistant mask length mismatch: got {len(loss_mask)}, expected {len(input_ids)}")
|
| 285 |
+
return loss_mask
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def _build_assistant_loss_mask(sample: dict, tokenizer: AutoTokenizer, input_ids: list[int]) -> list[int]:
|
| 289 |
+
builtin_mask = _try_builtin_assistant_mask(sample, tokenizer, input_ids)
|
| 290 |
+
if builtin_mask is not None:
|
| 291 |
+
return builtin_mask
|
| 292 |
+
return _build_assistant_mask_from_prefixes(sample, tokenizer, input_ids)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def write_tokenized_dataset(tokenizer, num_samples: int, out_file: str, log) -> dict:
|
| 296 |
+
if num_samples <= 0:
|
| 297 |
+
raise RuntimeError("num_samples must be positive when tokenization is enabled")
|
| 298 |
+
output_dir = os.path.dirname(out_file)
|
| 299 |
+
if output_dir:
|
| 300 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 301 |
+
|
| 302 |
+
log.info(f"Streaming {cfg.data.dataset_path}...")
|
| 303 |
+
ds = load_dataset(cfg.data.dataset_path, split="train", streaming=True, token=cfg.hub.token)
|
| 304 |
+
buffer_size = int(getattr(cfg.data, "stream_shuffle_buffer_size", 0) or 0)
|
| 305 |
+
if buffer_size > 0 and hasattr(ds, "shuffle"):
|
| 306 |
+
seed = int(getattr(cfg.data, "stream_shuffle_seed", 42))
|
| 307 |
+
log.info(f"Shuffling stream with buffer_size={buffer_size:,} seed={seed}")
|
| 308 |
+
ds = ds.shuffle(seed=seed, buffer_size=buffer_size)
|
| 309 |
+
|
| 310 |
+
# Check if dataset is pre-tokenized
|
| 311 |
+
try:
|
| 312 |
+
first_sample = next(iter(ds))
|
| 313 |
+
is_pre_tokenized = "input_ids" in first_sample and "loss_mask" in first_sample
|
| 314 |
+
except StopIteration:
|
| 315 |
+
raise RuntimeError("Loaded dataset is empty")
|
| 316 |
+
|
| 317 |
+
stats = {
|
| 318 |
+
"scanned": 0,
|
| 319 |
+
"written": 0,
|
| 320 |
+
"too_long_tokens": 0,
|
| 321 |
+
"too_short_tokens": 0,
|
| 322 |
+
"template_errors": 0,
|
| 323 |
+
"no_target_tokens": 0,
|
| 324 |
+
"invalid_messages": 0,
|
| 325 |
+
"total_tokens_written": 0,
|
| 326 |
+
"total_target_tokens_written": 0,
|
| 327 |
+
"min_tokens_written": 0,
|
| 328 |
+
"max_tokens_written": 0,
|
| 329 |
+
"min_target_tokens_written": 0,
|
| 330 |
+
"max_target_tokens_written": 0,
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
if is_pre_tokenized:
|
| 334 |
+
log.info("Auto-detected pre-tokenized dataset on HF Hub. Writing directly to train.jsonl...")
|
| 335 |
+
# Re-initialize to avoid losing the first element consumed by next(iter())
|
| 336 |
+
ds = load_dataset(cfg.data.dataset_path, split="train", streaming=True, token=cfg.hub.token)
|
| 337 |
+
if buffer_size > 0 and hasattr(ds, "shuffle"):
|
| 338 |
+
ds = ds.shuffle(seed=seed, buffer_size=buffer_size)
|
| 339 |
+
|
| 340 |
+
with open(out_file, "w", encoding="utf-8") as f:
|
| 341 |
+
for row in ds:
|
| 342 |
+
stats["scanned"] += 1
|
| 343 |
+
input_ids = row["input_ids"]
|
| 344 |
+
loss_mask = row["loss_mask"]
|
| 345 |
+
token_len = len(input_ids)
|
| 346 |
+
target_tokens = sum(loss_mask)
|
| 347 |
+
|
| 348 |
+
out_row = {
|
| 349 |
+
"input_ids": input_ids,
|
| 350 |
+
"loss_mask": loss_mask,
|
| 351 |
+
"length": token_len,
|
| 352 |
+
"target_tokens": target_tokens,
|
| 353 |
+
"source": row.get("source", "unknown"),
|
| 354 |
+
}
|
| 355 |
+
f.write(json.dumps(out_row) + "\n")
|
| 356 |
+
|
| 357 |
+
stats["written"] += 1
|
| 358 |
+
stats["total_tokens_written"] += token_len
|
| 359 |
+
stats["total_target_tokens_written"] += target_tokens
|
| 360 |
+
if stats["written"] == 1:
|
| 361 |
+
stats["min_tokens_written"] = token_len
|
| 362 |
+
stats["max_tokens_written"] = token_len
|
| 363 |
+
stats["min_target_tokens_written"] = target_tokens
|
| 364 |
+
stats["max_target_tokens_written"] = target_tokens
|
| 365 |
+
else:
|
| 366 |
+
stats["min_tokens_written"] = min(stats["min_tokens_written"], token_len)
|
| 367 |
+
stats["max_tokens_written"] = max(stats["max_tokens_written"], token_len)
|
| 368 |
+
stats["min_target_tokens_written"] = min(stats["min_target_tokens_written"], target_tokens)
|
| 369 |
+
stats["max_target_tokens_written"] = max(stats["max_target_tokens_written"], target_tokens)
|
| 370 |
+
|
| 371 |
+
if stats["written"] >= num_samples:
|
| 372 |
+
break
|
| 373 |
+
return stats
|
| 374 |
+
|
| 375 |
+
log.info("Standard raw text dataset detected. Running tokenization locally...")
|
| 376 |
+
with open(out_file, "w", encoding="utf-8") as f:
|
| 377 |
+
for row in ds:
|
| 378 |
+
stats["scanned"] += 1
|
| 379 |
+
sample = _extract_clean_sample(row)
|
| 380 |
+
if sample is None:
|
| 381 |
+
stats["invalid_messages"] += 1
|
| 382 |
+
continue
|
| 383 |
+
try:
|
| 384 |
+
input_ids = _coerce_token_ids(
|
| 385 |
+
tokenizer.apply_chat_template(
|
| 386 |
+
sample["messages"],
|
| 387 |
+
tokenize=True,
|
| 388 |
+
add_generation_prompt=False,
|
| 389 |
+
)
|
| 390 |
+
)
|
| 391 |
+
token_len = len(input_ids)
|
| 392 |
+
if token_len < _MIN_TOKEN_LENGTH:
|
| 393 |
+
stats["too_short_tokens"] += 1
|
| 394 |
+
continue
|
| 395 |
+
if token_len > cfg.data.max_seq_len:
|
| 396 |
+
stats["too_long_tokens"] += 1
|
| 397 |
+
continue
|
| 398 |
+
loss_mask = _build_assistant_loss_mask(sample, tokenizer, input_ids)
|
| 399 |
+
except _CHAT_TEMPLATE_ERRORS:
|
| 400 |
+
stats["template_errors"] += 1
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
target_tokens = sum(loss_mask)
|
| 404 |
+
if target_tokens == 0:
|
| 405 |
+
stats["no_target_tokens"] += 1
|
| 406 |
+
continue
|
| 407 |
+
|
| 408 |
+
out_row = {
|
| 409 |
+
"input_ids": input_ids,
|
| 410 |
+
"loss_mask": loss_mask,
|
| 411 |
+
"length": token_len,
|
| 412 |
+
"target_tokens": target_tokens,
|
| 413 |
+
"source": sample.get("source", "unknown"),
|
| 414 |
+
}
|
| 415 |
+
f.write(json.dumps(out_row) + "\n")
|
| 416 |
+
|
| 417 |
+
stats["written"] += 1
|
| 418 |
+
stats["total_tokens_written"] += token_len
|
| 419 |
+
stats["total_target_tokens_written"] += target_tokens
|
| 420 |
+
if stats["written"] == 1:
|
| 421 |
+
stats["min_tokens_written"] = token_len
|
| 422 |
+
stats["max_tokens_written"] = token_len
|
| 423 |
+
stats["min_target_tokens_written"] = target_tokens
|
| 424 |
+
stats["max_target_tokens_written"] = target_tokens
|
| 425 |
+
else:
|
| 426 |
+
stats["min_tokens_written"] = min(stats["min_tokens_written"], token_len)
|
| 427 |
+
stats["max_tokens_written"] = max(stats["max_tokens_written"], token_len)
|
| 428 |
+
stats["min_target_tokens_written"] = min(stats["min_target_tokens_written"], target_tokens)
|
| 429 |
+
stats["max_target_tokens_written"] = max(stats["max_target_tokens_written"], target_tokens)
|
| 430 |
+
|
| 431 |
+
if stats["written"] >= num_samples:
|
| 432 |
+
break
|
| 433 |
+
|
| 434 |
+
return stats
|
| 435 |
+
|
| 436 |
+
def main() -> None:
|
| 437 |
+
parser = argparse.ArgumentParser(description="Download models and tokenize data")
|
| 438 |
+
parser.add_argument("--num_samples", type=int, default=cfg.data.num_samples)
|
| 439 |
+
parser.add_argument("--skip_teacher", action="store_true", help="Skip teacher download.")
|
| 440 |
+
parser.add_argument("--skip_tokenization", action="store_true", help="Skip data tokenization.")
|
| 441 |
+
parser.add_argument("--tokenizer_only", action="store_true", help="Download tokenizer artifacts only")
|
| 442 |
+
args = parser.parse_args()
|
| 443 |
+
|
| 444 |
+
log = setup_logger("DOWNLOAD")
|
| 445 |
+
log.info("=" * 70)
|
| 446 |
+
log.info("Download and tokenize")
|
| 447 |
+
log.info("=" * 70)
|
| 448 |
+
|
| 449 |
+
write_system_info(cfg.paths.system_info, log)
|
| 450 |
+
|
| 451 |
+
log.info(f" Teacher: {cfg.model.teacher}")
|
| 452 |
+
log.info(f" Teacher rev: {_config_revision(cfg.model.teacher_revision) or 'unversioned'}")
|
| 453 |
+
tokenizer_model = getattr(cfg.model, "tokenizer", cfg.model.student)
|
| 454 |
+
tokenizer_revision = _config_revision(getattr(cfg.model, "tokenizer_revision", cfg.model.student_revision))
|
| 455 |
+
tokenizer_dir = getattr(cfg.paths, "tokenizer_dir", cfg.paths.student_dir)
|
| 456 |
+
log.info(f" Student: {cfg.model.student}")
|
| 457 |
+
log.info(f" Student rev: {_config_revision(cfg.model.student_revision) or 'unversioned'}")
|
| 458 |
+
log.info(f" Tokenizer: {tokenizer_model}")
|
| 459 |
+
log.info(f" Tokenizer rev:{tokenizer_revision or 'unversioned'}")
|
| 460 |
+
log.info(f" Student dir: {cfg.paths.student_dir}")
|
| 461 |
+
log.info(f" Tokenizer dir:{tokenizer_dir}")
|
| 462 |
+
log.info(f" Remote code: {cfg.model.allow_remote_code}")
|
| 463 |
+
log.info(f" Dataset: {cfg.data.dataset_path}")
|
| 464 |
+
log.info(f" Num samples: {args.num_samples:,}")
|
| 465 |
+
log.info(f" Max seq len: {cfg.data.max_seq_len}")
|
| 466 |
+
if torch.cuda.is_available():
|
| 467 |
+
log.info(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 468 |
+
|
| 469 |
+
if not args.tokenizer_only:
|
| 470 |
+
emit_log_spacing(log)
|
| 471 |
+
log.info("-" * 70)
|
| 472 |
+
log.info(f"Downloading student -> ./{cfg.paths.student_dir}/")
|
| 473 |
+
t0 = time.time()
|
| 474 |
+
try:
|
| 475 |
+
snapshot_download(
|
| 476 |
+
repo_id=cfg.model.student,
|
| 477 |
+
local_dir=cfg.paths.student_dir,
|
| 478 |
+
revision=_config_revision(cfg.model.student_revision),
|
| 479 |
+
ignore_patterns=_IGNORE_PATTERNS,
|
| 480 |
+
)
|
| 481 |
+
size_gb = sum(f.stat().st_size for f in Path(cfg.paths.student_dir).rglob("*") if f.is_file()) / 1e9
|
| 482 |
+
log.info(f"Student downloaded: {size_gb:.1f} GB in {time.time() - t0:.0f}s")
|
| 483 |
+
except Exception as exc:
|
| 484 |
+
log.error(f"Failed to download student: {exc}")
|
| 485 |
+
sys.exit(1)
|
| 486 |
+
|
| 487 |
+
if args.tokenizer_only or Path(tokenizer_dir).resolve() != Path(cfg.paths.student_dir).resolve():
|
| 488 |
+
emit_log_spacing(log)
|
| 489 |
+
log.info("-" * 70)
|
| 490 |
+
_download_tokenizer_artifacts(tokenizer_model, tokenizer_revision, tokenizer_dir, log)
|
| 491 |
+
|
| 492 |
+
if not args.skip_tokenization:
|
| 493 |
+
emit_log_spacing(log)
|
| 494 |
+
log.info("-" * 70)
|
| 495 |
+
log.info(f"Preparing dataset: {cfg.data.dataset_path}")
|
| 496 |
+
|
| 497 |
+
try:
|
| 498 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 499 |
+
tokenizer_dir,
|
| 500 |
+
trust_remote_code=cfg.model.allow_remote_code,
|
| 501 |
+
)
|
| 502 |
+
except Exception as exc:
|
| 503 |
+
log.error(format_model_load_error("Student tokenizer load", exc))
|
| 504 |
+
sys.exit(1)
|
| 505 |
+
if tokenizer.pad_token is None:
|
| 506 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 507 |
+
|
| 508 |
+
os.makedirs(cfg.paths.tokenized_dir, exist_ok=True)
|
| 509 |
+
out_file = os.path.join(cfg.paths.tokenized_dir, "train.jsonl")
|
| 510 |
+
stats_file = os.path.join(cfg.paths.tokenized_dir, _DATA_STATS_FILENAME)
|
| 511 |
+
|
| 512 |
+
log.info(
|
| 513 |
+
f"Streaming + tokenizing up to {args.num_samples:,} samples "
|
| 514 |
+
f"(max_seq_len={cfg.data.max_seq_len}, strict token limit, no truncation)"
|
| 515 |
+
)
|
| 516 |
+
t0 = time.time()
|
| 517 |
+
try:
|
| 518 |
+
token_stats = write_tokenized_dataset(
|
| 519 |
+
tokenizer=tokenizer,
|
| 520 |
+
num_samples=args.num_samples,
|
| 521 |
+
out_file=out_file,
|
| 522 |
+
log=log,
|
| 523 |
+
)
|
| 524 |
+
except RuntimeError as exc:
|
| 525 |
+
log.error(str(exc))
|
| 526 |
+
sys.exit(1)
|
| 527 |
+
|
| 528 |
+
write_data_stats(
|
| 529 |
+
output_path=stats_file,
|
| 530 |
+
stats=token_stats,
|
| 531 |
+
dataset_id=cfg.data.dataset_path,
|
| 532 |
+
config_name="default",
|
| 533 |
+
target_samples=args.num_samples,
|
| 534 |
+
max_seq_len=cfg.data.max_seq_len,
|
| 535 |
+
logger=log,
|
| 536 |
+
)
|
| 537 |
+
n_written = token_stats["written"]
|
| 538 |
+
|
| 539 |
+
if n_written == 0:
|
| 540 |
+
log.error("Tokenization produced 0 usable rows - aborting.")
|
| 541 |
+
sys.exit(1)
|
| 542 |
+
|
| 543 |
+
log.info(f"Pretokenization complete: {n_written:,} samples -> {out_file}")
|
| 544 |
+
else:
|
| 545 |
+
log.info("Skipping dataset tokenization (--skip_tokenization)")
|
| 546 |
+
|
| 547 |
+
if not args.skip_teacher:
|
| 548 |
+
emit_log_spacing(log)
|
| 549 |
+
log.info("-" * 70)
|
| 550 |
+
log.info(f"Downloading teacher -> ./{cfg.paths.teacher_dir}/")
|
| 551 |
+
t0 = time.time()
|
| 552 |
+
try:
|
| 553 |
+
snapshot_download(
|
| 554 |
+
repo_id=cfg.model.teacher,
|
| 555 |
+
local_dir=cfg.paths.teacher_dir,
|
| 556 |
+
revision=_config_revision(cfg.model.teacher_revision),
|
| 557 |
+
ignore_patterns=_IGNORE_PATTERNS,
|
| 558 |
+
)
|
| 559 |
+
size_gb = sum(f.stat().st_size for f in Path(cfg.paths.teacher_dir).rglob("*") if f.is_file()) / 1e9
|
| 560 |
+
log.info(f"Teacher downloaded: {size_gb:.1f} GB in {time.time() - t0:.0f}s")
|
| 561 |
+
except Exception as exc:
|
| 562 |
+
log.error(f"Failed to download teacher: {exc}")
|
| 563 |
+
sys.exit(1)
|
| 564 |
+
else:
|
| 565 |
+
log.info("Skipping teacher download (--skip_teacher)")
|
| 566 |
+
|
| 567 |
+
emit_log_spacing(log)
|
| 568 |
+
log.info("-" * 70)
|
| 569 |
+
log.info("Download complete")
|
| 570 |
+
log.info("-" * 70)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
if __name__ == "__main__":
|
| 574 |
+
main()
|
src/kd_contracts.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
PROVENANCE_SCHEMA_VERSION = 4
|
| 9 |
+
|
| 10 |
+
_SHARD_SCHEMA = {
|
| 11 |
+
"support": "teacher_topk_plus_other_bucket",
|
| 12 |
+
"layout": "chunked_sample_lists",
|
| 13 |
+
"logprobs_dtype": "float16",
|
| 14 |
+
"ids_dtype": "int32",
|
| 15 |
+
"other_logprob_dtype": "float16",
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
_SPECIAL_TOKEN_ID_FIELDS = (
|
| 19 |
+
"bos_token_id",
|
| 20 |
+
"eos_token_id",
|
| 21 |
+
"pad_token_id",
|
| 22 |
+
"unk_token_id",
|
| 23 |
+
"cls_token_id",
|
| 24 |
+
"sep_token_id",
|
| 25 |
+
"mask_token_id",
|
| 26 |
+
"additional_special_tokens_ids",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def normalize_config_revision(value: str | None) -> str | None:
|
| 30 |
+
if value is None:
|
| 31 |
+
return None
|
| 32 |
+
stripped = value.strip()
|
| 33 |
+
return stripped or None
|
| 34 |
+
|
| 35 |
+
def canonical_revision(value: str | None) -> str:
|
| 36 |
+
return normalize_config_revision(value) or "unversioned"
|
| 37 |
+
|
| 38 |
+
def sha256_file(path: str | Path, chunk_size: int = 1 << 20) -> str:
|
| 39 |
+
digest = hashlib.sha256()
|
| 40 |
+
with open(path, "rb") as handle:
|
| 41 |
+
while True:
|
| 42 |
+
chunk = handle.read(chunk_size)
|
| 43 |
+
if not chunk:
|
| 44 |
+
break
|
| 45 |
+
digest.update(chunk)
|
| 46 |
+
return digest.hexdigest()
|
| 47 |
+
|
| 48 |
+
def _special_token_ids(tokenizer) -> dict[str, Any]:
|
| 49 |
+
snapshot: dict[str, Any] = {}
|
| 50 |
+
for field in _SPECIAL_TOKEN_ID_FIELDS:
|
| 51 |
+
value = getattr(tokenizer, field, None)
|
| 52 |
+
if isinstance(value, tuple):
|
| 53 |
+
value = list(value)
|
| 54 |
+
snapshot[field] = value
|
| 55 |
+
return snapshot
|
| 56 |
+
|
| 57 |
+
def build_tokenizer_contract(tokenizer) -> dict[str, Any]:
|
| 58 |
+
canonical = {
|
| 59 |
+
"tokenizer_class": tokenizer.__class__.__name__,
|
| 60 |
+
"full_vocab_size": len(tokenizer),
|
| 61 |
+
"special_token_ids": _special_token_ids(tokenizer),
|
| 62 |
+
"vocab": dict(sorted(tokenizer.get_vocab().items())),
|
| 63 |
+
}
|
| 64 |
+
encoded = json.dumps(
|
| 65 |
+
canonical,
|
| 66 |
+
sort_keys=True,
|
| 67 |
+
separators=(",", ":"),
|
| 68 |
+
ensure_ascii=True,
|
| 69 |
+
).encode("utf-8")
|
| 70 |
+
return {
|
| 71 |
+
"tokenizer_class": canonical["tokenizer_class"],
|
| 72 |
+
"full_vocab_size": canonical["full_vocab_size"],
|
| 73 |
+
"special_token_ids": canonical["special_token_ids"],
|
| 74 |
+
"fingerprint": hashlib.sha256(encoded).hexdigest(),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def build_shard_schema() -> dict[str, str]:
|
| 78 |
+
return dict(_SHARD_SCHEMA)
|
| 79 |
+
|
| 80 |
+
def collect_model_vocab_sizes(model) -> dict[str, int]:
|
| 81 |
+
sizes: dict[str, int] = {}
|
| 82 |
+
|
| 83 |
+
config_size = getattr(getattr(model, "config", None), "vocab_size", None)
|
| 84 |
+
if isinstance(config_size, int):
|
| 85 |
+
sizes["config"] = config_size
|
| 86 |
+
|
| 87 |
+
input_embeddings = model.get_input_embeddings()
|
| 88 |
+
if input_embeddings is not None and getattr(input_embeddings, "weight", None) is not None:
|
| 89 |
+
sizes["input_embeddings"] = int(input_embeddings.weight.shape[0])
|
| 90 |
+
|
| 91 |
+
output_embeddings = model.get_output_embeddings()
|
| 92 |
+
if output_embeddings is not None and getattr(output_embeddings, "weight", None) is not None:
|
| 93 |
+
sizes["output_embeddings"] = int(output_embeddings.weight.shape[0])
|
| 94 |
+
|
| 95 |
+
return sizes
|
src/losses.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
PROB_EPS = 1.0e-12
|
| 8 |
+
|
| 9 |
+
def _normalize_support_logprobs(
|
| 10 |
+
topk_logprobs: torch.Tensor,
|
| 11 |
+
other_logprob: torch.Tensor,
|
| 12 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 13 |
+
topk_probs = topk_logprobs.float().exp()
|
| 14 |
+
other_prob = other_logprob.float().exp().unsqueeze(-1)
|
| 15 |
+
support_probs = torch.cat([topk_probs, other_prob], dim=-1).clamp_min(PROB_EPS)
|
| 16 |
+
support_probs = support_probs / support_probs.sum(dim=-1, keepdim=True).clamp_min(PROB_EPS)
|
| 17 |
+
return support_probs, support_probs.log()
|
| 18 |
+
|
| 19 |
+
def _masked_mean(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
mask = mask.float()
|
| 21 |
+
return (values * mask).sum() / mask.sum().clamp(min=1.0)
|
| 22 |
+
|
| 23 |
+
def compute_sft_ce(logits: torch.Tensor, labels: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
batch_size = logits.size(0)
|
| 25 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 26 |
+
shift_loss_mask = ((loss_mask[:, 1:] > 0) & shift_labels.ne(-100)).contiguous().float()
|
| 27 |
+
|
| 28 |
+
total_loss = torch.tensor(0.0, device=logits.device, dtype=torch.bfloat16)
|
| 29 |
+
total_weight = torch.tensor(0.0, device=logits.device, dtype=torch.bfloat16)
|
| 30 |
+
|
| 31 |
+
for i in range(batch_size):
|
| 32 |
+
b_logits = logits[i, :-1, :]
|
| 33 |
+
b_labels = shift_labels[i]
|
| 34 |
+
b_mask = shift_loss_mask[i]
|
| 35 |
+
|
| 36 |
+
ce = F.cross_entropy(
|
| 37 |
+
b_logits,
|
| 38 |
+
b_labels,
|
| 39 |
+
ignore_index=-100,
|
| 40 |
+
reduction="none",
|
| 41 |
+
)
|
| 42 |
+
total_loss += (ce * b_mask).sum()
|
| 43 |
+
total_weight += b_mask.sum()
|
| 44 |
+
|
| 45 |
+
return total_loss / total_weight.clamp(min=1.0)
|
| 46 |
+
|
| 47 |
+
def _compute_masked_ce_with_logits(logits, labels, loss_mask):
|
| 48 |
+
loss_ce = compute_sft_ce(logits, labels, loss_mask)
|
| 49 |
+
shift_logits = logits[:, :-1, :]
|
| 50 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 51 |
+
shift_loss_mask = ((loss_mask[:, 1:] > 0) & shift_labels.ne(-100)).contiguous().float()
|
| 52 |
+
return loss_ce, shift_logits, shift_loss_mask
|
| 53 |
+
|
| 54 |
+
def compute_distillation_loss(
|
| 55 |
+
student_logits: torch.Tensor,
|
| 56 |
+
labels: torch.Tensor,
|
| 57 |
+
teacher_logprobs: torch.Tensor,
|
| 58 |
+
teacher_ids: torch.Tensor,
|
| 59 |
+
teacher_other_logprob: torch.Tensor,
|
| 60 |
+
loss_mask: torch.Tensor,
|
| 61 |
+
alpha: float,
|
| 62 |
+
temperature: float,
|
| 63 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 64 |
+
vocab_size = student_logits.size(-1)
|
| 65 |
+
loss_ce, shift_logits, shift_loss_mask = _compute_masked_ce_with_logits(student_logits, labels, loss_mask)
|
| 66 |
+
|
| 67 |
+
shift_teacher_logprobs = teacher_logprobs[:, :-1, :].contiguous()
|
| 68 |
+
shift_teacher_ids = teacher_ids[:, :-1, :].contiguous()
|
| 69 |
+
shift_teacher_other_logprob = teacher_other_logprob[:, :-1].contiguous()
|
| 70 |
+
shift_student = shift_logits
|
| 71 |
+
|
| 72 |
+
topk_ids_clamped = shift_teacher_ids.clamp(0, vocab_size - 1)
|
| 73 |
+
student_log_z = torch.logsumexp(shift_student / temperature, dim=-1, keepdim=True).float()
|
| 74 |
+
student_topk_logprobs = shift_student.gather(-1, topk_ids_clamped).float() / temperature - student_log_z
|
| 75 |
+
student_topk_probs = student_topk_logprobs.float().exp()
|
| 76 |
+
student_other_prob = (1.0 - student_topk_probs.sum(dim=-1)).clamp_min(PROB_EPS)
|
| 77 |
+
student_other_logprob = student_other_prob.log()
|
| 78 |
+
|
| 79 |
+
teacher_support_probs, teacher_support_logprobs = _normalize_support_logprobs(
|
| 80 |
+
shift_teacher_logprobs,
|
| 81 |
+
shift_teacher_other_logprob,
|
| 82 |
+
)
|
| 83 |
+
_, student_support_logprobs = _normalize_support_logprobs(
|
| 84 |
+
student_topk_logprobs,
|
| 85 |
+
student_other_logprob,
|
| 86 |
+
)
|
| 87 |
+
positive_teacher = teacher_support_probs > 0
|
| 88 |
+
kl_terms = torch.where(
|
| 89 |
+
positive_teacher,
|
| 90 |
+
teacher_support_probs * (teacher_support_logprobs - student_support_logprobs),
|
| 91 |
+
torch.zeros_like(teacher_support_probs),
|
| 92 |
+
)
|
| 93 |
+
kl_per_token = kl_terms.sum(-1)
|
| 94 |
+
loss_kd = _masked_mean(kl_per_token, shift_loss_mask) * (temperature**2)
|
| 95 |
+
|
| 96 |
+
loss_total = alpha * loss_ce + (1.0 - alpha) * loss_kd
|
| 97 |
+
return loss_total, loss_ce.detach(), loss_kd.detach()
|
| 98 |
+
|
| 99 |
+
def compute_online_kd_loss(
|
| 100 |
+
student_logits: torch.Tensor,
|
| 101 |
+
teacher_logits: torch.Tensor,
|
| 102 |
+
labels: torch.Tensor,
|
| 103 |
+
loss_mask: torch.Tensor,
|
| 104 |
+
alpha: float,
|
| 105 |
+
temperature: float,
|
| 106 |
+
token_chunk_size: int = 2048,
|
| 107 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 108 |
+
loss_ce = compute_sft_ce(student_logits, labels, loss_mask)
|
| 109 |
+
|
| 110 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 111 |
+
shift_loss_mask = (
|
| 112 |
+
(loss_mask[:, 1:] > 0) & shift_labels.ne(-100)
|
| 113 |
+
).contiguous().float()
|
| 114 |
+
|
| 115 |
+
s_shifted = student_logits[:, :-1, :]
|
| 116 |
+
t_shifted = teacher_logits[:, :-1, :]
|
| 117 |
+
seq_len = s_shifted.size(1)
|
| 118 |
+
|
| 119 |
+
total_kl = torch.tensor(0.0, device=student_logits.device, dtype=torch.float32)
|
| 120 |
+
total_weight = torch.tensor(0.0, device=student_logits.device, dtype=torch.float32)
|
| 121 |
+
|
| 122 |
+
for tok_start in range(0, seq_len, token_chunk_size):
|
| 123 |
+
tok_end = min(tok_start + token_chunk_size, seq_len)
|
| 124 |
+
|
| 125 |
+
s_chunk = s_shifted[:, tok_start:tok_end, :].float()
|
| 126 |
+
t_chunk = t_shifted[:, tok_start:tok_end, :].float()
|
| 127 |
+
mask_chunk = shift_loss_mask[:, tok_start:tok_end]
|
| 128 |
+
|
| 129 |
+
chunk_weight = mask_chunk.sum()
|
| 130 |
+
t_probs = F.softmax(t_chunk / temperature, dim=-1)
|
| 131 |
+
s_log_probs = F.log_softmax(s_chunk / temperature, dim=-1)
|
| 132 |
+
kl_tokens = F.kl_div(
|
| 133 |
+
s_log_probs, t_probs, log_target=False, reduction="none"
|
| 134 |
+
).sum(dim=-1)
|
| 135 |
+
|
| 136 |
+
total_kl += (kl_tokens * mask_chunk).sum()
|
| 137 |
+
total_weight += chunk_weight
|
| 138 |
+
|
| 139 |
+
del s_chunk, t_chunk, t_probs, s_log_probs, kl_tokens, mask_chunk
|
| 140 |
+
|
| 141 |
+
loss_kd = (total_kl / total_weight.clamp(min=1.0)) * (temperature ** 2)
|
| 142 |
+
loss_kd = loss_kd.to(dtype=student_logits.dtype)
|
| 143 |
+
loss_total = alpha * loss_ce + (1.0 - alpha) * loss_kd
|
| 144 |
+
|
| 145 |
+
return loss_total, loss_ce.detach(), loss_kd.detach()
|
| 146 |
+
|
| 147 |
+
def compute_loss_for_phase(
|
| 148 |
+
phase: str,
|
| 149 |
+
logits: torch.Tensor,
|
| 150 |
+
labels: torch.Tensor,
|
| 151 |
+
loss_mask: torch.Tensor,
|
| 152 |
+
batch: dict,
|
| 153 |
+
alpha: float,
|
| 154 |
+
temperature: float,
|
| 155 |
+
teacher_logits: torch.Tensor | None = None,
|
| 156 |
+
online_kd_token_chunk_size: int = 2048,
|
| 157 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 158 |
+
if phase == "sft":
|
| 159 |
+
loss_ce = compute_sft_ce(logits, labels, loss_mask)
|
| 160 |
+
return loss_ce, loss_ce.detach(), torch.tensor(0.0, device=logits.device)
|
| 161 |
+
if phase == "online_kd":
|
| 162 |
+
return compute_online_kd_loss(
|
| 163 |
+
logits,
|
| 164 |
+
teacher_logits,
|
| 165 |
+
labels,
|
| 166 |
+
loss_mask,
|
| 167 |
+
alpha,
|
| 168 |
+
temperature,
|
| 169 |
+
token_chunk_size=online_kd_token_chunk_size,
|
| 170 |
+
)
|
| 171 |
+
return compute_distillation_loss(
|
| 172 |
+
logits,
|
| 173 |
+
labels,
|
| 174 |
+
batch["teacher_logprobs"],
|
| 175 |
+
batch["teacher_ids"],
|
| 176 |
+
batch["teacher_other_logprob"],
|
| 177 |
+
loss_mask,
|
| 178 |
+
alpha,
|
| 179 |
+
temperature,
|
| 180 |
+
)
|
src/optim.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from configs import cfg
|
| 6 |
+
|
| 7 |
+
def fused_adamw_preflight(logger) -> bool:
|
| 8 |
+
if not torch.cuda.is_available():
|
| 9 |
+
logger.info(" Optimizer: fused AdamW requested but CUDA is unavailable; using standard AdamW")
|
| 10 |
+
return False
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
probe = torch.nn.Parameter(torch.ones(8, device="cuda", dtype=torch.bfloat16))
|
| 14 |
+
probe_optim = torch.optim.AdamW([probe], lr=1.0e-4, fused=True)
|
| 15 |
+
loss = probe.float().square().sum()
|
| 16 |
+
loss.backward()
|
| 17 |
+
probe_optim.step()
|
| 18 |
+
probe_optim.zero_grad(set_to_none=True)
|
| 19 |
+
del loss, probe_optim, probe
|
| 20 |
+
return True
|
| 21 |
+
except Exception as exc:
|
| 22 |
+
logger.warning(f" Optimizer: fused AdamW preflight failed ({exc}); using standard AdamW")
|
| 23 |
+
return False
|
| 24 |
+
|
| 25 |
+
def build_adamw_optimizer(params: list[torch.nn.Parameter], logger, allow_fused: bool) -> torch.optim.Optimizer:
|
| 26 |
+
kwargs = {
|
| 27 |
+
"lr": cfg.training.learning_rate,
|
| 28 |
+
"weight_decay": cfg.training.weight_decay,
|
| 29 |
+
"betas": (0.9, 0.999),
|
| 30 |
+
}
|
| 31 |
+
fused_requested = bool(getattr(cfg.training, "fused_adamw", False)) and allow_fused
|
| 32 |
+
if fused_requested and fused_adamw_preflight(logger):
|
| 33 |
+
try:
|
| 34 |
+
optimizer = torch.optim.AdamW(params, **kwargs, fused=True)
|
| 35 |
+
logger.info(" Optimizer: AdamW fused=True")
|
| 36 |
+
return optimizer
|
| 37 |
+
except Exception as exc:
|
| 38 |
+
logger.warning(f" Optimizer: fused AdamW construction failed ({exc}); using standard AdamW")
|
| 39 |
+
elif bool(getattr(cfg.training, "fused_adamw", False)) and not allow_fused:
|
| 40 |
+
logger.info(" Optimizer: fused AdamW disabled for DeepSpeed")
|
| 41 |
+
|
| 42 |
+
optimizer = torch.optim.AdamW(params, **kwargs)
|
| 43 |
+
logger.info(" Optimizer: AdamW standard")
|
| 44 |
+
return optimizer
|
src/provenance.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from configs import cfg
|
| 8 |
+
from src.kd_contracts import (
|
| 9 |
+
PROVENANCE_SCHEMA_VERSION,
|
| 10 |
+
build_shard_schema,
|
| 11 |
+
canonical_revision,
|
| 12 |
+
collect_model_vocab_sizes,
|
| 13 |
+
sha256_file,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def resolve_model_vocab_size(model, tokenizer, label: str, log) -> int:
|
| 17 |
+
model_sizes = collect_model_vocab_sizes(model)
|
| 18 |
+
if not model_sizes:
|
| 19 |
+
log.error(f"{label} model does not expose a usable vocab size view.")
|
| 20 |
+
raise SystemExit(1)
|
| 21 |
+
|
| 22 |
+
unique_sizes = sorted(set(model_sizes.values()))
|
| 23 |
+
if len(unique_sizes) != 1:
|
| 24 |
+
details = ", ".join(f"{name}={size:,}" for name, size in sorted(model_sizes.items()))
|
| 25 |
+
log.error(f"{label} vocab mismatch across checkpoint artifacts: {details}")
|
| 26 |
+
raise SystemExit(1)
|
| 27 |
+
|
| 28 |
+
model_vocab_size = unique_sizes[0]
|
| 29 |
+
tokenizer_vocab_size = len(tokenizer)
|
| 30 |
+
if model_vocab_size < tokenizer_vocab_size:
|
| 31 |
+
log.error(
|
| 32 |
+
f"{label} tokenizer length ({tokenizer_vocab_size:,}) exceeds "
|
| 33 |
+
f"the model vocab size ({model_vocab_size:,})."
|
| 34 |
+
)
|
| 35 |
+
log.error("Repair or regenerate the checkpoint before using it for distillation.")
|
| 36 |
+
raise SystemExit(1)
|
| 37 |
+
if model_vocab_size > tokenizer_vocab_size:
|
| 38 |
+
log.info(
|
| 39 |
+
f" {label} model vocab is padded beyond the tokenizer range: "
|
| 40 |
+
f"tokenizer={tokenizer_vocab_size:,}, model={model_vocab_size:,}"
|
| 41 |
+
)
|
| 42 |
+
return model_vocab_size
|
| 43 |
+
|
| 44 |
+
def validate_provenance(
|
| 45 |
+
prov_path: str,
|
| 46 |
+
data_path: str,
|
| 47 |
+
dataset,
|
| 48 |
+
teacher_tokenizer_contract: dict,
|
| 49 |
+
student_tokenizer_contract: dict,
|
| 50 |
+
log,
|
| 51 |
+
) -> None:
|
| 52 |
+
if not os.path.exists(prov_path):
|
| 53 |
+
log.error("Missing _provenance.json in the logits directory.")
|
| 54 |
+
log.error("Regenerate the current teacher-logit shard metadata.")
|
| 55 |
+
raise SystemExit(1)
|
| 56 |
+
|
| 57 |
+
with open(prov_path, "r", encoding="utf-8") as f:
|
| 58 |
+
prov = json.load(f)
|
| 59 |
+
|
| 60 |
+
schema_version = prov.get("schema_version")
|
| 61 |
+
if schema_version != PROVENANCE_SCHEMA_VERSION:
|
| 62 |
+
log.error(
|
| 63 |
+
f"Unsupported shard provenance schema: {schema_version!r}. "
|
| 64 |
+
f"Expected {PROVENANCE_SCHEMA_VERSION}."
|
| 65 |
+
)
|
| 66 |
+
log.error("Regenerate the teacher-logit shards.")
|
| 67 |
+
raise SystemExit(1)
|
| 68 |
+
|
| 69 |
+
teacher_meta = prov.get("teacher", {})
|
| 70 |
+
student_meta = prov.get("student", {})
|
| 71 |
+
current_data_sha = sha256_file(data_path)
|
| 72 |
+
actual_shard_count = sum(1 for _ in Path(prov_path).parent.glob("shard_*.pt"))
|
| 73 |
+
provenance_num_samples = prov.get("num_samples")
|
| 74 |
+
try:
|
| 75 |
+
provenance_num_samples_int = int(provenance_num_samples)
|
| 76 |
+
except (TypeError, ValueError):
|
| 77 |
+
log.error(f"PROVENANCE MISMATCH: num_samples is {provenance_num_samples!r}, expected an integer.")
|
| 78 |
+
log.error("Regenerate compatible teacher-logit shards.")
|
| 79 |
+
raise SystemExit(1)
|
| 80 |
+
if provenance_num_samples_int < len(dataset):
|
| 81 |
+
log.error(
|
| 82 |
+
f"PROVENANCE MISMATCH: num_samples is {provenance_num_samples_int}, "
|
| 83 |
+
f"but the requested dataset has {len(dataset)} samples."
|
| 84 |
+
)
|
| 85 |
+
log.error("Regenerate compatible teacher-logit shards.")
|
| 86 |
+
raise SystemExit(1)
|
| 87 |
+
if provenance_num_samples_int > len(dataset):
|
| 88 |
+
log.warning(
|
| 89 |
+
f" Provenance contains {provenance_num_samples_int:,} samples; "
|
| 90 |
+
f"training is using the first {len(dataset):,}. This is expected for smoke tests."
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
expected_pairs = [
|
| 94 |
+
("shard_count", prov.get("shard_count"), actual_shard_count),
|
| 95 |
+
("samples_per_shard", prov.get("samples_per_shard"), dataset.samples_per_shard),
|
| 96 |
+
("data_sha256", prov.get("data_sha256"), current_data_sha),
|
| 97 |
+
("max_seq_len", prov.get("max_seq_len"), cfg.data.max_seq_len),
|
| 98 |
+
("top_k", prov.get("top_k"), cfg.training.top_k),
|
| 99 |
+
("temperature", prov.get("temperature"), float(cfg.training.temperature)),
|
| 100 |
+
("teacher.model", teacher_meta.get("model"), cfg.model.teacher),
|
| 101 |
+
(
|
| 102 |
+
"teacher.revision",
|
| 103 |
+
teacher_meta.get("revision"),
|
| 104 |
+
canonical_revision(cfg.model.teacher_revision),
|
| 105 |
+
),
|
| 106 |
+
(
|
| 107 |
+
"teacher.tokenizer_size",
|
| 108 |
+
teacher_meta.get("tokenizer_size"),
|
| 109 |
+
teacher_tokenizer_contract["full_vocab_size"],
|
| 110 |
+
),
|
| 111 |
+
(
|
| 112 |
+
"teacher.tokenizer_fingerprint",
|
| 113 |
+
teacher_meta.get("tokenizer_fingerprint"),
|
| 114 |
+
teacher_tokenizer_contract["fingerprint"],
|
| 115 |
+
),
|
| 116 |
+
("student.model", student_meta.get("model"), getattr(cfg.model, "tokenizer", cfg.model.student)),
|
| 117 |
+
(
|
| 118 |
+
"student.revision",
|
| 119 |
+
student_meta.get("revision"),
|
| 120 |
+
canonical_revision(getattr(cfg.model, "tokenizer_revision", cfg.model.student_revision)),
|
| 121 |
+
),
|
| 122 |
+
(
|
| 123 |
+
"student.tokenizer_size",
|
| 124 |
+
student_meta.get("tokenizer_size"),
|
| 125 |
+
student_tokenizer_contract["full_vocab_size"],
|
| 126 |
+
),
|
| 127 |
+
(
|
| 128 |
+
"student.tokenizer_fingerprint",
|
| 129 |
+
student_meta.get("tokenizer_fingerprint"),
|
| 130 |
+
student_tokenizer_contract["fingerprint"],
|
| 131 |
+
),
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
warn_only_fields = {
|
| 135 |
+
"teacher.tokenizer_fingerprint",
|
| 136 |
+
"student.tokenizer_fingerprint",
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
for field_name, found, expected in expected_pairs:
|
| 140 |
+
if found != expected:
|
| 141 |
+
if field_name in warn_only_fields:
|
| 142 |
+
log.warning(
|
| 143 |
+
f" Provenance WARNING (non-fatal): {field_name} is {found!r}, "
|
| 144 |
+
f"expected {expected!r}. This is likely due to a transformers "
|
| 145 |
+
f"library version change. Continuing because vocab sizes match."
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
log.error(
|
| 149 |
+
f"PROVENANCE MISMATCH: {field_name} is {found!r}, expected {expected!r}."
|
| 150 |
+
)
|
| 151 |
+
log.error("Regenerate compatible teacher-logit shards.")
|
| 152 |
+
raise SystemExit(1)
|
| 153 |
+
|
| 154 |
+
provenance_data_path = prov.get("data_path")
|
| 155 |
+
current_data_path = str(Path(data_path).resolve())
|
| 156 |
+
if provenance_data_path != current_data_path:
|
| 157 |
+
log.warning(
|
| 158 |
+
" Provenance data_path differs because logits were likely generated on another machine: "
|
| 159 |
+
f"{provenance_data_path!r} vs {current_data_path!r}. "
|
| 160 |
+
"Continuing because data_sha256 matches."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
shard_schema = prov.get("shard_schema")
|
| 164 |
+
expected_shard_schema = build_shard_schema()
|
| 165 |
+
if shard_schema != expected_shard_schema:
|
| 166 |
+
log.error(
|
| 167 |
+
f"PROVENANCE MISMATCH: shard_schema is {shard_schema!r}, "
|
| 168 |
+
f"expected {expected_shard_schema!r}."
|
| 169 |
+
)
|
| 170 |
+
log.error("Regenerate compatible teacher-logit shards.")
|
| 171 |
+
raise SystemExit(1)
|
| 172 |
+
|
| 173 |
+
log.info(" Provenance: PASS (teacher shards match the current config and dataset)")
|
src/sequence_packing.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from bisect import bisect_left, insort
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
from src.training_data import DistillationDataset
|
| 10 |
+
|
| 11 |
+
class SequencePackedDataset(Dataset):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
source: DistillationDataset,
|
| 15 |
+
source_indices: list[int],
|
| 16 |
+
pack_length: int,
|
| 17 |
+
eos_token_id: int,
|
| 18 |
+
pad_token_id: int,
|
| 19 |
+
mask_first_after_separator: bool = True,
|
| 20 |
+
):
|
| 21 |
+
if pack_length <= 0:
|
| 22 |
+
raise ValueError(f"pack_length must be positive, got {pack_length}.")
|
| 23 |
+
if not hasattr(source, "sample_lengths"):
|
| 24 |
+
raise ValueError("Packed training requires a source dataset with sample_lengths metadata.")
|
| 25 |
+
if not source_indices:
|
| 26 |
+
raise ValueError("Packed training requires at least one source row.")
|
| 27 |
+
|
| 28 |
+
self.source = source
|
| 29 |
+
self.source_indices = [int(index) for index in source_indices]
|
| 30 |
+
self.source_index_set = set(self.source_indices)
|
| 31 |
+
if len(self.source_index_set) != len(self.source_indices):
|
| 32 |
+
raise ValueError("Packed training source indices contain duplicates.")
|
| 33 |
+
|
| 34 |
+
self.pack_length = int(pack_length)
|
| 35 |
+
self.eos_token_id = int(eos_token_id)
|
| 36 |
+
self.pad_token_id = int(pad_token_id)
|
| 37 |
+
self.mask_first_after_separator = bool(mask_first_after_separator)
|
| 38 |
+
self._length_by_index: dict[int, int] = {}
|
| 39 |
+
self.plan: list[list[int]] = []
|
| 40 |
+
|
| 41 |
+
for source_index in self.source_indices:
|
| 42 |
+
try:
|
| 43 |
+
length = int(source.sample_lengths[source_index])
|
| 44 |
+
except IndexError as exc:
|
| 45 |
+
raise IndexError(f"Source index {source_index} is outside the tokenized dataset.") from exc
|
| 46 |
+
if length > self.pack_length:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"Tokenized sample #{source_index} has length {length}, "
|
| 49 |
+
f"which exceeds pack_length={self.pack_length}."
|
| 50 |
+
)
|
| 51 |
+
self._length_by_index[source_index] = length
|
| 52 |
+
|
| 53 |
+
self._build_plan()
|
| 54 |
+
self._validate_plan()
|
| 55 |
+
self.source_sample_count = len(self.source_indices)
|
| 56 |
+
self.bin_count = len(self.plan)
|
| 57 |
+
self.original_token_count = sum(self._length_by_index.values())
|
| 58 |
+
self.separator_token_count = sum(max(0, len(bin_indices) - 1) for bin_indices in self.plan)
|
| 59 |
+
self.packed_token_count = self.original_token_count + self.separator_token_count
|
| 60 |
+
self.total_capacity = self.bin_count * self.pack_length
|
| 61 |
+
self.pad_token_count = self.total_capacity - self.packed_token_count
|
| 62 |
+
self.average_samples_per_bin = self.source_sample_count / max(self.bin_count, 1)
|
| 63 |
+
self.utilization = self.packed_token_count / max(self.total_capacity, 1)
|
| 64 |
+
|
| 65 |
+
def _build_plan(self) -> None:
|
| 66 |
+
items = sorted(
|
| 67 |
+
((self._length_by_index[source_index], source_index) for source_index in self.source_indices),
|
| 68 |
+
key=lambda item: (-item[0], item[1]),
|
| 69 |
+
)
|
| 70 |
+
available: list[tuple[int, int]] = []
|
| 71 |
+
|
| 72 |
+
for length, source_index in items:
|
| 73 |
+
required_existing = length + 1
|
| 74 |
+
insert_at = bisect_left(available, (required_existing, -1))
|
| 75 |
+
if insert_at == len(available):
|
| 76 |
+
bin_id = len(self.plan)
|
| 77 |
+
self.plan.append([source_index])
|
| 78 |
+
remaining = self.pack_length - length
|
| 79 |
+
insort(available, (remaining, bin_id))
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
remaining, bin_id = available.pop(insert_at)
|
| 83 |
+
next_remaining = remaining - required_existing
|
| 84 |
+
if next_remaining < 0:
|
| 85 |
+
raise ValueError("Internal packing error: bin capacity became negative.")
|
| 86 |
+
self.plan[bin_id].append(source_index)
|
| 87 |
+
insort(available, (next_remaining, bin_id))
|
| 88 |
+
|
| 89 |
+
def _validate_plan(self) -> None:
|
| 90 |
+
seen: set[int] = set()
|
| 91 |
+
for bin_id, bin_indices in enumerate(self.plan):
|
| 92 |
+
if not bin_indices:
|
| 93 |
+
raise ValueError(f"Packed bin #{bin_id} is empty.")
|
| 94 |
+
real_length = sum(self._length_by_index[source_index] for source_index in bin_indices)
|
| 95 |
+
real_length += max(0, len(bin_indices) - 1)
|
| 96 |
+
if real_length > self.pack_length:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
f"Packed bin #{bin_id} has real_length={real_length}, "
|
| 99 |
+
f"which exceeds pack_length={self.pack_length}."
|
| 100 |
+
)
|
| 101 |
+
for source_index in bin_indices:
|
| 102 |
+
if source_index in seen:
|
| 103 |
+
raise ValueError(f"Source sample #{source_index} appears in more than one packed bin.")
|
| 104 |
+
seen.add(source_index)
|
| 105 |
+
|
| 106 |
+
missing = self.source_index_set - seen
|
| 107 |
+
if missing:
|
| 108 |
+
first_missing = min(missing)
|
| 109 |
+
raise ValueError(f"Source sample #{first_missing} was not assigned to a packed bin.")
|
| 110 |
+
|
| 111 |
+
def __len__(self) -> int:
|
| 112 |
+
return len(self.plan)
|
| 113 |
+
|
| 114 |
+
def __getitem__(self, bin_idx: int) -> dict[str, torch.Tensor]:
|
| 115 |
+
bin_indices = self.plan[bin_idx]
|
| 116 |
+
input_parts: list[torch.Tensor] = []
|
| 117 |
+
mask_parts: list[torch.Tensor] = []
|
| 118 |
+
original_tokens = 0
|
| 119 |
+
separator_tokens = 0
|
| 120 |
+
|
| 121 |
+
for sample_offset, source_index in enumerate(bin_indices):
|
| 122 |
+
item = self.source[source_index]
|
| 123 |
+
input_ids = item["input_ids"].long()
|
| 124 |
+
loss_mask = item["loss_mask"].long()
|
| 125 |
+
original_tokens += int(input_ids.size(0))
|
| 126 |
+
|
| 127 |
+
if sample_offset > 0:
|
| 128 |
+
input_parts.append(torch.tensor([self.eos_token_id], dtype=torch.long))
|
| 129 |
+
mask_parts.append(torch.zeros(1, dtype=torch.long))
|
| 130 |
+
separator_tokens += 1
|
| 131 |
+
if self.mask_first_after_separator and loss_mask.numel() > 0:
|
| 132 |
+
loss_mask = loss_mask.clone()
|
| 133 |
+
loss_mask[0] = 0
|
| 134 |
+
|
| 135 |
+
input_parts.append(input_ids)
|
| 136 |
+
mask_parts.append(loss_mask)
|
| 137 |
+
|
| 138 |
+
input_ids = torch.cat(input_parts)
|
| 139 |
+
loss_mask = torch.cat(mask_parts)
|
| 140 |
+
real_length = int(input_ids.size(0))
|
| 141 |
+
if real_length > self.pack_length:
|
| 142 |
+
raise ValueError(
|
| 143 |
+
f"Packed bin #{bin_idx} has real_length={real_length}, "
|
| 144 |
+
f"which exceeds pack_length={self.pack_length}."
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
pad_len = self.pack_length - real_length
|
| 148 |
+
if pad_len:
|
| 149 |
+
input_ids = F.pad(input_ids, (0, pad_len), value=self.pad_token_id)
|
| 150 |
+
loss_mask = F.pad(loss_mask, (0, pad_len), value=0)
|
| 151 |
+
|
| 152 |
+
return {
|
| 153 |
+
"input_ids": input_ids,
|
| 154 |
+
"loss_mask": loss_mask,
|
| 155 |
+
"real_length": torch.tensor(real_length, dtype=torch.long),
|
| 156 |
+
"source_samples": torch.tensor(len(bin_indices), dtype=torch.long),
|
| 157 |
+
"original_tokens": torch.tensor(original_tokens, dtype=torch.long),
|
| 158 |
+
"separator_tokens": torch.tensor(separator_tokens, dtype=torch.long),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def collate_packed_fn(batch: list[dict], pad_token_id: int) -> dict:
|
| 162 |
+
del pad_token_id
|
| 163 |
+
input_ids = torch.stack([item["input_ids"] for item in batch])
|
| 164 |
+
loss_mask = torch.stack([item["loss_mask"] for item in batch]).long()
|
| 165 |
+
real_lengths = torch.stack([item["real_length"] for item in batch]).long()
|
| 166 |
+
|
| 167 |
+
seq_len = input_ids.size(1)
|
| 168 |
+
positions = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
|
| 169 |
+
attention_mask = (positions < real_lengths.unsqueeze(1)).long()
|
| 170 |
+
|
| 171 |
+
labels = input_ids.clone()
|
| 172 |
+
labels = labels.masked_fill(loss_mask == 0, -100)
|
| 173 |
+
|
| 174 |
+
return {
|
| 175 |
+
"input_ids": input_ids,
|
| 176 |
+
"attention_mask": attention_mask,
|
| 177 |
+
"loss_mask": loss_mask,
|
| 178 |
+
"labels": labels,
|
| 179 |
+
"real_length": real_lengths,
|
| 180 |
+
"source_samples": torch.stack([item["source_samples"] for item in batch]).long(),
|
| 181 |
+
"original_tokens": torch.stack([item["original_tokens"] for item in batch]).long(),
|
| 182 |
+
"separator_tokens": torch.stack([item["separator_tokens"] for item in batch]).long(),
|
| 183 |
+
}
|
src/train.py
ADDED
|
@@ -0,0 +1,1219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
from functools import partial
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 15 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.data import DataLoader, Dataset, Subset
|
| 19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup
|
| 20 |
+
|
| 21 |
+
from configs import cfg, emit_log_spacing, setup_logger
|
| 22 |
+
from src.checkpoints import (
|
| 23 |
+
find_latest_training_checkpoint,
|
| 24 |
+
load_trainer_state,
|
| 25 |
+
maybe_upload_checkpoint,
|
| 26 |
+
packing_checkpoint_metadata,
|
| 27 |
+
read_env_flag,
|
| 28 |
+
save_checkpoint,
|
| 29 |
+
validate_resume_packing_state,
|
| 30 |
+
)
|
| 31 |
+
from src.kd_contracts import build_tokenizer_contract
|
| 32 |
+
from src.losses import compute_loss_for_phase
|
| 33 |
+
from src.optim import build_adamw_optimizer
|
| 34 |
+
from src.provenance import resolve_model_vocab_size, validate_provenance
|
| 35 |
+
from src.sequence_packing import SequencePackedDataset, collate_packed_fn
|
| 36 |
+
from src.training_data import (
|
| 37 |
+
DistillationDataset,
|
| 38 |
+
collate_fn,
|
| 39 |
+
extract_shard_id_range,
|
| 40 |
+
move_batch_to_device,
|
| 41 |
+
resolve_dataloader_runtime,
|
| 42 |
+
torch_load_cpu,
|
| 43 |
+
)
|
| 44 |
+
from src.training_schedule import (
|
| 45 |
+
build_train_validation_subsets,
|
| 46 |
+
compute_training_schedule,
|
| 47 |
+
load_deepspeed_runtime_config,
|
| 48 |
+
)
|
| 49 |
+
from src.transformers_compat import format_model_load_error, resolve_attention_backend
|
| 50 |
+
from src.validation import evaluate_validation_loss
|
| 51 |
+
|
| 52 |
+
def _log_gpu(logger) -> None:
|
| 53 |
+
if torch.cuda.is_available():
|
| 54 |
+
device = torch.cuda.current_device()
|
| 55 |
+
alloc = torch.cuda.max_memory_allocated(device) / (1024**3)
|
| 56 |
+
reserved = torch.cuda.max_memory_reserved(device) / (1024**3)
|
| 57 |
+
total = torch.cuda.get_device_properties(device).total_memory / (1024**3)
|
| 58 |
+
pct = alloc / total * 100
|
| 59 |
+
logger.info(f"[GPU] {alloc:.1f}/{total:.0f} GiB ({pct:.0f}%) peak alloc, {reserved:.1f} GiB peak reserved")
|
| 60 |
+
|
| 61 |
+
def main() -> None:
|
| 62 |
+
parser = argparse.ArgumentParser(description="Quintus training (SFT / KD)")
|
| 63 |
+
packing_cfg = getattr(cfg.training, "sequence_packing", None)
|
| 64 |
+
sequence_packing_default = bool(getattr(packing_cfg, "enabled", False))
|
| 65 |
+
pack_length_default = int(getattr(packing_cfg, "pack_length", cfg.data.max_seq_len))
|
| 66 |
+
mask_first_after_separator = bool(getattr(packing_cfg, "mask_first_token_after_separator", True))
|
| 67 |
+
parser.add_argument("--num_samples", type=int, default=cfg.data.num_samples)
|
| 68 |
+
parser.add_argument("--phase", type=str, choices=["sft", "kd", "online_kd"], default="online_kd", help="Training phase")
|
| 69 |
+
parser.add_argument("--resume_from_checkpoint", action="store_true", help="Resume from latest epoch in current output directory")
|
| 70 |
+
parser.add_argument("--init_from_checkpoint", type=str, default=None, help="Initialize weights from a specific path before training")
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--compile_model",
|
| 73 |
+
action="store_true",
|
| 74 |
+
default=bool(getattr(cfg.training, "compile_model", False)),
|
| 75 |
+
help="Enable torch.compile after checkpoint loading. Off by default for KD memory safety.",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument("--local_rank", type=int, default=-1, help=argparse.SUPPRESS)
|
| 78 |
+
parser.add_argument("--deepspeed", type=str, default=None, help="Enable DeepSpeed with the given config path.")
|
| 79 |
+
parser.add_argument("--no_deepspeed", action="store_true", help="Run without DeepSpeed.")
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--allow_partial_final_window",
|
| 82 |
+
action="store_true",
|
| 83 |
+
help="Allow DeepSpeed to drop a final incomplete accumulation window during smoke tests.",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument("--teacher_model", type=str, default=cfg.model.teacher)
|
| 86 |
+
parser.add_argument("--teacher_revision", type=str, default=cfg.model.teacher_revision)
|
| 87 |
+
parser.add_argument("--student_model", type=str, default=cfg.model.student)
|
| 88 |
+
parser.add_argument("--student_revision", type=str, default=cfg.model.student_revision)
|
| 89 |
+
parser.add_argument("--tokenizer_model", type=str, default=getattr(cfg.model, "tokenizer", cfg.model.student))
|
| 90 |
+
parser.add_argument("--tokenizer_revision", type=str, default=getattr(cfg.model, "tokenizer_revision", cfg.model.student_revision))
|
| 91 |
+
parser.add_argument("--student_dir", type=str, default=cfg.paths.student_dir)
|
| 92 |
+
parser.add_argument("--tokenizer_dir", type=str, default=getattr(cfg.paths, "tokenizer_dir", cfg.paths.student_dir))
|
| 93 |
+
parser.add_argument("--distilled_dir", type=str, default=cfg.paths.distilled_dir)
|
| 94 |
+
parser.add_argument("--num_epochs", type=int, default=cfg.training.num_epochs)
|
| 95 |
+
parser.add_argument("--max_steps", type=int, default=-1, help="Stop after this many optimizer steps. -1 = no limit.")
|
| 96 |
+
parser.add_argument("--learning_rate", type=float, default=float(cfg.training.learning_rate))
|
| 97 |
+
parser.add_argument("--alpha", type=float, default=cfg.training.alpha)
|
| 98 |
+
parser.add_argument("--temperature", type=float, default=cfg.training.temperature)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--online_kd_token_chunk_size",
|
| 101 |
+
type=int,
|
| 102 |
+
default=int(getattr(cfg.training, "online_kd_token_chunk_size", 2048)),
|
| 103 |
+
help="Token chunk size for full-vocabulary online KD loss.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument("--micro_batch_size", type=int, default=cfg.training.micro_batch_size)
|
| 106 |
+
parser.add_argument("--grad_accum_steps", type=int, default=cfg.training.grad_accum_steps)
|
| 107 |
+
parser.add_argument("--sequence_packing", action="store_true", default=False, help="Enable sequence packing for online_kd.")
|
| 108 |
+
parser.add_argument("--no_sequence_packing", action="store_true", default=False, help="Disable sequence packing.")
|
| 109 |
+
parser.add_argument("--pack_length", type=int, default=None, help="Packed sequence length.")
|
| 110 |
+
parser.add_argument("--disable_checkpointing", action="store_true", default=False, help="Disable intermediate epoch/step/best checkpoint saves.")
|
| 111 |
+
parser.add_argument("--gradient_checkpointing", action="store_true", default=bool(cfg.training.gradient_checkpointing), help="Enable gradient checkpointing (activation checkpointing).")
|
| 112 |
+
parser.add_argument("--upload_kd_checkpoints", action="store_true", default=False)
|
| 113 |
+
parser.add_argument("--upload_step_checkpoints", action="store_true", default=False)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--upload_last_checkpoint",
|
| 116 |
+
action="store_true",
|
| 117 |
+
default=False,
|
| 118 |
+
help="Upload the final 'last' checkpoint to the Hub. Off by default.",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--hub_upload_strict",
|
| 122 |
+
action="store_true",
|
| 123 |
+
default=read_env_flag("QUINTUS_HUB_UPLOAD_STRICT", False),
|
| 124 |
+
help="Fail training if a requested Hub checkpoint upload fails.",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument("--hub_repo_id", type=str, default=f"{cfg.hub.username}/{cfg.hub.repo_name}")
|
| 127 |
+
parser.add_argument("--ckpt_path_in_repo", type=str, default="models/online_kd_8b_17b_ep1_B200_20260608_alpha0.3")
|
| 128 |
+
parser.add_argument("--commit_message_prefix", type=str, default="Online KD 8B->1.7B B200 Run (alpha=0.3)")
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
|
| 131 |
+
if args.sequence_packing and args.no_sequence_packing:
|
| 132 |
+
parser.error("Use either --sequence_packing or --no_sequence_packing, not both.")
|
| 133 |
+
sequence_packing_enabled = sequence_packing_default
|
| 134 |
+
if args.sequence_packing:
|
| 135 |
+
sequence_packing_enabled = True
|
| 136 |
+
elif args.no_sequence_packing:
|
| 137 |
+
sequence_packing_enabled = False
|
| 138 |
+
pack_length = int(args.pack_length if args.pack_length is not None else pack_length_default)
|
| 139 |
+
if pack_length <= 0:
|
| 140 |
+
parser.error(f"--pack_length must be positive, got {pack_length}.")
|
| 141 |
+
if pack_length > int(cfg.data.max_seq_len):
|
| 142 |
+
parser.error(f"--pack_length must be <= data.max_seq_len ({int(cfg.data.max_seq_len)}), got {pack_length}.")
|
| 143 |
+
if sequence_packing_enabled and args.phase != "online_kd":
|
| 144 |
+
parser.error("--sequence_packing is supported only with --phase online_kd.")
|
| 145 |
+
if args.online_kd_token_chunk_size <= 0:
|
| 146 |
+
parser.error(
|
| 147 |
+
f"--online_kd_token_chunk_size must be positive, got {args.online_kd_token_chunk_size}."
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
cfg.model.teacher = args.teacher_model
|
| 151 |
+
cfg.model.teacher_revision = args.teacher_revision
|
| 152 |
+
cfg.model.student = args.student_model
|
| 153 |
+
cfg.model.student_revision = args.student_revision
|
| 154 |
+
cfg.model.tokenizer = args.tokenizer_model
|
| 155 |
+
cfg.model.tokenizer_revision = args.tokenizer_revision
|
| 156 |
+
cfg.paths.student_dir = args.student_dir
|
| 157 |
+
cfg.paths.tokenizer_dir = args.tokenizer_dir
|
| 158 |
+
cfg.paths.distilled_dir = args.distilled_dir
|
| 159 |
+
cfg.training.num_epochs = args.num_epochs
|
| 160 |
+
cfg.training.learning_rate = args.learning_rate
|
| 161 |
+
cfg.training.alpha = args.alpha
|
| 162 |
+
cfg.training.temperature = args.temperature
|
| 163 |
+
cfg.training.online_kd_token_chunk_size = int(args.online_kd_token_chunk_size)
|
| 164 |
+
cfg.training.micro_batch_size = args.micro_batch_size
|
| 165 |
+
cfg.training.grad_accum_steps = args.grad_accum_steps
|
| 166 |
+
cfg.training.gradient_checkpointing = args.gradient_checkpointing
|
| 167 |
+
cfg.training.disable_checkpointing = args.disable_checkpointing
|
| 168 |
+
cfg.training.sequence_packing.enabled = sequence_packing_enabled
|
| 169 |
+
cfg.training.sequence_packing.pack_length = pack_length
|
| 170 |
+
cfg.training.sequence_packing.mask_first_token_after_separator = mask_first_after_separator
|
| 171 |
+
cfg.data.num_samples = args.num_samples
|
| 172 |
+
|
| 173 |
+
from omegaconf import OmegaConf
|
| 174 |
+
if not hasattr(cfg, "hub"):
|
| 175 |
+
cfg.hub = OmegaConf.create()
|
| 176 |
+
cfg.hub.upload_kd_checkpoints = args.upload_kd_checkpoints
|
| 177 |
+
cfg.hub.upload_step_checkpoints = args.upload_step_checkpoints
|
| 178 |
+
cfg.hub.upload_last_checkpoint = args.upload_last_checkpoint
|
| 179 |
+
cfg.hub.hub_upload_strict = args.hub_upload_strict
|
| 180 |
+
cfg.hub.repo_id = args.hub_repo_id
|
| 181 |
+
cfg.hub.ckpt_path_in_repo = args.ckpt_path_in_repo
|
| 182 |
+
cfg.hub.commit_message_prefix = args.commit_message_prefix
|
| 183 |
+
|
| 184 |
+
rank = int(os.environ.get("LOCAL_RANK", args.local_rank))
|
| 185 |
+
log = setup_logger("TRAIN", rank=rank)
|
| 186 |
+
|
| 187 |
+
log.info("=" * 70)
|
| 188 |
+
log.info("Quintus Training")
|
| 189 |
+
log.info("=" * 70)
|
| 190 |
+
tokenizer_dir = getattr(cfg.paths, "tokenizer_dir", cfg.paths.student_dir)
|
| 191 |
+
tokenizer_model = getattr(cfg.model, "tokenizer", cfg.model.student)
|
| 192 |
+
|
| 193 |
+
log.info(f" Student: {cfg.paths.student_dir}")
|
| 194 |
+
log.info(f" Student id: {cfg.model.student}")
|
| 195 |
+
log.info(f" Tokenizer: {tokenizer_dir}")
|
| 196 |
+
log.info(f" Tokenizer id:{tokenizer_model}")
|
| 197 |
+
log.info(f" Num samples: {args.num_samples:,}")
|
| 198 |
+
log.info(f" Epochs: {cfg.training.num_epochs}")
|
| 199 |
+
log.info(f" LR: {cfg.training.learning_rate}")
|
| 200 |
+
log.info(f" Phase: {args.phase}")
|
| 201 |
+
if args.phase in ("kd", "online_kd"):
|
| 202 |
+
log.info(f" CE weight: {cfg.training.alpha}")
|
| 203 |
+
log.info(f" Temperature: {cfg.training.temperature}")
|
| 204 |
+
if args.phase == "online_kd":
|
| 205 |
+
log.info(f" KD chunk: {cfg.training.online_kd_token_chunk_size} tokens")
|
| 206 |
+
log.info(f" Micro batch: {cfg.training.micro_batch_size}")
|
| 207 |
+
log.info(f" Grad accum: {cfg.training.grad_accum_steps}")
|
| 208 |
+
log.info(f" Eff. batch: {cfg.training.micro_batch_size * cfg.training.grad_accum_steps}")
|
| 209 |
+
log.info(f" Val ratio: {cfg.training.validation_ratio:.2%}")
|
| 210 |
+
log.info(f" Remote code: {cfg.model.allow_remote_code}")
|
| 211 |
+
log.info(f" Output dir: {cfg.paths.distilled_dir}")
|
| 212 |
+
log.info(f" Log file: {cfg.paths.log_file}")
|
| 213 |
+
log.info(f" Fused AdamW: {bool(getattr(cfg.training, 'fused_adamw', False))}")
|
| 214 |
+
log.info(
|
| 215 |
+
f" HF upload: regular={cfg.hub.upload_kd_checkpoints} "
|
| 216 |
+
f"steps={cfg.hub.upload_step_checkpoints} "
|
| 217 |
+
f"last={cfg.hub.upload_last_checkpoint} "
|
| 218 |
+
f"strict={cfg.hub.hub_upload_strict}"
|
| 219 |
+
)
|
| 220 |
+
log.info(
|
| 221 |
+
f" HF target: {cfg.hub.repo_id}/"
|
| 222 |
+
f"{cfg.hub.ckpt_path_in_repo}"
|
| 223 |
+
)
|
| 224 |
+
if torch.cuda.is_available():
|
| 225 |
+
log.info(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
t_dir = tokenizer_dir
|
| 229 |
+
if not os.path.exists(t_dir):
|
| 230 |
+
log.warning(f"Tokenizer directory '{t_dir}' not found. Falling back to downloading '{tokenizer_model}' from HF Hub.")
|
| 231 |
+
t_dir = tokenizer_model
|
| 232 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 233 |
+
t_dir,
|
| 234 |
+
trust_remote_code=cfg.model.allow_remote_code,
|
| 235 |
+
)
|
| 236 |
+
except Exception as exc:
|
| 237 |
+
log.error(format_model_load_error("Student tokenizer load", exc))
|
| 238 |
+
sys.exit(1)
|
| 239 |
+
if tokenizer.pad_token is None:
|
| 240 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 241 |
+
if sequence_packing_enabled:
|
| 242 |
+
if tokenizer.eos_token_id is None:
|
| 243 |
+
log.error("Sequence packing requires tokenizer.eos_token_id.")
|
| 244 |
+
sys.exit(1)
|
| 245 |
+
if tokenizer.pad_token_id is None:
|
| 246 |
+
log.error("Sequence packing requires tokenizer.pad_token_id.")
|
| 247 |
+
sys.exit(1)
|
| 248 |
+
student_tokenizer_contract = build_tokenizer_contract(tokenizer)
|
| 249 |
+
student_tokenizer_vocab_size = student_tokenizer_contract["full_vocab_size"]
|
| 250 |
+
|
| 251 |
+
if args.phase == "kd":
|
| 252 |
+
_prov_path_for_teacher = os.path.join(cfg.paths.logits_dir, "_provenance.json")
|
| 253 |
+
if os.path.exists(_prov_path_for_teacher):
|
| 254 |
+
with open(_prov_path_for_teacher, "r", encoding="utf-8") as _pf:
|
| 255 |
+
_prov_data = json.load(_pf)
|
| 256 |
+
_teacher_prov = _prov_data.get("teacher", {})
|
| 257 |
+
teacher_tokenizer_contract = {
|
| 258 |
+
"full_vocab_size": _teacher_prov.get("tokenizer_size"),
|
| 259 |
+
"fingerprint": _teacher_prov.get("tokenizer_fingerprint"),
|
| 260 |
+
}
|
| 261 |
+
log.info(
|
| 262 |
+
f" Teacher contract read from provenance: "
|
| 263 |
+
f"vocab={teacher_tokenizer_contract['full_vocab_size']}, "
|
| 264 |
+
f"fingerprint={teacher_tokenizer_contract['fingerprint'][:12]}..."
|
| 265 |
+
)
|
| 266 |
+
else:
|
| 267 |
+
try:
|
| 268 |
+
teacher_tokenizer = AutoTokenizer.from_pretrained(
|
| 269 |
+
cfg.paths.teacher_dir if os.path.exists(cfg.paths.teacher_dir) else cfg.model.teacher,
|
| 270 |
+
trust_remote_code=cfg.model.allow_remote_code,
|
| 271 |
+
)
|
| 272 |
+
except Exception as exc:
|
| 273 |
+
log.error(format_model_load_error("Teacher tokenizer load", exc))
|
| 274 |
+
sys.exit(1)
|
| 275 |
+
teacher_tokenizer_contract = build_tokenizer_contract(teacher_tokenizer)
|
| 276 |
+
del teacher_tokenizer
|
| 277 |
+
else:
|
| 278 |
+
teacher_tokenizer_contract = None
|
| 279 |
+
|
| 280 |
+
attn_impl = resolve_attention_backend(log)
|
| 281 |
+
log.info(f" Attention: {attn_impl}")
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
from liger_kernel.transformers import apply_liger_kernel_to_qwen3
|
| 285 |
+
apply_liger_kernel_to_qwen3(
|
| 286 |
+
rope=True,
|
| 287 |
+
swiglu=True,
|
| 288 |
+
rms_norm=True,
|
| 289 |
+
cross_entropy=False,
|
| 290 |
+
fused_linear_cross_entropy=False,
|
| 291 |
+
)
|
| 292 |
+
log.info(" Liger: enabled")
|
| 293 |
+
except ImportError:
|
| 294 |
+
if cfg.training.micro_batch_size >= 6:
|
| 295 |
+
log.error(" Liger: missing; install liger-kernel or lower micro_batch_size.")
|
| 296 |
+
raise RuntimeError("liger_kernel is required for micro_batch_size >= 6.")
|
| 297 |
+
else:
|
| 298 |
+
log.warning(" Liger: not installed")
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
s_dir = cfg.paths.student_dir
|
| 302 |
+
if not os.path.exists(s_dir):
|
| 303 |
+
log.warning(f"Student model directory '{s_dir}' not found. Falling back to downloading '{cfg.model.student}' from HF Hub.")
|
| 304 |
+
s_dir = cfg.model.student
|
| 305 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 306 |
+
s_dir,
|
| 307 |
+
dtype=torch.bfloat16,
|
| 308 |
+
low_cpu_mem_usage=True,
|
| 309 |
+
trust_remote_code=cfg.model.allow_remote_code,
|
| 310 |
+
attn_implementation=attn_impl,
|
| 311 |
+
)
|
| 312 |
+
except Exception as exc:
|
| 313 |
+
log.error(format_model_load_error("Student model load", exc))
|
| 314 |
+
sys.exit(1)
|
| 315 |
+
model.config.use_cache = False
|
| 316 |
+
if getattr(cfg.training, "gradient_checkpointing", False):
|
| 317 |
+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
| 318 |
+
log.info(" Grad ckpt: enabled")
|
| 319 |
+
else:
|
| 320 |
+
log.info(" Grad ckpt: disabled")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
start_epoch = 0
|
| 325 |
+
resume_state: dict = {}
|
| 326 |
+
if args.resume_from_checkpoint and args.init_from_checkpoint:
|
| 327 |
+
log.error("Use either --init_from_checkpoint or --resume_from_checkpoint, not both.")
|
| 328 |
+
sys.exit(1)
|
| 329 |
+
|
| 330 |
+
checkpoint_to_load = args.init_from_checkpoint
|
| 331 |
+
|
| 332 |
+
if args.resume_from_checkpoint:
|
| 333 |
+
latest_ckpt = find_latest_training_checkpoint(cfg.paths.distilled_dir)
|
| 334 |
+
if latest_ckpt is None:
|
| 335 |
+
log.error(
|
| 336 |
+
f"--resume_from_checkpoint was set, but no epoch_* or step_* checkpoints were found in "
|
| 337 |
+
f"{cfg.paths.distilled_dir}. Use --init_from_checkpoint for the first KD run."
|
| 338 |
+
)
|
| 339 |
+
sys.exit(1)
|
| 340 |
+
checkpoint_to_load = latest_ckpt
|
| 341 |
+
resume_state = load_trainer_state(latest_ckpt, log)
|
| 342 |
+
checkpoint_type = resume_state.get("checkpoint_type", os.path.basename(latest_ckpt).split("_")[0])
|
| 343 |
+
start_epoch = int(resume_state.get("start_epoch", 0) or 0)
|
| 344 |
+
if checkpoint_type == "epoch":
|
| 345 |
+
log.info(f"Interrupted run detected. Resuming after completed epoch {start_epoch}")
|
| 346 |
+
else:
|
| 347 |
+
log.info(
|
| 348 |
+
f"Interrupted run detected. Resuming from {os.path.basename(latest_ckpt)} "
|
| 349 |
+
f"at epoch_index={start_epoch}, next_batch_in_epoch="
|
| 350 |
+
f"{int(resume_state.get('next_batch_in_epoch', 0) or 0)}"
|
| 351 |
+
)
|
| 352 |
+
validate_resume_packing_state(
|
| 353 |
+
resume_state,
|
| 354 |
+
enabled=sequence_packing_enabled,
|
| 355 |
+
pack_length=pack_length,
|
| 356 |
+
max_seq_len=int(cfg.data.max_seq_len),
|
| 357 |
+
log=log,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
if checkpoint_to_load:
|
| 361 |
+
log.info(f"Loading weights from: {checkpoint_to_load}")
|
| 362 |
+
try:
|
| 363 |
+
from safetensors.torch import load_file
|
| 364 |
+
ckpt_file = os.path.join(checkpoint_to_load, "model.safetensors")
|
| 365 |
+
if not os.path.exists(ckpt_file):
|
| 366 |
+
ckpt_file = os.path.join(checkpoint_to_load, "pytorch_model.bin")
|
| 367 |
+
|
| 368 |
+
if ckpt_file.endswith(".safetensors"):
|
| 369 |
+
state_dict = load_file(ckpt_file)
|
| 370 |
+
else:
|
| 371 |
+
state_dict = torch.load(ckpt_file, map_location="cpu")
|
| 372 |
+
|
| 373 |
+
new_state_dict = {}
|
| 374 |
+
for k, v in state_dict.items():
|
| 375 |
+
if k.startswith("_orig_mod."):
|
| 376 |
+
new_state_dict[k[len("_orig_mod."):]] = v
|
| 377 |
+
else:
|
| 378 |
+
new_state_dict[k] = v
|
| 379 |
+
|
| 380 |
+
model.load_state_dict(new_state_dict)
|
| 381 |
+
log.info("Weights loaded.")
|
| 382 |
+
except Exception as e:
|
| 383 |
+
log.error(f"Failed to load weights: {e}")
|
| 384 |
+
sys.exit(1)
|
| 385 |
+
|
| 386 |
+
model.train()
|
| 387 |
+
|
| 388 |
+
if args.compile_model:
|
| 389 |
+
log.info(" Compile: enabled")
|
| 390 |
+
model = torch.compile(model, dynamic=True)
|
| 391 |
+
else:
|
| 392 |
+
log.info(" Compile: disabled")
|
| 393 |
+
|
| 394 |
+
torch.set_float32_matmul_precision("high")
|
| 395 |
+
|
| 396 |
+
student_model_vocab_size = resolve_model_vocab_size(model, tokenizer, "Student", log)
|
| 397 |
+
log.info(
|
| 398 |
+
f" Student V: tokenizer={student_tokenizer_vocab_size:,} "
|
| 399 |
+
f"model={student_model_vocab_size:,}"
|
| 400 |
+
)
|
| 401 |
+
_log_gpu(log)
|
| 402 |
+
|
| 403 |
+
if args.phase == "kd":
|
| 404 |
+
shard0 = os.path.join(cfg.paths.logits_dir, "shard_000000.pt")
|
| 405 |
+
if os.path.exists(shard0):
|
| 406 |
+
test_shard = torch_load_cpu(shard0)
|
| 407 |
+
try:
|
| 408 |
+
min_id, max_id = extract_shard_id_range(test_shard, shard0)
|
| 409 |
+
except (KeyError, ValueError) as exc:
|
| 410 |
+
log.error(str(exc))
|
| 411 |
+
sys.exit(1)
|
| 412 |
+
if min_id < 0:
|
| 413 |
+
log.error(f" Negative IDs (min={min_id}); int16 overflow.")
|
| 414 |
+
log.error(" Regenerate shards.")
|
| 415 |
+
sys.exit(1)
|
| 416 |
+
if max_id >= student_tokenizer_vocab_size:
|
| 417 |
+
log.error(
|
| 418 |
+
f"VOCAB MISMATCH: shard max_id={max_id} >= "
|
| 419 |
+
f"student tokenizer vocab={student_tokenizer_vocab_size}"
|
| 420 |
+
)
|
| 421 |
+
sys.exit(1)
|
| 422 |
+
log.info(
|
| 423 |
+
f" Vocab check: PASS (ids in [{min_id}, {max_id}], "
|
| 424 |
+
f"reachable tokenizer V={student_tokenizer_vocab_size})"
|
| 425 |
+
)
|
| 426 |
+
else:
|
| 427 |
+
log.warning(f" Shard {shard0} not found; skipping vocab check")
|
| 428 |
+
|
| 429 |
+
data_path = os.path.join(cfg.paths.tokenized_dir, "train.jsonl")
|
| 430 |
+
dataset = DistillationDataset(data_path, cfg.paths.logits_dir, cfg.data.max_seq_len, args.num_samples, args.phase)
|
| 431 |
+
log.info(f" Dataset: {len(dataset):,} samples")
|
| 432 |
+
|
| 433 |
+
if args.phase == "kd":
|
| 434 |
+
prov_path = os.path.join(cfg.paths.logits_dir, "_provenance.json")
|
| 435 |
+
validate_provenance(
|
| 436 |
+
prov_path=prov_path,
|
| 437 |
+
data_path=data_path,
|
| 438 |
+
dataset=dataset,
|
| 439 |
+
teacher_tokenizer_contract=teacher_tokenizer_contract,
|
| 440 |
+
student_tokenizer_contract=student_tokenizer_contract,
|
| 441 |
+
log=log,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
pad_id = tokenizer.pad_token_id
|
| 445 |
+
if args.no_deepspeed:
|
| 446 |
+
args.deepspeed = None
|
| 447 |
+
use_ds = args.deepspeed is not None
|
| 448 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 449 |
+
if world_size != 1:
|
| 450 |
+
log.error("This training path is single-GPU only. Re-run with NUM_GPUS=1.")
|
| 451 |
+
sys.exit(1)
|
| 452 |
+
is_main = rank in (-1, 0)
|
| 453 |
+
|
| 454 |
+
ds_runtime_config = None
|
| 455 |
+
if use_ds:
|
| 456 |
+
try:
|
| 457 |
+
ds_runtime_config = load_deepspeed_runtime_config(
|
| 458 |
+
args.deepspeed,
|
| 459 |
+
micro_batch_size=cfg.training.micro_batch_size,
|
| 460 |
+
grad_accum=cfg.training.grad_accum_steps,
|
| 461 |
+
)
|
| 462 |
+
except (OSError, ValueError, json.JSONDecodeError) as exc:
|
| 463 |
+
log.error(str(exc))
|
| 464 |
+
sys.exit(1)
|
| 465 |
+
|
| 466 |
+
train_dataset, val_dataset, split_meta = build_train_validation_subsets(
|
| 467 |
+
dataset=dataset,
|
| 468 |
+
validation_ratio=float(cfg.training.validation_ratio),
|
| 469 |
+
split_seed=int(cfg.training.split_seed),
|
| 470 |
+
micro_batch_size=cfg.training.micro_batch_size,
|
| 471 |
+
grad_accum=cfg.training.grad_accum_steps,
|
| 472 |
+
num_epochs=cfg.training.num_epochs,
|
| 473 |
+
use_ds=use_ds,
|
| 474 |
+
)
|
| 475 |
+
log.info(
|
| 476 |
+
f" Train split: {len(train_dataset):,} samples | "
|
| 477 |
+
f"Val split: {int(split_meta['validation_size']):,} samples"
|
| 478 |
+
)
|
| 479 |
+
if bool(split_meta["accumulation_aligned"]):
|
| 480 |
+
log.info(
|
| 481 |
+
f" Accum align: train split is divisible by effective batch "
|
| 482 |
+
f"{int(split_meta['effective_batch_size']):,}"
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
if use_ds:
|
| 486 |
+
log.warning(
|
| 487 |
+
f" Accum align: train split leaves "
|
| 488 |
+
f"{int(split_meta['train_remainder_batches'])} partial accumulation batches per epoch; "
|
| 489 |
+
"DeepSpeed will carry partial accumulation across epoch boundaries"
|
| 490 |
+
)
|
| 491 |
+
else:
|
| 492 |
+
log.warning(
|
| 493 |
+
f" Accum align: train split leaves "
|
| 494 |
+
f"{int(split_meta['train_remainder_batches'])} partial accumulation batches per epoch; "
|
| 495 |
+
"the fallback flush path will rescale gradients correctly"
|
| 496 |
+
)
|
| 497 |
+
if bool(split_meta["adjusted"]):
|
| 498 |
+
log.info(
|
| 499 |
+
f" Val align: requested {int(split_meta['requested_validation_size']):,} "
|
| 500 |
+
f"({float(split_meta['requested_validation_ratio']) * 100:.2f}%), "
|
| 501 |
+
f"using {int(split_meta['validation_size']):,} "
|
| 502 |
+
f"({float(split_meta['actual_validation_ratio']) * 100:.2f}%) "
|
| 503 |
+
"to preserve the training schedule"
|
| 504 |
+
)
|
| 505 |
+
elif val_dataset is not None:
|
| 506 |
+
log.info(
|
| 507 |
+
f" Val split: using {float(split_meta['actual_validation_ratio']) * 100:.2f}% "
|
| 508 |
+
f"held out with split_seed={cfg.training.split_seed}"
|
| 509 |
+
)
|
| 510 |
+
else:
|
| 511 |
+
log.warning(" Validation disabled; tracking training loss.")
|
| 512 |
+
|
| 513 |
+
effective_train_dataset: Dataset = train_dataset
|
| 514 |
+
train_collate = partial(collate_fn, pad_token_id=pad_id)
|
| 515 |
+
val_collate = partial(collate_fn, pad_token_id=pad_id)
|
| 516 |
+
if sequence_packing_enabled:
|
| 517 |
+
if isinstance(train_dataset, Subset):
|
| 518 |
+
source_dataset = train_dataset.dataset
|
| 519 |
+
train_source_indices = [int(index) for index in train_dataset.indices]
|
| 520 |
+
else:
|
| 521 |
+
source_dataset = train_dataset
|
| 522 |
+
train_source_indices = list(range(len(train_dataset)))
|
| 523 |
+
|
| 524 |
+
if not isinstance(source_dataset, DistillationDataset):
|
| 525 |
+
log.error("Sequence packing requires DistillationDataset as the split source.")
|
| 526 |
+
sys.exit(1)
|
| 527 |
+
|
| 528 |
+
val_source_indices: set[int] = set()
|
| 529 |
+
if isinstance(val_dataset, Subset) and val_dataset.dataset is source_dataset:
|
| 530 |
+
val_source_indices = {int(index) for index in val_dataset.indices}
|
| 531 |
+
|
| 532 |
+
try:
|
| 533 |
+
packed_train_dataset = SequencePackedDataset(
|
| 534 |
+
source=source_dataset,
|
| 535 |
+
source_indices=train_source_indices,
|
| 536 |
+
pack_length=pack_length,
|
| 537 |
+
eos_token_id=int(tokenizer.eos_token_id),
|
| 538 |
+
pad_token_id=int(tokenizer.pad_token_id),
|
| 539 |
+
mask_first_after_separator=mask_first_after_separator,
|
| 540 |
+
)
|
| 541 |
+
except (IndexError, ValueError) as exc:
|
| 542 |
+
log.error(str(exc))
|
| 543 |
+
sys.exit(1)
|
| 544 |
+
|
| 545 |
+
overlap = packed_train_dataset.source_index_set.intersection(val_source_indices)
|
| 546 |
+
if overlap:
|
| 547 |
+
first_overlap = min(overlap)
|
| 548 |
+
log.error(f"Sequence packing split error: validation sample #{first_overlap} appears in training bins.")
|
| 549 |
+
sys.exit(1)
|
| 550 |
+
|
| 551 |
+
effective_train_dataset = packed_train_dataset
|
| 552 |
+
train_collate = partial(collate_packed_fn, pad_token_id=pad_id)
|
| 553 |
+
log.info(" Packing: enabled")
|
| 554 |
+
log.info(f" Pack length: {packed_train_dataset.pack_length:,}")
|
| 555 |
+
log.info(f" Train bins: {packed_train_dataset.bin_count:,}")
|
| 556 |
+
log.info(f" Train rows: {packed_train_dataset.source_sample_count:,}")
|
| 557 |
+
log.info(f" Avg samples: {packed_train_dataset.average_samples_per_bin:.2f} per bin")
|
| 558 |
+
log.info(f" Original tokens: {packed_train_dataset.original_token_count:,}")
|
| 559 |
+
log.info(f" Separator tokens: {packed_train_dataset.separator_token_count:,}")
|
| 560 |
+
log.info(f" Pad tokens: {packed_train_dataset.pad_token_count:,}")
|
| 561 |
+
log.info(f" Utilization: {packed_train_dataset.utilization * 100:.1f}%")
|
| 562 |
+
else:
|
| 563 |
+
log.info(" Packing: disabled")
|
| 564 |
+
|
| 565 |
+
dataloader_runtime = resolve_dataloader_runtime()
|
| 566 |
+
log.info(
|
| 567 |
+
" DataLoader: "
|
| 568 |
+
f"workers={int(dataloader_runtime['num_workers'])} "
|
| 569 |
+
f"pin_memory={bool(dataloader_runtime['pin_memory'])} "
|
| 570 |
+
f"persistent={bool(dataloader_runtime.get('persistent_workers', False))}"
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
dataloader = DataLoader(
|
| 574 |
+
effective_train_dataset,
|
| 575 |
+
batch_size=cfg.training.micro_batch_size,
|
| 576 |
+
shuffle=(args.phase != "kd"),
|
| 577 |
+
collate_fn=train_collate,
|
| 578 |
+
drop_last=True,
|
| 579 |
+
**dataloader_runtime,
|
| 580 |
+
)
|
| 581 |
+
if args.phase == "kd":
|
| 582 |
+
log.info(" KD sampler: sequential shard-local order (split membership remains randomized)")
|
| 583 |
+
val_dataloader = None
|
| 584 |
+
if val_dataset is not None:
|
| 585 |
+
val_dataloader = DataLoader(
|
| 586 |
+
val_dataset,
|
| 587 |
+
batch_size=cfg.training.micro_batch_size,
|
| 588 |
+
shuffle=False,
|
| 589 |
+
collate_fn=val_collate,
|
| 590 |
+
drop_last=False,
|
| 591 |
+
**dataloader_runtime,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
grad_accum = cfg.training.grad_accum_steps
|
| 595 |
+
schedule = compute_training_schedule(
|
| 596 |
+
dataset_size=len(effective_train_dataset),
|
| 597 |
+
micro_batch_size=cfg.training.micro_batch_size,
|
| 598 |
+
grad_accum=grad_accum,
|
| 599 |
+
num_epochs=cfg.training.num_epochs,
|
| 600 |
+
use_ds=use_ds,
|
| 601 |
+
drop_last=True,
|
| 602 |
+
)
|
| 603 |
+
batches_per_epoch = int(schedule["batches_per_epoch"])
|
| 604 |
+
remainder_batches = int(schedule["remainder_batches"])
|
| 605 |
+
has_remainder = bool(schedule["has_remainder"])
|
| 606 |
+
total_micro_batches = int(schedule["total_micro_batches"])
|
| 607 |
+
steps_per_epoch = int(schedule["steps_per_epoch"])
|
| 608 |
+
total_steps = int(schedule["total_steps"])
|
| 609 |
+
final_remainder = int(schedule["final_remainder"])
|
| 610 |
+
|
| 611 |
+
if batches_per_epoch == 0:
|
| 612 |
+
schedule_unit = "packed bins" if sequence_packing_enabled else "samples"
|
| 613 |
+
log.error(
|
| 614 |
+
f"Dataset too small for micro_batch_size={cfg.training.micro_batch_size}. "
|
| 615 |
+
f"Train split has {len(effective_train_dataset)} {schedule_unit} and drop_last=True would produce 0 batches."
|
| 616 |
+
)
|
| 617 |
+
sys.exit(1)
|
| 618 |
+
|
| 619 |
+
dropped_samples_per_epoch = int(schedule["dropped_samples_per_epoch"])
|
| 620 |
+
if dropped_samples_per_epoch:
|
| 621 |
+
schedule_unit = "packed bins" if sequence_packing_enabled else "samples"
|
| 622 |
+
log.warning(
|
| 623 |
+
f" drop_last=True will discard {dropped_samples_per_epoch} {schedule_unit} per epoch "
|
| 624 |
+
"before gradient accumulation begins"
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
if use_ds and final_remainder:
|
| 628 |
+
dropped_total = int(schedule["dropped_samples_total"])
|
| 629 |
+
schedule_unit = "packed bins" if sequence_packing_enabled else "samples"
|
| 630 |
+
message = (
|
| 631 |
+
f"DeepSpeed would drop the final {final_remainder} micro-batches "
|
| 632 |
+
f"({dropped_total} {schedule_unit} total) because {batches_per_epoch} batches per epoch "
|
| 633 |
+
f"across {cfg.training.num_epochs} epochs yields {total_micro_batches} micro-batches, "
|
| 634 |
+
f"which is not divisible by grad_accum={grad_accum}."
|
| 635 |
+
)
|
| 636 |
+
if not args.allow_partial_final_window:
|
| 637 |
+
log.error(message)
|
| 638 |
+
log.error(
|
| 639 |
+
"Adjust num_samples, micro_batch_size, grad_accum_steps, or num_epochs "
|
| 640 |
+
"so total micro-batches is divisible by grad_accum, or rerun with "
|
| 641 |
+
"--allow_partial_final_window for a smoke test."
|
| 642 |
+
)
|
| 643 |
+
sys.exit(1)
|
| 644 |
+
log.warning(message)
|
| 645 |
+
log.warning("Proceeding because --allow_partial_final_window was set.")
|
| 646 |
+
|
| 647 |
+
warmup_steps = int(total_steps * cfg.training.warmup_ratio)
|
| 648 |
+
if has_remainder:
|
| 649 |
+
if use_ds:
|
| 650 |
+
log.info(
|
| 651 |
+
f" NOTE: {batches_per_epoch} batches are not divisible by grad_accum={grad_accum}; "
|
| 652 |
+
f"DeepSpeed carries {remainder_batches} leftover micro-batches across epoch boundaries"
|
| 653 |
+
)
|
| 654 |
+
if final_remainder and args.allow_partial_final_window:
|
| 655 |
+
log.info(
|
| 656 |
+
f" NOTE: only the final {final_remainder} micro-batches of "
|
| 657 |
+
"the last epoch are dropped because they never reach a full accumulation window"
|
| 658 |
+
)
|
| 659 |
+
else:
|
| 660 |
+
log.info(
|
| 661 |
+
f" NOTE: {batches_per_epoch} batches are not divisible by grad_accum={grad_accum}; "
|
| 662 |
+
f"the training loop will flush {remainder_batches} leftover micro-batches each epoch"
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
if not use_ds:
|
| 666 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 667 |
+
model.to(device)
|
| 668 |
+
|
| 669 |
+
optimizer = build_adamw_optimizer(list(model.parameters()), log, allow_fused=not use_ds)
|
| 670 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
|
| 671 |
+
resume_global_step = int(resume_state.get("global_step", 0) or 0) if args.resume_from_checkpoint else 0
|
| 672 |
+
saved_run_epochs = int(resume_state.get("num_epochs", cfg.training.num_epochs) or cfg.training.num_epochs)
|
| 673 |
+
extending_completed_run = (
|
| 674 |
+
args.resume_from_checkpoint
|
| 675 |
+
and saved_run_epochs < cfg.training.num_epochs
|
| 676 |
+
and start_epoch >= saved_run_epochs
|
| 677 |
+
)
|
| 678 |
+
scheduler_state_path = os.path.join(checkpoint_to_load, "scheduler.pt") if checkpoint_to_load else None
|
| 679 |
+
if (
|
| 680 |
+
extending_completed_run
|
| 681 |
+
and read_env_flag("QUINTUS_FRESH_SCHEDULER_ON_EXTEND", True)
|
| 682 |
+
):
|
| 683 |
+
remaining_steps = max(1, total_steps - resume_global_step)
|
| 684 |
+
extension_warmup_steps = int(remaining_steps * cfg.training.warmup_ratio)
|
| 685 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, extension_warmup_steps, remaining_steps)
|
| 686 |
+
log.info(
|
| 687 |
+
f" Scheduler: fresh extension schedule "
|
| 688 |
+
f"({remaining_steps:,} remaining steps, {extension_warmup_steps:,} warmup); "
|
| 689 |
+
f"checkpoint was saved for {saved_run_epochs} epochs"
|
| 690 |
+
)
|
| 691 |
+
elif args.resume_from_checkpoint and scheduler_state_path and os.path.exists(scheduler_state_path):
|
| 692 |
+
try:
|
| 693 |
+
scheduler.load_state_dict(torch.load(scheduler_state_path, map_location="cpu"))
|
| 694 |
+
for param_group, lr in zip(optimizer.param_groups, scheduler.get_last_lr()):
|
| 695 |
+
param_group["lr"] = lr
|
| 696 |
+
log.info(f" Scheduler: restored from {scheduler_state_path}")
|
| 697 |
+
except Exception as exc:
|
| 698 |
+
log.warning(f" Scheduler restore failed ({exc}); continuing with a fresh schedule")
|
| 699 |
+
log.info(f" Batches/ep: {batches_per_epoch:,}")
|
| 700 |
+
step_label = "Steps/ep"
|
| 701 |
+
step_note = ""
|
| 702 |
+
if has_remainder:
|
| 703 |
+
if use_ds:
|
| 704 |
+
step_label = "Steps/ep*"
|
| 705 |
+
step_note = " (floor; cross-epoch carry shifts exact epoch boundaries)"
|
| 706 |
+
else:
|
| 707 |
+
step_note = " (includes remainder flush)"
|
| 708 |
+
log.info(f" {step_label}: {steps_per_epoch:,}{step_note}")
|
| 709 |
+
log.info(f" Steps total: {total_steps:,} ({warmup_steps:,} warmup)")
|
| 710 |
+
log.info(
|
| 711 |
+
" Best ckpt: held-out validation loss"
|
| 712 |
+
if val_dataloader is not None
|
| 713 |
+
else " Best ckpt: training loss (validation disabled)"
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
if use_ds:
|
| 717 |
+
import deepspeed
|
| 718 |
+
|
| 719 |
+
model, optimizer, _, scheduler = deepspeed.initialize(
|
| 720 |
+
model=model,
|
| 721 |
+
optimizer=optimizer,
|
| 722 |
+
lr_scheduler=scheduler,
|
| 723 |
+
config=ds_runtime_config,
|
| 724 |
+
)
|
| 725 |
+
device = model.device
|
| 726 |
+
log.info("[DS] DeepSpeed ZeRO-2 initialized")
|
| 727 |
+
log.info(f"[DS] DeepSpeed will accumulate over {grad_accum} micro-batches internally")
|
| 728 |
+
else:
|
| 729 |
+
log.info(f" Device: {device}")
|
| 730 |
+
|
| 731 |
+
_log_gpu(log)
|
| 732 |
+
|
| 733 |
+
teacher_model = None
|
| 734 |
+
if args.phase == "online_kd":
|
| 735 |
+
teacher_source = cfg.paths.teacher_dir if os.path.exists(cfg.paths.teacher_dir) else cfg.model.teacher
|
| 736 |
+
if teacher_source != cfg.model.teacher:
|
| 737 |
+
log.info(f"Loading frozen teacher model from local directory '{teacher_source}' on device {device}...")
|
| 738 |
+
else:
|
| 739 |
+
log.info(f"Loading frozen teacher model '{teacher_source}' on device {device}...")
|
| 740 |
+
try:
|
| 741 |
+
teacher_model = AutoModelForCausalLM.from_pretrained(
|
| 742 |
+
teacher_source,
|
| 743 |
+
dtype=torch.bfloat16,
|
| 744 |
+
low_cpu_mem_usage=True,
|
| 745 |
+
trust_remote_code=cfg.model.allow_remote_code,
|
| 746 |
+
attn_implementation=attn_impl,
|
| 747 |
+
).to(device)
|
| 748 |
+
for p in teacher_model.parameters():
|
| 749 |
+
p.requires_grad = False
|
| 750 |
+
teacher_model.eval()
|
| 751 |
+
log.info(f"Teacher model '{teacher_source}' loaded and frozen.")
|
| 752 |
+
except Exception as exc:
|
| 753 |
+
log.error(f"Failed to load teacher model: {exc}")
|
| 754 |
+
sys.exit(1)
|
| 755 |
+
|
| 756 |
+
checkpoint_packing_metadata = packing_checkpoint_metadata(
|
| 757 |
+
enabled=sequence_packing_enabled,
|
| 758 |
+
pack_length=pack_length,
|
| 759 |
+
max_seq_len=int(cfg.data.max_seq_len),
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
os.makedirs(cfg.paths.distilled_dir, exist_ok=True)
|
| 763 |
+
loss_log: list[dict] = []
|
| 764 |
+
global_step = resume_global_step
|
| 765 |
+
micro_step_global = int(resume_state.get("micro_step_global", 0) or 0) if args.resume_from_checkpoint else 0
|
| 766 |
+
best_metric_name = "validation loss" if val_dataloader is not None else "training loss"
|
| 767 |
+
best_selection_loss = float("inf")
|
| 768 |
+
if args.resume_from_checkpoint and "best_selection_loss" in resume_state:
|
| 769 |
+
try:
|
| 770 |
+
best_selection_loss = float(resume_state["best_selection_loss"])
|
| 771 |
+
log.info(f" Best resume: restored prior best {best_metric_name}={best_selection_loss:.4f}")
|
| 772 |
+
except (TypeError, ValueError):
|
| 773 |
+
log.warning(" Best resume: prior best_selection_loss was unreadable; recomputing from this run")
|
| 774 |
+
best_checkpoint_tag = resume_state.get("best_checkpoint_tag")
|
| 775 |
+
best_ckpt_path = os.path.join(cfg.paths.distilled_dir, "best")
|
| 776 |
+
if not os.path.isdir(best_ckpt_path):
|
| 777 |
+
best_ckpt_path = None
|
| 778 |
+
if best_checkpoint_tag:
|
| 779 |
+
candidate_best_path = os.path.join(cfg.paths.distilled_dir, str(best_checkpoint_tag))
|
| 780 |
+
if os.path.isdir(candidate_best_path):
|
| 781 |
+
best_ckpt_path = candidate_best_path
|
| 782 |
+
log.info(f" Best resume: using {best_checkpoint_tag} as the current best checkpoint")
|
| 783 |
+
t_start = time.time()
|
| 784 |
+
|
| 785 |
+
alpha = cfg.training.alpha
|
| 786 |
+
temperature = cfg.training.temperature
|
| 787 |
+
log_every = max(1, min(50, total_steps // 20))
|
| 788 |
+
checkpoint_every_steps = max(0, int(os.environ.get("TRAIN_CHECKPOINT_EVERY_STEPS", "2000")))
|
| 789 |
+
if getattr(cfg.training, "disable_checkpointing", False):
|
| 790 |
+
checkpoint_every_steps = 0
|
| 791 |
+
|
| 792 |
+
running_loss = 0.0
|
| 793 |
+
running_ce = 0.0
|
| 794 |
+
running_kd = 0.0
|
| 795 |
+
running_count = 0
|
| 796 |
+
|
| 797 |
+
emit_log_spacing(log)
|
| 798 |
+
log.info("-" * 70)
|
| 799 |
+
log.info("Training Start")
|
| 800 |
+
if checkpoint_every_steps:
|
| 801 |
+
log.info(f" Mid-epoch checkpoint interval: every {checkpoint_every_steps:,} optimizer steps")
|
| 802 |
+
else:
|
| 803 |
+
log.info(" Mid-epoch checkpoints disabled")
|
| 804 |
+
log.info("-" * 70)
|
| 805 |
+
|
| 806 |
+
window_tokens = 0
|
| 807 |
+
window_t_start = time.time()
|
| 808 |
+
_gpu_loss_accum = torch.zeros(1, device=device)
|
| 809 |
+
_gpu_ce_accum = torch.zeros(1, device=device)
|
| 810 |
+
_gpu_kd_accum = torch.zeros(1, device=device)
|
| 811 |
+
_gpu_tokens_accum = torch.zeros(1, dtype=torch.long, device=device)
|
| 812 |
+
training_complete = False
|
| 813 |
+
|
| 814 |
+
for epoch in range(start_epoch, cfg.training.num_epochs):
|
| 815 |
+
if training_complete:
|
| 816 |
+
break
|
| 817 |
+
t_epoch = time.time()
|
| 818 |
+
epoch_loss = 0.0
|
| 819 |
+
epoch_ce = 0.0
|
| 820 |
+
epoch_kd = 0.0
|
| 821 |
+
epoch_steps = 0
|
| 822 |
+
epoch_tokens = 0
|
| 823 |
+
micro_in_epoch = 0
|
| 824 |
+
resume_batch_offset = 0
|
| 825 |
+
if args.resume_from_checkpoint and epoch == start_epoch:
|
| 826 |
+
resume_batch_offset = int(resume_state.get("next_batch_in_epoch", 0) or 0)
|
| 827 |
+
if resume_batch_offset:
|
| 828 |
+
log.info(f" Resume: skipping {resume_batch_offset:,} already-processed batches in epoch {epoch + 1}")
|
| 829 |
+
|
| 830 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 831 |
+
if resume_batch_offset and batch_idx < resume_batch_offset:
|
| 832 |
+
continue
|
| 833 |
+
batch = move_batch_to_device(batch, device)
|
| 834 |
+
input_ids = batch["input_ids"]
|
| 835 |
+
attention_mask = batch["attention_mask"]
|
| 836 |
+
labels = batch["labels"]
|
| 837 |
+
loss_mask = batch["loss_mask"]
|
| 838 |
+
|
| 839 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
|
| 840 |
+
|
| 841 |
+
if args.phase == "online_kd" and teacher_model is not None:
|
| 842 |
+
with torch.no_grad():
|
| 843 |
+
teacher_logits = teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits
|
| 844 |
+
else:
|
| 845 |
+
teacher_logits = None
|
| 846 |
+
|
| 847 |
+
loss, ce, kd = compute_loss_for_phase(
|
| 848 |
+
args.phase,
|
| 849 |
+
logits,
|
| 850 |
+
labels,
|
| 851 |
+
loss_mask,
|
| 852 |
+
batch,
|
| 853 |
+
alpha,
|
| 854 |
+
temperature,
|
| 855 |
+
teacher_logits=teacher_logits,
|
| 856 |
+
online_kd_token_chunk_size=int(cfg.training.online_kd_token_chunk_size),
|
| 857 |
+
)
|
| 858 |
+
if not torch.isfinite(loss):
|
| 859 |
+
log.error(
|
| 860 |
+
f"Non-finite loss in phase={args.phase}: "
|
| 861 |
+
f"loss={loss.item()} ce={ce.item()} kd={kd.item()}"
|
| 862 |
+
)
|
| 863 |
+
if args.phase == "kd":
|
| 864 |
+
log.error("Action: regenerate teacher logits.")
|
| 865 |
+
else:
|
| 866 |
+
log.error("Action: check dataset / reduce LR.")
|
| 867 |
+
sys.exit(1)
|
| 868 |
+
|
| 869 |
+
micro_in_epoch += 1
|
| 870 |
+
micro_step_global += 1
|
| 871 |
+
|
| 872 |
+
_gpu_loss_accum += loss.detach()
|
| 873 |
+
_gpu_ce_accum += ce.detach()
|
| 874 |
+
_gpu_kd_accum += kd.detach()
|
| 875 |
+
_gpu_tokens_accum += attention_mask.sum()
|
| 876 |
+
|
| 877 |
+
if use_ds:
|
| 878 |
+
model.backward(loss)
|
| 879 |
+
model.step()
|
| 880 |
+
else:
|
| 881 |
+
scaled = loss / grad_accum
|
| 882 |
+
scaled.backward()
|
| 883 |
+
|
| 884 |
+
if micro_in_epoch % grad_accum == 0:
|
| 885 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 886 |
+
optimizer.step()
|
| 887 |
+
scheduler.step()
|
| 888 |
+
optimizer.zero_grad(set_to_none=True)
|
| 889 |
+
|
| 890 |
+
is_optim_step = (
|
| 891 |
+
(micro_step_global % grad_accum == 0) if use_ds else (micro_in_epoch % grad_accum == 0)
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
if is_optim_step:
|
| 895 |
+
global_step += 1
|
| 896 |
+
epoch_steps += 1
|
| 897 |
+
running_count += 1
|
| 898 |
+
|
| 899 |
+
step_loss = _gpu_loss_accum.item() / grad_accum
|
| 900 |
+
step_ce = _gpu_ce_accum.item() / grad_accum
|
| 901 |
+
step_kd = _gpu_kd_accum.item() / grad_accum
|
| 902 |
+
step_tokens = _gpu_tokens_accum.item()
|
| 903 |
+
_gpu_loss_accum.zero_()
|
| 904 |
+
_gpu_ce_accum.zero_()
|
| 905 |
+
_gpu_kd_accum.zero_()
|
| 906 |
+
_gpu_tokens_accum.zero_()
|
| 907 |
+
|
| 908 |
+
epoch_tokens += step_tokens
|
| 909 |
+
window_tokens += step_tokens
|
| 910 |
+
epoch_loss += step_loss
|
| 911 |
+
epoch_ce += step_ce
|
| 912 |
+
epoch_kd += step_kd
|
| 913 |
+
running_loss += step_loss
|
| 914 |
+
running_ce += step_ce
|
| 915 |
+
running_kd += step_kd
|
| 916 |
+
|
| 917 |
+
if global_step % log_every == 0 or global_step == total_steps:
|
| 918 |
+
avg_loss = running_loss / max(running_count, 1)
|
| 919 |
+
avg_ce = running_ce / max(running_count, 1)
|
| 920 |
+
avg_kd = running_kd / max(running_count, 1)
|
| 921 |
+
try:
|
| 922 |
+
lr = scheduler.get_last_lr()[0]
|
| 923 |
+
except Exception:
|
| 924 |
+
lr = cfg.training.learning_rate
|
| 925 |
+
window_elapsed = max(time.time() - window_t_start, 0.1)
|
| 926 |
+
rolling_tok_s = window_tokens / window_elapsed
|
| 927 |
+
rolling_eta_s = (window_elapsed / max(running_count, 1)) * (total_steps - global_step) / log_every * running_count
|
| 928 |
+
cum_tok_s = epoch_tokens / max(time.time() - t_epoch, 1)
|
| 929 |
+
log.info(
|
| 930 |
+
f" E{epoch + 1}/{cfg.training.num_epochs} "
|
| 931 |
+
f"S{global_step:>4}/{total_steps} | "
|
| 932 |
+
f"loss={avg_loss:.4f} ce={avg_ce:.4f} kd={avg_kd:.4f} | "
|
| 933 |
+
f"lr={lr:.2e} | {rolling_tok_s:,.0f} tok/s (avg {cum_tok_s:,.0f}) | ETA {rolling_eta_s / 60:.1f}m"
|
| 934 |
+
)
|
| 935 |
+
loss_log.append(
|
| 936 |
+
{
|
| 937 |
+
"step": global_step,
|
| 938 |
+
"epoch": epoch + 1,
|
| 939 |
+
"loss_total": round(avg_loss, 5),
|
| 940 |
+
"loss_ce": round(avg_ce, 5),
|
| 941 |
+
"loss_kd": round(avg_kd, 5),
|
| 942 |
+
"lr": lr,
|
| 943 |
+
"tok_per_sec": round(rolling_tok_s, 0),
|
| 944 |
+
"tok_per_sec_cumulative": round(cum_tok_s, 0),
|
| 945 |
+
}
|
| 946 |
+
)
|
| 947 |
+
window_tokens = 0
|
| 948 |
+
window_t_start = time.time()
|
| 949 |
+
|
| 950 |
+
running_loss = 0.0
|
| 951 |
+
running_ce = 0.0
|
| 952 |
+
running_kd = 0.0
|
| 953 |
+
running_count = 0
|
| 954 |
+
|
| 955 |
+
if checkpoint_every_steps and global_step % checkpoint_every_steps == 0 and is_main:
|
| 956 |
+
log.info(f" Saving mid-epoch checkpoint at step {global_step}...")
|
| 957 |
+
step_tag = f"step_{global_step}"
|
| 958 |
+
step_ckpt_path = save_checkpoint(
|
| 959 |
+
model,
|
| 960 |
+
tokenizer,
|
| 961 |
+
cfg.paths.distilled_dir,
|
| 962 |
+
step_tag,
|
| 963 |
+
log,
|
| 964 |
+
scheduler=scheduler,
|
| 965 |
+
trainer_state={
|
| 966 |
+
**checkpoint_packing_metadata,
|
| 967 |
+
"checkpoint_type": "step",
|
| 968 |
+
"phase": args.phase,
|
| 969 |
+
"epoch_index": epoch,
|
| 970 |
+
"start_epoch": epoch,
|
| 971 |
+
"global_step": global_step,
|
| 972 |
+
"micro_step_global": micro_step_global,
|
| 973 |
+
"next_batch_in_epoch": micro_in_epoch,
|
| 974 |
+
"num_epochs": cfg.training.num_epochs,
|
| 975 |
+
"micro_batch_size": cfg.training.micro_batch_size,
|
| 976 |
+
"grad_accum_steps": grad_accum,
|
| 977 |
+
},
|
| 978 |
+
)
|
| 979 |
+
maybe_upload_checkpoint(step_ckpt_path, step_tag, log)
|
| 980 |
+
|
| 981 |
+
if args.max_steps > 0 and global_step >= args.max_steps:
|
| 982 |
+
log.info(f"Reached max_steps={args.max_steps}. Stopping training.")
|
| 983 |
+
training_complete = True
|
| 984 |
+
break
|
| 985 |
+
|
| 986 |
+
if training_complete:
|
| 987 |
+
break
|
| 988 |
+
|
| 989 |
+
if not use_ds:
|
| 990 |
+
remainder = micro_in_epoch % grad_accum
|
| 991 |
+
if remainder != 0:
|
| 992 |
+
flush_scale = grad_accum / remainder
|
| 993 |
+
for parameter in model.parameters():
|
| 994 |
+
if parameter.grad is not None:
|
| 995 |
+
parameter.grad.mul_(flush_scale)
|
| 996 |
+
|
| 997 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 998 |
+
optimizer.step()
|
| 999 |
+
scheduler.step()
|
| 1000 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1001 |
+
global_step += 1
|
| 1002 |
+
epoch_steps += 1
|
| 1003 |
+
|
| 1004 |
+
step_loss = _gpu_loss_accum.item() / remainder
|
| 1005 |
+
step_ce = _gpu_ce_accum.item() / remainder
|
| 1006 |
+
step_kd = _gpu_kd_accum.item() / remainder
|
| 1007 |
+
step_tokens = _gpu_tokens_accum.item()
|
| 1008 |
+
_gpu_loss_accum.zero_()
|
| 1009 |
+
_gpu_ce_accum.zero_()
|
| 1010 |
+
_gpu_kd_accum.zero_()
|
| 1011 |
+
_gpu_tokens_accum.zero_()
|
| 1012 |
+
|
| 1013 |
+
epoch_tokens += step_tokens
|
| 1014 |
+
window_tokens += step_tokens
|
| 1015 |
+
running_loss += step_loss
|
| 1016 |
+
running_ce += step_ce
|
| 1017 |
+
running_kd += step_kd
|
| 1018 |
+
running_count += 1
|
| 1019 |
+
|
| 1020 |
+
avg_loss = running_loss / max(running_count, 1)
|
| 1021 |
+
avg_ce = running_ce / max(running_count, 1)
|
| 1022 |
+
avg_kd = running_kd / max(running_count, 1)
|
| 1023 |
+
epoch_loss += step_loss
|
| 1024 |
+
epoch_ce += step_ce
|
| 1025 |
+
epoch_kd += step_kd
|
| 1026 |
+
running_loss = 0.0
|
| 1027 |
+
running_ce = 0.0
|
| 1028 |
+
running_kd = 0.0
|
| 1029 |
+
running_count = 0
|
| 1030 |
+
|
| 1031 |
+
elapsed = time.time() - t_start
|
| 1032 |
+
try:
|
| 1033 |
+
lr = scheduler.get_last_lr()[0]
|
| 1034 |
+
except Exception:
|
| 1035 |
+
lr = cfg.training.learning_rate
|
| 1036 |
+
tok_s = epoch_tokens / max(time.time() - t_epoch, 1)
|
| 1037 |
+
eta_s = (elapsed / max(global_step, 1)) * (total_steps - global_step)
|
| 1038 |
+
log.info(
|
| 1039 |
+
f" E{epoch + 1}/{cfg.training.num_epochs} "
|
| 1040 |
+
f"S{global_step:>4}/{total_steps} | "
|
| 1041 |
+
f"loss={avg_loss:.4f} ce={avg_ce:.4f} kd={avg_kd:.4f} | "
|
| 1042 |
+
f"lr={lr:.2e} | {tok_s:,.0f} tok/s | ETA {eta_s / 60:.1f}m [flush]"
|
| 1043 |
+
)
|
| 1044 |
+
loss_log.append(
|
| 1045 |
+
{
|
| 1046 |
+
"step": global_step,
|
| 1047 |
+
"epoch": epoch + 1,
|
| 1048 |
+
"loss_total": round(avg_loss, 5),
|
| 1049 |
+
"loss_ce": round(avg_ce, 5),
|
| 1050 |
+
"loss_kd": round(avg_kd, 5),
|
| 1051 |
+
"lr": lr,
|
| 1052 |
+
"tok_per_sec": round(tok_s, 0),
|
| 1053 |
+
}
|
| 1054 |
+
)
|
| 1055 |
+
window_tokens = 0
|
| 1056 |
+
window_t_start = time.time()
|
| 1057 |
+
log.info(f" Epoch {epoch + 1}: flushed {remainder} leftover micro-batches")
|
| 1058 |
+
else:
|
| 1059 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1060 |
+
elif (micro_step_global % grad_accum) != 0 and epoch < cfg.training.num_epochs - 1:
|
| 1061 |
+
carry = micro_step_global % grad_accum
|
| 1062 |
+
log.info(f" Epoch {epoch + 1}: carrying {carry} micro-batches into the next epoch")
|
| 1063 |
+
|
| 1064 |
+
avg_epoch_loss = epoch_loss / max(epoch_steps, 1)
|
| 1065 |
+
avg_epoch_ce = epoch_ce / max(epoch_steps, 1)
|
| 1066 |
+
avg_epoch_kd = epoch_kd / max(epoch_steps, 1)
|
| 1067 |
+
epoch_elapsed = time.time() - t_epoch
|
| 1068 |
+
log.info(
|
| 1069 |
+
f" Epoch {epoch + 1} done | "
|
| 1070 |
+
f"avg_loss={avg_epoch_loss:.4f} ce={avg_epoch_ce:.4f} kd={avg_epoch_kd:.4f} | "
|
| 1071 |
+
f"{epoch_tokens:,} tok | {epoch_elapsed / 60:.1f}m"
|
| 1072 |
+
)
|
| 1073 |
+
_log_gpu(log)
|
| 1074 |
+
|
| 1075 |
+
val_metrics = None
|
| 1076 |
+
if val_dataloader is not None:
|
| 1077 |
+
val_start = time.time()
|
| 1078 |
+
val_limit = min(20, len(val_dataloader)) if args.max_steps > 0 else -1
|
| 1079 |
+
if val_limit > 0:
|
| 1080 |
+
log.info(f" Validation start | capping at {val_limit} batches for dry run (total {len(val_dataloader)} batches)")
|
| 1081 |
+
else:
|
| 1082 |
+
log.info(f" Validation start | {len(val_dataloader):,} batches")
|
| 1083 |
+
val_metrics = evaluate_validation_loss(
|
| 1084 |
+
phase=args.phase,
|
| 1085 |
+
model=model,
|
| 1086 |
+
dataloader=val_dataloader,
|
| 1087 |
+
device=device,
|
| 1088 |
+
alpha=alpha,
|
| 1089 |
+
temperature=temperature,
|
| 1090 |
+
online_kd_token_chunk_size=int(cfg.training.online_kd_token_chunk_size),
|
| 1091 |
+
teacher_model=teacher_model,
|
| 1092 |
+
max_batches=val_limit,
|
| 1093 |
+
)
|
| 1094 |
+
log.info(
|
| 1095 |
+
f" Validation | loss={val_metrics['loss']:.4f} ce={val_metrics['ce']:.4f} "
|
| 1096 |
+
f"kd={val_metrics['kd']:.4f} | {int(val_metrics['batches'])} batches | "
|
| 1097 |
+
f"{(time.time() - val_start) / 60:.1f}m"
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
if is_main:
|
| 1101 |
+
selection_loss = val_metrics["loss"] if val_metrics is not None else avg_epoch_loss
|
| 1102 |
+
is_new_best = selection_loss < best_selection_loss
|
| 1103 |
+
epoch_tag = f"epoch_{epoch + 1}"
|
| 1104 |
+
if is_new_best:
|
| 1105 |
+
best_selection_loss = selection_loss
|
| 1106 |
+
best_checkpoint_tag = epoch_tag
|
| 1107 |
+
log.info(f" Best update: {best_metric_name}={best_selection_loss:.4f} from {epoch_tag}")
|
| 1108 |
+
else:
|
| 1109 |
+
log.info(
|
| 1110 |
+
f" Best unchanged: current {best_metric_name}={selection_loss:.4f}; "
|
| 1111 |
+
f"best={best_selection_loss:.4f} from {best_checkpoint_tag}"
|
| 1112 |
+
)
|
| 1113 |
+
epoch_state = {
|
| 1114 |
+
**checkpoint_packing_metadata,
|
| 1115 |
+
"checkpoint_type": "epoch",
|
| 1116 |
+
"phase": args.phase,
|
| 1117 |
+
"epoch_index": epoch,
|
| 1118 |
+
"start_epoch": epoch + 1,
|
| 1119 |
+
"global_step": global_step,
|
| 1120 |
+
"micro_step_global": micro_step_global,
|
| 1121 |
+
"next_batch_in_epoch": 0,
|
| 1122 |
+
"num_epochs": cfg.training.num_epochs,
|
| 1123 |
+
"micro_batch_size": cfg.training.micro_batch_size,
|
| 1124 |
+
"grad_accum_steps": grad_accum,
|
| 1125 |
+
"selection_loss": float(selection_loss),
|
| 1126 |
+
"best_selection_loss": float(best_selection_loss),
|
| 1127 |
+
"best_metric_name": best_metric_name,
|
| 1128 |
+
"best_checkpoint_tag": best_checkpoint_tag,
|
| 1129 |
+
}
|
| 1130 |
+
if read_env_flag("QUINTUS_SAVE_EPOCH_CHECKPOINTS", True) and not getattr(cfg.training, "disable_checkpointing", False):
|
| 1131 |
+
epoch_ckpt_path = save_checkpoint(
|
| 1132 |
+
model,
|
| 1133 |
+
tokenizer,
|
| 1134 |
+
cfg.paths.distilled_dir,
|
| 1135 |
+
epoch_tag,
|
| 1136 |
+
log,
|
| 1137 |
+
scheduler=scheduler,
|
| 1138 |
+
trainer_state=epoch_state,
|
| 1139 |
+
)
|
| 1140 |
+
maybe_upload_checkpoint(epoch_ckpt_path, epoch_tag, log)
|
| 1141 |
+
else:
|
| 1142 |
+
log.info(f" Skipping intermediate {epoch_tag} save")
|
| 1143 |
+
if is_new_best and not getattr(cfg.training, "disable_checkpointing", False):
|
| 1144 |
+
best_ckpt_path = save_checkpoint(
|
| 1145 |
+
model,
|
| 1146 |
+
tokenizer,
|
| 1147 |
+
cfg.paths.distilled_dir,
|
| 1148 |
+
"best",
|
| 1149 |
+
log,
|
| 1150 |
+
scheduler=scheduler,
|
| 1151 |
+
trainer_state=dict(epoch_state, checkpoint_type="best"),
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
if use_ds and final_remainder:
|
| 1155 |
+
model.zero_grad()
|
| 1156 |
+
running_loss = 0.0
|
| 1157 |
+
running_ce = 0.0
|
| 1158 |
+
running_kd = 0.0
|
| 1159 |
+
running_count = 0
|
| 1160 |
+
log.warning(f" Training end: dropped final {final_remainder} leftover micro-batches")
|
| 1161 |
+
|
| 1162 |
+
if is_main:
|
| 1163 |
+
if best_ckpt_path and os.path.isdir(best_ckpt_path) and not getattr(cfg.training, "disable_checkpointing", False):
|
| 1164 |
+
maybe_upload_checkpoint(best_ckpt_path, "best", log)
|
| 1165 |
+
last_ckpt_path = save_checkpoint(
|
| 1166 |
+
model,
|
| 1167 |
+
tokenizer,
|
| 1168 |
+
cfg.paths.distilled_dir,
|
| 1169 |
+
"last",
|
| 1170 |
+
log,
|
| 1171 |
+
scheduler=scheduler,
|
| 1172 |
+
trainer_state={
|
| 1173 |
+
**checkpoint_packing_metadata,
|
| 1174 |
+
"checkpoint_type": "last",
|
| 1175 |
+
"phase": args.phase,
|
| 1176 |
+
"start_epoch": cfg.training.num_epochs,
|
| 1177 |
+
"global_step": global_step,
|
| 1178 |
+
"micro_step_global": micro_step_global,
|
| 1179 |
+
"next_batch_in_epoch": 0,
|
| 1180 |
+
"num_epochs": cfg.training.num_epochs,
|
| 1181 |
+
"micro_batch_size": cfg.training.micro_batch_size,
|
| 1182 |
+
"grad_accum_steps": grad_accum,
|
| 1183 |
+
"best_selection_loss": float(best_selection_loss) if math.isfinite(best_selection_loss) else None,
|
| 1184 |
+
"best_metric_name": best_metric_name,
|
| 1185 |
+
"best_checkpoint_tag": best_checkpoint_tag,
|
| 1186 |
+
},
|
| 1187 |
+
)
|
| 1188 |
+
maybe_upload_checkpoint(last_ckpt_path, "last", log)
|
| 1189 |
+
|
| 1190 |
+
csv_path = os.path.join(cfg.paths.distilled_dir, cfg.paths.loss_csv)
|
| 1191 |
+
if loss_log and is_main:
|
| 1192 |
+
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
| 1193 |
+
writer = csv.DictWriter(f, fieldnames=loss_log[0].keys())
|
| 1194 |
+
writer.writeheader()
|
| 1195 |
+
writer.writerows(loss_log)
|
| 1196 |
+
log.info(f"Loss CSV -> {csv_path}")
|
| 1197 |
+
|
| 1198 |
+
total_elapsed = time.time() - t_start
|
| 1199 |
+
emit_log_spacing(log)
|
| 1200 |
+
log.info("=" * 70)
|
| 1201 |
+
log.info("Training complete")
|
| 1202 |
+
log.info(f" Wall time: {total_elapsed / 3600:.2f}h ({total_elapsed / 60:.1f}m)")
|
| 1203 |
+
log.info(f" Optim steps: {global_step}")
|
| 1204 |
+
log.info(f" Micro steps: {micro_step_global}")
|
| 1205 |
+
log.info(f" Best {best_metric_name}: {best_selection_loss:.4f}")
|
| 1206 |
+
log.info(f" Best ckpt: {best_ckpt_path}")
|
| 1207 |
+
log.info(f" Output dir: {cfg.paths.distilled_dir}/")
|
| 1208 |
+
log.info("=" * 70)
|
| 1209 |
+
|
| 1210 |
+
|
| 1211 |
+
if __name__ == "__main__":
|
| 1212 |
+
try:
|
| 1213 |
+
main()
|
| 1214 |
+
except Exception:
|
| 1215 |
+
try:
|
| 1216 |
+
setup_logger("TRAIN").exception("Uncaught training failure")
|
| 1217 |
+
except Exception:
|
| 1218 |
+
pass
|
| 1219 |
+
raise
|
src/training_data.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
from configs import cfg
|
| 11 |
+
|
| 12 |
+
PAD_MULTIPLE = 128
|
| 13 |
+
|
| 14 |
+
def torch_load_cpu(path: str) -> dict:
|
| 15 |
+
try:
|
| 16 |
+
return torch.load(path, map_location="cpu", weights_only=True)
|
| 17 |
+
except TypeError:
|
| 18 |
+
return torch.load(path, map_location="cpu")
|
| 19 |
+
|
| 20 |
+
def extract_shard_id_range(shard_payload: dict, shard_path: str) -> tuple[int, int]:
|
| 21 |
+
try:
|
| 22 |
+
ids_payload = shard_payload["ids"]
|
| 23 |
+
except KeyError as exc:
|
| 24 |
+
raise KeyError(
|
| 25 |
+
f"Teacher shard {shard_path} is missing 'ids'. Regenerate the teacher-logit shards."
|
| 26 |
+
) from exc
|
| 27 |
+
|
| 28 |
+
if torch.is_tensor(ids_payload):
|
| 29 |
+
if ids_payload.numel() == 0:
|
| 30 |
+
raise ValueError(
|
| 31 |
+
f"Teacher shard {shard_path} has an empty ids tensor. Regenerate the teacher-logit shards."
|
| 32 |
+
)
|
| 33 |
+
return int(ids_payload.min().item()), int(ids_payload.max().item())
|
| 34 |
+
|
| 35 |
+
if not isinstance(ids_payload, list) or not ids_payload:
|
| 36 |
+
raise ValueError(
|
| 37 |
+
f"Teacher shard {shard_path} has an incompatible ids payload. "
|
| 38 |
+
"Regenerate the teacher-logit shards."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
min_id: int | None = None
|
| 42 |
+
max_id: int | None = None
|
| 43 |
+
for sample_idx, ids_tensor in enumerate(ids_payload):
|
| 44 |
+
if not torch.is_tensor(ids_tensor):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"Teacher shard {shard_path} sample #{sample_idx} has a non-tensor ids payload. "
|
| 47 |
+
"Regenerate the teacher-logit shards."
|
| 48 |
+
)
|
| 49 |
+
if ids_tensor.numel() == 0:
|
| 50 |
+
continue
|
| 51 |
+
sample_min = int(ids_tensor.min().item())
|
| 52 |
+
sample_max = int(ids_tensor.max().item())
|
| 53 |
+
min_id = sample_min if min_id is None else min(min_id, sample_min)
|
| 54 |
+
max_id = sample_max if max_id is None else max(max_id, sample_max)
|
| 55 |
+
|
| 56 |
+
if min_id is None or max_id is None:
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f"Teacher shard {shard_path} only contains empty ids tensors. "
|
| 59 |
+
"Regenerate the teacher-logit shards."
|
| 60 |
+
)
|
| 61 |
+
return min_id, max_id
|
| 62 |
+
|
| 63 |
+
class DistillationDataset(Dataset):
|
| 64 |
+
def __init__(self, data_path: str, logits_dir: str, max_seq_len: int, num_samples: int = -1, phase: str = "kd"):
|
| 65 |
+
self.phase = phase
|
| 66 |
+
self.data_path = data_path
|
| 67 |
+
self.logits_dir = logits_dir
|
| 68 |
+
self.max_seq_len = max_seq_len
|
| 69 |
+
self.samples_per_shard = self._resolve_samples_per_shard()
|
| 70 |
+
self.sample_offsets: list[int] = []
|
| 71 |
+
self.sample_lengths: list[int] = []
|
| 72 |
+
self.sample_target_counts: list[int] = []
|
| 73 |
+
self._data_handle = None
|
| 74 |
+
self._cached_shard_idx: int | None = None
|
| 75 |
+
self._cached_shard_path: str | None = None
|
| 76 |
+
self._cached_shard_payload: dict | None = None
|
| 77 |
+
|
| 78 |
+
with open(data_path, "r", encoding="utf-8") as f:
|
| 79 |
+
while True:
|
| 80 |
+
if 0 < num_samples <= len(self.sample_offsets):
|
| 81 |
+
break
|
| 82 |
+
offset = f.tell()
|
| 83 |
+
line = f.readline()
|
| 84 |
+
if not line:
|
| 85 |
+
break
|
| 86 |
+
i = len(self.sample_offsets)
|
| 87 |
+
raw_sample = json.loads(line)
|
| 88 |
+
input_ids_list, loss_mask_list = self._coerce_tokenized_row(raw_sample, i)
|
| 89 |
+
self.sample_offsets.append(offset)
|
| 90 |
+
self.sample_lengths.append(len(input_ids_list))
|
| 91 |
+
self.sample_target_counts.append(sum(loss_mask_list))
|
| 92 |
+
|
| 93 |
+
def __len__(self) -> int:
|
| 94 |
+
return len(self.sample_offsets)
|
| 95 |
+
|
| 96 |
+
def __getstate__(self) -> dict:
|
| 97 |
+
state = self.__dict__.copy()
|
| 98 |
+
state["_data_handle"] = None
|
| 99 |
+
state["_cached_shard_idx"] = None
|
| 100 |
+
state["_cached_shard_path"] = None
|
| 101 |
+
state["_cached_shard_payload"] = None
|
| 102 |
+
return state
|
| 103 |
+
|
| 104 |
+
def __del__(self) -> None:
|
| 105 |
+
data_handle = getattr(self, "_data_handle", None)
|
| 106 |
+
if data_handle is not None:
|
| 107 |
+
try:
|
| 108 |
+
data_handle.close()
|
| 109 |
+
except Exception:
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def _resolve_samples_per_shard(self) -> int:
|
| 113 |
+
prov_path = os.path.join(self.logits_dir, "_provenance.json")
|
| 114 |
+
if not os.path.exists(prov_path):
|
| 115 |
+
return 1
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
with open(prov_path, "r", encoding="utf-8") as f:
|
| 119 |
+
prov = json.load(f)
|
| 120 |
+
except (OSError, json.JSONDecodeError):
|
| 121 |
+
return 1
|
| 122 |
+
|
| 123 |
+
shard_schema = prov.get("shard_schema", {})
|
| 124 |
+
if shard_schema.get("layout") != "chunked_sample_lists":
|
| 125 |
+
return 1
|
| 126 |
+
|
| 127 |
+
raw_value = prov.get("samples_per_shard", 1)
|
| 128 |
+
try:
|
| 129 |
+
value = int(raw_value)
|
| 130 |
+
except (TypeError, ValueError):
|
| 131 |
+
return 1
|
| 132 |
+
return max(value, 1)
|
| 133 |
+
|
| 134 |
+
def _coerce_tokenized_row(self, raw_sample: dict, idx: int) -> tuple[list[int], list[int]]:
|
| 135 |
+
try:
|
| 136 |
+
input_ids = raw_sample["input_ids"][: self.max_seq_len]
|
| 137 |
+
except KeyError as exc:
|
| 138 |
+
raise KeyError(
|
| 139 |
+
f"Tokenized sample #{idx} is missing 'input_ids'. "
|
| 140 |
+
"Re-run download.py to regenerate the tokenized dataset."
|
| 141 |
+
) from exc
|
| 142 |
+
try:
|
| 143 |
+
loss_mask = raw_sample["loss_mask"][: len(input_ids)]
|
| 144 |
+
except KeyError as exc:
|
| 145 |
+
raise KeyError(
|
| 146 |
+
"Tokenized sample is missing 'loss_mask'. Re-run download.py to regenerate "
|
| 147 |
+
"assistant-only training targets before distilling."
|
| 148 |
+
) from exc
|
| 149 |
+
|
| 150 |
+
if not isinstance(input_ids, list) or len(input_ids) == 0:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
f"Tokenized sample #{idx} has incompatible input_ids payload. "
|
| 153 |
+
"Re-run download.py to regenerate."
|
| 154 |
+
)
|
| 155 |
+
if not isinstance(loss_mask, list) or len(loss_mask) != len(input_ids):
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"Tokenized sample #{idx} has incompatible loss_mask length {len(loss_mask)}. "
|
| 158 |
+
"Re-run download.py to regenerate assistant-only targets."
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
normalized_mask = [int(value) for value in loss_mask]
|
| 162 |
+
if any(value not in (0, 1) for value in normalized_mask):
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"Tokenized sample #{idx} has non-binary loss_mask values. "
|
| 165 |
+
"Re-run download.py to regenerate assistant-only targets."
|
| 166 |
+
)
|
| 167 |
+
if sum(normalized_mask) == 0:
|
| 168 |
+
raise ValueError(
|
| 169 |
+
f"Tokenized sample #{idx} has no assistant target tokens. "
|
| 170 |
+
"Re-run download.py to filter invalid conversations."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return [int(token_id) for token_id in input_ids], normalized_mask
|
| 174 |
+
|
| 175 |
+
def _data_file(self):
|
| 176 |
+
if self._data_handle is None:
|
| 177 |
+
self._data_handle = open(self.data_path, "r", encoding="utf-8")
|
| 178 |
+
return self._data_handle
|
| 179 |
+
|
| 180 |
+
def _load_raw_sample(self, idx: int) -> dict:
|
| 181 |
+
data_file = self._data_file()
|
| 182 |
+
data_file.seek(self.sample_offsets[idx])
|
| 183 |
+
line = data_file.readline()
|
| 184 |
+
if not line:
|
| 185 |
+
raise IndexError(f"Tokenized sample #{idx} could not be read from {self.data_path}.")
|
| 186 |
+
return json.loads(line)
|
| 187 |
+
|
| 188 |
+
def _load_shard_payload(self, shard_idx: int) -> tuple[str, dict]:
|
| 189 |
+
if self._cached_shard_idx == shard_idx and self._cached_shard_payload is not None and self._cached_shard_path is not None:
|
| 190 |
+
return self._cached_shard_path, self._cached_shard_payload
|
| 191 |
+
|
| 192 |
+
shard_path = os.path.join(self.logits_dir, f"shard_{shard_idx:06d}.pt")
|
| 193 |
+
if not os.path.exists(shard_path):
|
| 194 |
+
raise FileNotFoundError(
|
| 195 |
+
f"Missing teacher logit shard: {shard_path}. "
|
| 196 |
+
"Regenerate the teacher-logit shards."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
payload = torch_load_cpu(shard_path)
|
| 200 |
+
self._cached_shard_idx = shard_idx
|
| 201 |
+
self._cached_shard_path = shard_path
|
| 202 |
+
self._cached_shard_payload = payload
|
| 203 |
+
return shard_path, payload
|
| 204 |
+
|
| 205 |
+
def _load_teacher_tensors(self, idx: int, seq_len: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 206 |
+
if self.samples_per_shard <= 1:
|
| 207 |
+
shard_path, shard = self._load_shard_payload(idx)
|
| 208 |
+
try:
|
| 209 |
+
teacher_logprobs = shard["logprobs"][:seq_len]
|
| 210 |
+
teacher_ids = shard["ids"][:seq_len]
|
| 211 |
+
teacher_other_logprob = shard["other_logprob"][:seq_len]
|
| 212 |
+
except KeyError as exc:
|
| 213 |
+
missing = exc.args[0]
|
| 214 |
+
raise KeyError(
|
| 215 |
+
f"Shard {shard_path} is missing {missing!r}. "
|
| 216 |
+
"Regenerate the current teacher-logit shards."
|
| 217 |
+
) from exc
|
| 218 |
+
return teacher_logprobs, teacher_ids, teacher_other_logprob
|
| 219 |
+
|
| 220 |
+
shard_idx = idx // self.samples_per_shard
|
| 221 |
+
sample_offset = idx % self.samples_per_shard
|
| 222 |
+
shard_path, shard = self._load_shard_payload(shard_idx)
|
| 223 |
+
try:
|
| 224 |
+
count = int(shard["count"])
|
| 225 |
+
start_idx = int(shard["start_idx"])
|
| 226 |
+
logprobs_list = shard["logprobs"]
|
| 227 |
+
ids_list = shard["ids"]
|
| 228 |
+
other_list = shard["other_logprob"]
|
| 229 |
+
except KeyError as exc:
|
| 230 |
+
missing = exc.args[0]
|
| 231 |
+
raise KeyError(
|
| 232 |
+
f"Grouped shard {shard_path} is missing {missing!r}. "
|
| 233 |
+
"Regenerate the current teacher-logit shards."
|
| 234 |
+
) from exc
|
| 235 |
+
|
| 236 |
+
expected_start_idx = shard_idx * self.samples_per_shard
|
| 237 |
+
if start_idx != expected_start_idx:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"Grouped shard {shard_path} starts at sample {start_idx}, "
|
| 240 |
+
f"expected {expected_start_idx}. Regenerate the teacher-logit shards."
|
| 241 |
+
)
|
| 242 |
+
if not (len(logprobs_list) == len(ids_list) == len(other_list) == count):
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"Grouped shard {shard_path} has inconsistent sample counts. "
|
| 245 |
+
"Regenerate the current teacher-logit shards."
|
| 246 |
+
)
|
| 247 |
+
if sample_offset >= count:
|
| 248 |
+
raise FileNotFoundError(
|
| 249 |
+
f"Grouped shard {shard_path} does not contain sample #{idx} "
|
| 250 |
+
f"(start_idx={start_idx}, count={count}). Regenerate the teacher-logit shards."
|
| 251 |
+
)
|
| 252 |
+
try:
|
| 253 |
+
teacher_logprobs = logprobs_list[sample_offset][:seq_len]
|
| 254 |
+
teacher_ids = ids_list[sample_offset][:seq_len]
|
| 255 |
+
teacher_other_logprob = other_list[sample_offset][:seq_len]
|
| 256 |
+
except (IndexError, TypeError) as exc:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
f"Grouped shard {shard_path} has an incompatible payload layout. "
|
| 259 |
+
"Regenerate the current teacher-logit shards."
|
| 260 |
+
) from exc
|
| 261 |
+
return teacher_logprobs, teacher_ids, teacher_other_logprob
|
| 262 |
+
|
| 263 |
+
def __getitem__(self, idx: int) -> dict:
|
| 264 |
+
raw_sample = self._load_raw_sample(idx)
|
| 265 |
+
input_ids_list, loss_mask_list = self._coerce_tokenized_row(raw_sample, idx)
|
| 266 |
+
input_ids = torch.tensor(input_ids_list, dtype=torch.long)
|
| 267 |
+
loss_mask = torch.tensor(loss_mask_list, dtype=torch.long)
|
| 268 |
+
seq_len = int(input_ids.size(0))
|
| 269 |
+
|
| 270 |
+
if self.phase in ("sft", "online_kd"):
|
| 271 |
+
return {"input_ids": input_ids, "loss_mask": loss_mask}
|
| 272 |
+
|
| 273 |
+
teacher_logprobs, teacher_ids, teacher_other_logprob = self._load_teacher_tensors(idx, seq_len)
|
| 274 |
+
if teacher_logprobs.shape[0] != seq_len:
|
| 275 |
+
raise ValueError(
|
| 276 |
+
f"Teacher shard for sample #{idx} has length {teacher_logprobs.shape[0]}, "
|
| 277 |
+
f"but the tokenized row has length {seq_len}. Regenerate the teacher-logit shards; "
|
| 278 |
+
"teacher shards must be in original JSONL row order."
|
| 279 |
+
)
|
| 280 |
+
if teacher_logprobs.ndim != 2 or teacher_ids.shape != teacher_logprobs.shape:
|
| 281 |
+
raise ValueError(
|
| 282 |
+
f"Teacher shard for sample #{idx} has incompatible top-k tensor shapes: "
|
| 283 |
+
f"logprobs={tuple(teacher_logprobs.shape)}, ids={tuple(teacher_ids.shape)}. "
|
| 284 |
+
"Regenerate the current teacher-logit shards."
|
| 285 |
+
)
|
| 286 |
+
if teacher_other_logprob.ndim != 1 or teacher_other_logprob.shape[0] != teacher_logprobs.shape[0]:
|
| 287 |
+
raise ValueError(
|
| 288 |
+
f"Teacher shard for sample #{idx} has incompatible other-bucket shape: "
|
| 289 |
+
f"other_logprob={tuple(teacher_other_logprob.shape)}, "
|
| 290 |
+
f"expected ({teacher_logprobs.shape[0]},). "
|
| 291 |
+
"Regenerate the current teacher-logit shards."
|
| 292 |
+
)
|
| 293 |
+
if teacher_logprobs.shape[1] != cfg.training.top_k:
|
| 294 |
+
raise ValueError(
|
| 295 |
+
f"Teacher shard for sample #{idx} stores top_k={teacher_logprobs.shape[1]}, "
|
| 296 |
+
f"expected {cfg.training.top_k}. "
|
| 297 |
+
"Regenerate compatible teacher-logit shards."
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
return {
|
| 301 |
+
"input_ids": input_ids,
|
| 302 |
+
"loss_mask": loss_mask,
|
| 303 |
+
"teacher_logprobs": teacher_logprobs,
|
| 304 |
+
"teacher_ids": teacher_ids.long(),
|
| 305 |
+
"teacher_other_logprob": teacher_other_logprob,
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
def collate_fn(batch: list[dict], pad_token_id: int) -> dict:
|
| 309 |
+
raw_max = max(item["input_ids"].size(0) for item in batch)
|
| 310 |
+
max_len = ((raw_max + PAD_MULTIPLE - 1) // PAD_MULTIPLE) * PAD_MULTIPLE
|
| 311 |
+
|
| 312 |
+
input_ids_list, mask_list, loss_mask_list, labels_list = [], [], [], []
|
| 313 |
+
teacher_logprobs_list, teacher_ids_list, teacher_other_logprob_list = [], [], []
|
| 314 |
+
|
| 315 |
+
for item in batch:
|
| 316 |
+
seq_len = item["input_ids"].size(0)
|
| 317 |
+
pad_len = max_len - seq_len
|
| 318 |
+
padded_loss_mask = F.pad(item["loss_mask"], (0, pad_len), value=0)
|
| 319 |
+
padded_labels = F.pad(item["input_ids"].clone(), (0, pad_len), value=pad_token_id)
|
| 320 |
+
padded_labels = padded_labels.masked_fill(padded_loss_mask == 0, -100)
|
| 321 |
+
|
| 322 |
+
input_ids_list.append(F.pad(item["input_ids"], (0, pad_len), value=pad_token_id))
|
| 323 |
+
mask_list.append(
|
| 324 |
+
torch.cat(
|
| 325 |
+
[
|
| 326 |
+
torch.ones(seq_len, dtype=torch.long),
|
| 327 |
+
torch.zeros(pad_len, dtype=torch.long),
|
| 328 |
+
]
|
| 329 |
+
)
|
| 330 |
+
)
|
| 331 |
+
loss_mask_list.append(padded_loss_mask)
|
| 332 |
+
labels_list.append(padded_labels)
|
| 333 |
+
|
| 334 |
+
if "teacher_logprobs" in item:
|
| 335 |
+
teacher_seq_len = item["teacher_logprobs"].size(0)
|
| 336 |
+
teacher_pad_len = max_len - teacher_seq_len
|
| 337 |
+
teacher_logprobs_list.append(
|
| 338 |
+
F.pad(item["teacher_logprobs"], (0, 0, 0, teacher_pad_len), value=float("-inf"))
|
| 339 |
+
)
|
| 340 |
+
teacher_ids_list.append(F.pad(item["teacher_ids"], (0, 0, 0, teacher_pad_len), value=0))
|
| 341 |
+
teacher_other_logprob_list.append(
|
| 342 |
+
F.pad(item["teacher_other_logprob"], (0, teacher_pad_len), value=float("-inf"))
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
result = {
|
| 346 |
+
"input_ids": torch.stack(input_ids_list),
|
| 347 |
+
"attention_mask": torch.stack(mask_list),
|
| 348 |
+
"loss_mask": torch.stack(loss_mask_list).long(),
|
| 349 |
+
"labels": torch.stack(labels_list),
|
| 350 |
+
}
|
| 351 |
+
if teacher_logprobs_list:
|
| 352 |
+
result["teacher_logprobs"] = torch.stack(teacher_logprobs_list)
|
| 353 |
+
result["teacher_ids"] = torch.stack(teacher_ids_list)
|
| 354 |
+
result["teacher_other_logprob"] = torch.stack(teacher_other_logprob_list)
|
| 355 |
+
return result
|
| 356 |
+
|
| 357 |
+
def resolve_dataloader_runtime() -> dict[str, int | bool]:
|
| 358 |
+
cpu_count = max(1, os.cpu_count() or 1)
|
| 359 |
+
configured_workers = int(getattr(cfg.training, "dataloader_workers", 4))
|
| 360 |
+
num_workers = max(0, min(configured_workers, cpu_count))
|
| 361 |
+
runtime: dict[str, int | bool] = {
|
| 362 |
+
"num_workers": num_workers,
|
| 363 |
+
"pin_memory": torch.cuda.is_available(),
|
| 364 |
+
}
|
| 365 |
+
if num_workers > 0:
|
| 366 |
+
runtime["persistent_workers"] = True
|
| 367 |
+
runtime["prefetch_factor"] = max(1, int(getattr(cfg.training, "prefetch_factor", 2)))
|
| 368 |
+
return runtime
|
| 369 |
+
|
| 370 |
+
def move_batch_to_device(batch: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
|
| 371 |
+
non_blocking = device.type == "cuda"
|
| 372 |
+
return {
|
| 373 |
+
name: tensor.to(device, non_blocking=non_blocking)
|
| 374 |
+
for name, tensor in batch.items()
|
| 375 |
+
}
|
src/training_schedule.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset, Subset
|
| 8 |
+
|
| 9 |
+
def compute_training_schedule(
|
| 10 |
+
dataset_size: int,
|
| 11 |
+
micro_batch_size: int,
|
| 12 |
+
grad_accum: int,
|
| 13 |
+
num_epochs: int,
|
| 14 |
+
use_ds: bool,
|
| 15 |
+
drop_last: bool = True,
|
| 16 |
+
) -> dict[str, int | bool]:
|
| 17 |
+
if dataset_size < 0:
|
| 18 |
+
raise ValueError("dataset_size must be >= 0")
|
| 19 |
+
if micro_batch_size <= 0 or grad_accum <= 0 or num_epochs <= 0:
|
| 20 |
+
raise ValueError("micro_batch_size, grad_accum, and num_epochs must all be positive")
|
| 21 |
+
|
| 22 |
+
if drop_last:
|
| 23 |
+
batches_per_epoch = dataset_size // micro_batch_size
|
| 24 |
+
used_samples_per_epoch = batches_per_epoch * micro_batch_size
|
| 25 |
+
dropped_samples_per_epoch = dataset_size - used_samples_per_epoch
|
| 26 |
+
else:
|
| 27 |
+
batches_per_epoch = math.ceil(dataset_size / micro_batch_size) if dataset_size else 0
|
| 28 |
+
used_samples_per_epoch = dataset_size
|
| 29 |
+
dropped_samples_per_epoch = 0
|
| 30 |
+
|
| 31 |
+
total_micro_batches = batches_per_epoch * num_epochs
|
| 32 |
+
remainder_batches = batches_per_epoch % grad_accum if batches_per_epoch else 0
|
| 33 |
+
has_remainder = remainder_batches != 0
|
| 34 |
+
|
| 35 |
+
if use_ds:
|
| 36 |
+
steps_per_epoch = batches_per_epoch // grad_accum
|
| 37 |
+
total_steps = total_micro_batches // grad_accum
|
| 38 |
+
final_remainder = total_micro_batches % grad_accum
|
| 39 |
+
else:
|
| 40 |
+
steps_per_epoch = batches_per_epoch // grad_accum + (1 if has_remainder and batches_per_epoch else 0)
|
| 41 |
+
total_steps = steps_per_epoch * num_epochs
|
| 42 |
+
final_remainder = 0
|
| 43 |
+
|
| 44 |
+
return {
|
| 45 |
+
"batches_per_epoch": batches_per_epoch,
|
| 46 |
+
"used_samples_per_epoch": used_samples_per_epoch,
|
| 47 |
+
"dropped_samples_per_epoch": dropped_samples_per_epoch,
|
| 48 |
+
"remainder_batches": remainder_batches,
|
| 49 |
+
"has_remainder": has_remainder,
|
| 50 |
+
"total_micro_batches": total_micro_batches,
|
| 51 |
+
"steps_per_epoch": steps_per_epoch,
|
| 52 |
+
"total_steps": total_steps,
|
| 53 |
+
"final_remainder": final_remainder,
|
| 54 |
+
"dropped_samples_total": final_remainder * micro_batch_size if use_ds else 0,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
def choose_validation_size(
|
| 58 |
+
dataset_size: int,
|
| 59 |
+
validation_ratio: float,
|
| 60 |
+
micro_batch_size: int,
|
| 61 |
+
grad_accum: int,
|
| 62 |
+
num_epochs: int,
|
| 63 |
+
use_ds: bool,
|
| 64 |
+
) -> int:
|
| 65 |
+
if not 0.0 <= validation_ratio < 1.0:
|
| 66 |
+
raise ValueError(f"validation_ratio must be in [0, 1), got {validation_ratio}")
|
| 67 |
+
if dataset_size < 2 or validation_ratio <= 0:
|
| 68 |
+
return 0
|
| 69 |
+
|
| 70 |
+
desired_val_size = max(1, int(round(dataset_size * validation_ratio)))
|
| 71 |
+
aligned_candidates: list[tuple[int, int]] = []
|
| 72 |
+
fallback_candidates: list[tuple[int, int]] = []
|
| 73 |
+
for val_size in range(1, dataset_size):
|
| 74 |
+
train_size = dataset_size - val_size
|
| 75 |
+
schedule = compute_training_schedule(
|
| 76 |
+
dataset_size=train_size,
|
| 77 |
+
micro_batch_size=micro_batch_size,
|
| 78 |
+
grad_accum=grad_accum,
|
| 79 |
+
num_epochs=num_epochs,
|
| 80 |
+
use_ds=use_ds,
|
| 81 |
+
drop_last=True,
|
| 82 |
+
)
|
| 83 |
+
if int(schedule["batches_per_epoch"]) == 0:
|
| 84 |
+
continue
|
| 85 |
+
if int(schedule["dropped_samples_per_epoch"]) != 0:
|
| 86 |
+
continue
|
| 87 |
+
candidate = (abs(val_size - desired_val_size), val_size)
|
| 88 |
+
if int(schedule["remainder_batches"]) == 0 and int(schedule["final_remainder"]) == 0:
|
| 89 |
+
aligned_candidates.append(candidate)
|
| 90 |
+
else:
|
| 91 |
+
fallback_candidates.append(candidate)
|
| 92 |
+
|
| 93 |
+
if aligned_candidates:
|
| 94 |
+
return min(aligned_candidates)[1]
|
| 95 |
+
if fallback_candidates:
|
| 96 |
+
return min(fallback_candidates)[1]
|
| 97 |
+
return min(desired_val_size, dataset_size - 1)
|
| 98 |
+
|
| 99 |
+
def build_train_validation_subsets(
|
| 100 |
+
dataset: Dataset,
|
| 101 |
+
validation_ratio: float,
|
| 102 |
+
split_seed: int,
|
| 103 |
+
micro_batch_size: int,
|
| 104 |
+
grad_accum: int,
|
| 105 |
+
num_epochs: int,
|
| 106 |
+
use_ds: bool,
|
| 107 |
+
) -> tuple[Dataset, Dataset | None, dict[str, float | int | bool]]:
|
| 108 |
+
dataset_size = len(dataset)
|
| 109 |
+
validation_size = choose_validation_size(
|
| 110 |
+
dataset_size=dataset_size,
|
| 111 |
+
validation_ratio=validation_ratio,
|
| 112 |
+
micro_batch_size=micro_batch_size,
|
| 113 |
+
grad_accum=grad_accum,
|
| 114 |
+
num_epochs=num_epochs,
|
| 115 |
+
use_ds=use_ds,
|
| 116 |
+
)
|
| 117 |
+
requested_validation_size = max(1, int(round(dataset_size * validation_ratio))) if validation_ratio > 0 else 0
|
| 118 |
+
metadata: dict[str, float | int | bool] = {
|
| 119 |
+
"dataset_size": dataset_size,
|
| 120 |
+
"requested_validation_size": requested_validation_size,
|
| 121 |
+
"validation_size": validation_size,
|
| 122 |
+
"train_size": dataset_size - validation_size,
|
| 123 |
+
"requested_validation_ratio": validation_ratio,
|
| 124 |
+
"actual_validation_ratio": (validation_size / dataset_size) if dataset_size else 0.0,
|
| 125 |
+
"adjusted": validation_size != requested_validation_size,
|
| 126 |
+
}
|
| 127 |
+
train_schedule = compute_training_schedule(
|
| 128 |
+
dataset_size=dataset_size - validation_size,
|
| 129 |
+
micro_batch_size=micro_batch_size,
|
| 130 |
+
grad_accum=grad_accum,
|
| 131 |
+
num_epochs=num_epochs,
|
| 132 |
+
use_ds=use_ds,
|
| 133 |
+
drop_last=True,
|
| 134 |
+
)
|
| 135 |
+
metadata.update(
|
| 136 |
+
{
|
| 137 |
+
"effective_batch_size": micro_batch_size * grad_accum,
|
| 138 |
+
"train_batches_per_epoch": int(train_schedule["batches_per_epoch"]),
|
| 139 |
+
"train_remainder_batches": int(train_schedule["remainder_batches"]),
|
| 140 |
+
"train_dropped_samples_per_epoch": int(train_schedule["dropped_samples_per_epoch"]),
|
| 141 |
+
"accumulation_aligned": int(train_schedule["remainder_batches"]) == 0
|
| 142 |
+
and int(train_schedule["final_remainder"]) == 0,
|
| 143 |
+
}
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if validation_size == 0:
|
| 147 |
+
return dataset, None, metadata
|
| 148 |
+
|
| 149 |
+
generator = torch.Generator().manual_seed(split_seed)
|
| 150 |
+
permutation = torch.randperm(dataset_size, generator=generator).tolist()
|
| 151 |
+
val_indices = sorted(permutation[:validation_size])
|
| 152 |
+
train_indices = sorted(permutation[validation_size:])
|
| 153 |
+
return Subset(dataset, train_indices), Subset(dataset, val_indices), metadata
|
| 154 |
+
|
| 155 |
+
def load_deepspeed_runtime_config(config_path: str, micro_batch_size: int, grad_accum: int) -> dict:
|
| 156 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 157 |
+
ds_cfg = json.load(f)
|
| 158 |
+
|
| 159 |
+
if not isinstance(ds_cfg, dict):
|
| 160 |
+
raise ValueError(f"DeepSpeed config in {config_path} must be a JSON object.")
|
| 161 |
+
|
| 162 |
+
runtime_cfg = dict(ds_cfg)
|
| 163 |
+
runtime_cfg["train_micro_batch_size_per_gpu"] = micro_batch_size
|
| 164 |
+
runtime_cfg["gradient_accumulation_steps"] = grad_accum
|
| 165 |
+
return runtime_cfg
|
src/transformers_compat.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
import importlib.util
|
| 5 |
+
import os
|
| 6 |
+
from configs import cfg
|
| 7 |
+
|
| 8 |
+
def _false() -> bool:
|
| 9 |
+
return False
|
| 10 |
+
|
| 11 |
+
def describe_exception_chain(exc: Exception) -> str:
|
| 12 |
+
messages: list[str] = []
|
| 13 |
+
seen: set[int] = set()
|
| 14 |
+
current: BaseException | None = exc
|
| 15 |
+
|
| 16 |
+
while current is not None and id(current) not in seen:
|
| 17 |
+
seen.add(id(current))
|
| 18 |
+
message = f"{type(current).__name__}: {current}"
|
| 19 |
+
if message not in messages:
|
| 20 |
+
messages.append(message)
|
| 21 |
+
current = current.__cause__ or current.__context__
|
| 22 |
+
|
| 23 |
+
return " -> ".join(messages)
|
| 24 |
+
|
| 25 |
+
def disable_flash_attn_for_transformers() -> None:
|
| 26 |
+
try:
|
| 27 |
+
import transformers.utils as tf_utils
|
| 28 |
+
|
| 29 |
+
tf_utils.is_flash_attn_2_available = _false
|
| 30 |
+
if hasattr(tf_utils, "is_flash_attn_3_available"):
|
| 31 |
+
tf_utils.is_flash_attn_3_available = _false
|
| 32 |
+
except Exception:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from transformers.utils import import_utils as tf_import_utils
|
| 37 |
+
|
| 38 |
+
tf_import_utils.is_flash_attn_2_available = _false
|
| 39 |
+
if hasattr(tf_import_utils, "is_flash_attn_3_available"):
|
| 40 |
+
tf_import_utils.is_flash_attn_3_available = _false
|
| 41 |
+
except Exception:
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
import transformers.modeling_utils as modeling_utils
|
| 46 |
+
|
| 47 |
+
if hasattr(modeling_utils, "is_flash_attn_2_available"):
|
| 48 |
+
modeling_utils.is_flash_attn_2_available = _false
|
| 49 |
+
if hasattr(modeling_utils, "is_flash_attn_3_available"):
|
| 50 |
+
modeling_utils.is_flash_attn_3_available = _false
|
| 51 |
+
except Exception:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
try:
|
| 55 |
+
import transformers.modeling_flash_attention_utils as flash_utils
|
| 56 |
+
|
| 57 |
+
flash_utils.is_flash_attn_2_available = _false
|
| 58 |
+
if hasattr(flash_utils, "is_flash_attn_3_available"):
|
| 59 |
+
flash_utils.is_flash_attn_3_available = _false
|
| 60 |
+
except Exception:
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
def resolve_attention_backend(logger) -> str:
|
| 64 |
+
forced_backend = os.environ.get("QUINTUS_ATTENTION_BACKEND")
|
| 65 |
+
if forced_backend:
|
| 66 |
+
logger.info(f" [ATTENTION] Forced backend via QUINTUS_ATTENTION_BACKEND={forced_backend!r}.")
|
| 67 |
+
return forced_backend
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
from transformers.utils import is_flash_attn_3_available
|
| 71 |
+
if is_flash_attn_3_available():
|
| 72 |
+
logger.info(" [ATTENTION] Using flash_attention_3.")
|
| 73 |
+
return "flash_attention_3"
|
| 74 |
+
except Exception:
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
importlib.import_module("flash_attn")
|
| 79 |
+
logger.info(" [ATTENTION] Using flash_attention_2.")
|
| 80 |
+
return "flash_attention_2"
|
| 81 |
+
except Exception as exc:
|
| 82 |
+
if importlib.util.find_spec("flash_attn") is not None:
|
| 83 |
+
disable_flash_attn_for_transformers()
|
| 84 |
+
logger.warning(
|
| 85 |
+
"flash-attn appears installed but failed to import (%s: %s); "
|
| 86 |
+
"masking flash-attn from Transformers and falling back to sdpa.",
|
| 87 |
+
type(exc).__name__,
|
| 88 |
+
exc,
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
logger.info(" [ATTENTION] Using PyTorch SDPA.")
|
| 92 |
+
return "sdpa"
|
| 93 |
+
|
| 94 |
+
def _requires_remote_code_opt_in(exc: Exception) -> bool:
|
| 95 |
+
message = str(exc).lower()
|
| 96 |
+
return (
|
| 97 |
+
"trust_remote_code" in message
|
| 98 |
+
or "requires you to execute the configuration file" in message
|
| 99 |
+
or "requires remote code" in message
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def format_model_load_error(subject: str, exc: Exception) -> str:
|
| 103 |
+
if not cfg.model.allow_remote_code and _requires_remote_code_opt_in(exc):
|
| 104 |
+
return (
|
| 105 |
+
f"{subject} failed because the selected model/tokenizer requires remote code, "
|
| 106 |
+
"but Quintus is configured with allow_remote_code=false. Review the upstream "
|
| 107 |
+
"repository and rerun with QUINTUS_ALLOW_REMOTE_CODE=1 only if you explicitly "
|
| 108 |
+
"trust that code."
|
| 109 |
+
)
|
| 110 |
+
return f"{subject} failed: {describe_exception_chain(exc)}"
|
src/validation.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
from src.losses import compute_loss_for_phase
|
| 7 |
+
from src.training_data import move_batch_to_device
|
| 8 |
+
|
| 9 |
+
def evaluate_validation_loss(
|
| 10 |
+
phase: str,
|
| 11 |
+
model,
|
| 12 |
+
dataloader: DataLoader,
|
| 13 |
+
device: torch.device,
|
| 14 |
+
alpha: float,
|
| 15 |
+
temperature: float,
|
| 16 |
+
online_kd_token_chunk_size: int = 2048,
|
| 17 |
+
teacher_model=None,
|
| 18 |
+
max_batches: int = -1,
|
| 19 |
+
) -> dict[str, float | int]:
|
| 20 |
+
was_training = model.training
|
| 21 |
+
model.eval()
|
| 22 |
+
|
| 23 |
+
total_loss = 0.0
|
| 24 |
+
total_ce = 0.0
|
| 25 |
+
total_kd = 0.0
|
| 26 |
+
batches = 0
|
| 27 |
+
|
| 28 |
+
with torch.inference_mode():
|
| 29 |
+
for batch in dataloader:
|
| 30 |
+
if max_batches > 0 and batches >= max_batches:
|
| 31 |
+
break
|
| 32 |
+
batch = move_batch_to_device(batch, device)
|
| 33 |
+
input_ids = batch["input_ids"]
|
| 34 |
+
attention_mask = batch["attention_mask"]
|
| 35 |
+
labels = batch["labels"]
|
| 36 |
+
loss_mask = batch["loss_mask"]
|
| 37 |
+
|
| 38 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
|
| 39 |
+
|
| 40 |
+
if phase == "online_kd" and teacher_model is not None:
|
| 41 |
+
teacher_logits = teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits
|
| 42 |
+
else:
|
| 43 |
+
teacher_logits = None
|
| 44 |
+
|
| 45 |
+
loss, ce, kd = compute_loss_for_phase(
|
| 46 |
+
phase,
|
| 47 |
+
logits,
|
| 48 |
+
labels,
|
| 49 |
+
loss_mask,
|
| 50 |
+
batch,
|
| 51 |
+
alpha,
|
| 52 |
+
temperature,
|
| 53 |
+
teacher_logits=teacher_logits,
|
| 54 |
+
online_kd_token_chunk_size=online_kd_token_chunk_size,
|
| 55 |
+
)
|
| 56 |
+
total_loss += float(loss.detach().item())
|
| 57 |
+
total_ce += float(ce.detach().item())
|
| 58 |
+
total_kd += float(kd.detach().item())
|
| 59 |
+
batches += 1
|
| 60 |
+
|
| 61 |
+
if was_training:
|
| 62 |
+
model.train()
|
| 63 |
+
|
| 64 |
+
denom = max(batches, 1)
|
| 65 |
+
return {
|
| 66 |
+
"loss": total_loss / denom,
|
| 67 |
+
"ce": total_ce / denom,
|
| 68 |
+
"kd": total_kd / denom,
|
| 69 |
+
"batches": batches,
|
| 70 |
+
}
|
weight_audit/quintus_weight_audit.py
ADDED
|
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage : python audit.py \
|
| 3 |
+
--base_model Qwen/Qwen3-1.7B-Base \
|
| 4 |
+
--distilled_model iamrahulreddy/Quintus \
|
| 5 |
+
--output_file weight_audit_report.txt \
|
| 6 |
+
--alpha 0.3
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import collections
|
| 11 |
+
import math
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
from datetime import datetime, timezone
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from huggingface_hub import snapshot_download
|
| 20 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 21 |
+
|
| 22 |
+
# Formatting utilities
|
| 23 |
+
def fmt_num(n: int) -> str:
|
| 24 |
+
if n >= 1_000_000_000:
|
| 25 |
+
return f"{n:,} ({n / 1e9:.6f} B)"
|
| 26 |
+
if n >= 1_000_000:
|
| 27 |
+
return f"{n:,} ({n / 1e6:.6f} M)"
|
| 28 |
+
return f"{n:,}"
|
| 29 |
+
|
| 30 |
+
def fmt_size(b: int) -> str:
|
| 31 |
+
if b >= 1 << 30:
|
| 32 |
+
return f"{b / (1 << 30):.3f} GiB"
|
| 33 |
+
if b >= 1 << 20:
|
| 34 |
+
return f"{b / (1 << 20):.3f} MiB"
|
| 35 |
+
if b >= 1 << 10:
|
| 36 |
+
return f"{b / (1 << 10):.3f} KiB"
|
| 37 |
+
return f"{b} B"
|
| 38 |
+
|
| 39 |
+
def divider(char: str = "-", width: int = 88) -> str:
|
| 40 |
+
return char * width
|
| 41 |
+
|
| 42 |
+
def section_header(index: int, title: str) -> str:
|
| 43 |
+
return f"\n[{index:02d}] {title}"
|
| 44 |
+
|
| 45 |
+
def sub_header(title: str) -> str:
|
| 46 |
+
return f"\n -- {title}"
|
| 47 |
+
|
| 48 |
+
# Layer classification
|
| 49 |
+
LAYER_TYPE_MAP = {
|
| 50 |
+
"embed_tokens": "embedding",
|
| 51 |
+
"lm_head": "lm_head",
|
| 52 |
+
"self_attn.q_proj": "attn_q",
|
| 53 |
+
"self_attn.k_proj": "attn_k",
|
| 54 |
+
"self_attn.v_proj": "attn_v",
|
| 55 |
+
"self_attn.o_proj": "attn_o",
|
| 56 |
+
"self_attn.q_norm": "attn_qnorm",
|
| 57 |
+
"self_attn.k_norm": "attn_knorm",
|
| 58 |
+
"mlp.gate_proj": "mlp_gate",
|
| 59 |
+
"mlp.up_proj": "mlp_up",
|
| 60 |
+
"mlp.down_proj": "mlp_down",
|
| 61 |
+
"input_layernorm": "layernorm",
|
| 62 |
+
"post_attention_layernorm": "layernorm",
|
| 63 |
+
"model.norm": "final_norm",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def classify_layer(name: str) -> str:
|
| 67 |
+
for pattern, label in LAYER_TYPE_MAP.items():
|
| 68 |
+
if pattern in name:
|
| 69 |
+
return label
|
| 70 |
+
return "other"
|
| 71 |
+
|
| 72 |
+
# Tensor statistics
|
| 73 |
+
def tensor_stats(t: torch.Tensor) -> dict:
|
| 74 |
+
tf = t.float()
|
| 75 |
+
flat = tf.view(-1)
|
| 76 |
+
mean = flat.mean().item()
|
| 77 |
+
std = flat.std().item()
|
| 78 |
+
|
| 79 |
+
sparsity = (flat.abs() < 1e-6).float().mean().item()
|
| 80 |
+
sat_thresh = flat.abs().max().item() * 0.99
|
| 81 |
+
saturation = (flat.abs() >= sat_thresh).float().mean().item()
|
| 82 |
+
kurtosis = (((flat - mean) / std) ** 4).mean().item() - 3.0 if std > 1e-10 else 0.0
|
| 83 |
+
outlier_r = (flat.abs() > (flat.abs().mean() + 3.0 * std)).float().mean().item()
|
| 84 |
+
|
| 85 |
+
row_l2_stats = {}
|
| 86 |
+
if tf.ndim == 2:
|
| 87 |
+
row_norms = tf.norm(2, dim=1)
|
| 88 |
+
row_l2_stats = {
|
| 89 |
+
"row_l2_mean": row_norms.mean().item(),
|
| 90 |
+
"row_l2_std": row_norms.std().item(),
|
| 91 |
+
"row_l2_min": row_norms.min().item(),
|
| 92 |
+
"row_l2_max": row_norms.max().item(),
|
| 93 |
+
"dead_rows": int((row_norms < 1e-6).sum().item()),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
"shape": list(tf.shape),
|
| 98 |
+
"numel": flat.numel(),
|
| 99 |
+
"dtype": str(t.dtype),
|
| 100 |
+
"mean": mean,
|
| 101 |
+
"std": std,
|
| 102 |
+
"min": flat.min().item(),
|
| 103 |
+
"max": flat.max().item(),
|
| 104 |
+
"abs_mean": flat.abs().mean().item(),
|
| 105 |
+
"l2_norm": flat.norm(2).item(),
|
| 106 |
+
"l1_norm": flat.norm(1).item(),
|
| 107 |
+
"sparsity": sparsity,
|
| 108 |
+
"saturation": saturation,
|
| 109 |
+
"kurtosis": kurtosis,
|
| 110 |
+
"outlier_ratio": outlier_r,
|
| 111 |
+
**row_l2_stats,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Divergence between two tensors
|
| 115 |
+
def tensor_divergence(t_base: torch.Tensor, t_dist: torch.Tensor, chunk_size: int = 10_000_000) -> dict:
|
| 116 |
+
a_flat = t_base.detach().view(-1)
|
| 117 |
+
b_flat = t_dist.detach().view(-1)
|
| 118 |
+
n_elements = a_flat.numel()
|
| 119 |
+
|
| 120 |
+
# Running accumulators in float64 (on CPU/Python) to prevent memory spikes
|
| 121 |
+
dot_prod = 0.0
|
| 122 |
+
a_sq_sum = 0.0
|
| 123 |
+
b_sq_sum = 0.0
|
| 124 |
+
|
| 125 |
+
# Delta statistics
|
| 126 |
+
max_delta = 0.0
|
| 127 |
+
sum_delta = 0.0
|
| 128 |
+
l2_delta_sq = 0.0
|
| 129 |
+
sum_abs_a = 0.0
|
| 130 |
+
|
| 131 |
+
# Process in chunks to keep memory footprint extremely small (~80MB peak per chunk)
|
| 132 |
+
for i in range(0, n_elements, chunk_size):
|
| 133 |
+
a_chunk = a_flat[i : i + chunk_size].to(torch.float64)
|
| 134 |
+
b_chunk = b_flat[i : i + chunk_size].to(torch.float64)
|
| 135 |
+
|
| 136 |
+
# Accumulate dot product and norms
|
| 137 |
+
dot_prod += torch.dot(a_chunk, b_chunk).item()
|
| 138 |
+
a_sq_sum += torch.dot(a_chunk, a_chunk).item()
|
| 139 |
+
b_sq_sum += torch.dot(b_chunk, b_chunk).item()
|
| 140 |
+
|
| 141 |
+
# Accumulate delta stats
|
| 142 |
+
delta_chunk = (b_chunk - a_chunk).abs()
|
| 143 |
+
max_delta = max(max_delta, delta_chunk.max().item())
|
| 144 |
+
sum_delta += delta_chunk.sum().item()
|
| 145 |
+
l2_delta_sq += torch.dot(delta_chunk, delta_chunk).item()
|
| 146 |
+
sum_abs_a += a_chunk.abs().sum().item()
|
| 147 |
+
|
| 148 |
+
# Final metrics
|
| 149 |
+
a_norm = math.sqrt(a_sq_sum)
|
| 150 |
+
b_norm = math.sqrt(b_sq_sum)
|
| 151 |
+
if a_norm > 0 and b_norm > 0:
|
| 152 |
+
cos_sim_raw = dot_prod / (a_norm * b_norm)
|
| 153 |
+
else:
|
| 154 |
+
cos_sim_raw = 0.0
|
| 155 |
+
|
| 156 |
+
cos_sim = max(-1.0, min(1.0, cos_sim_raw))
|
| 157 |
+
|
| 158 |
+
rel_err = sum_delta / (sum_abs_a + 1e-12)
|
| 159 |
+
base_l2 = a_norm
|
| 160 |
+
delta_l2 = math.sqrt(l2_delta_sq)
|
| 161 |
+
snr_db = 20.0 * math.log10(base_l2 / (delta_l2 + 1e-12)) if base_l2 > 0 else 0.0
|
| 162 |
+
|
| 163 |
+
# Standard deviation of delta
|
| 164 |
+
mean_delta = sum_delta / n_elements
|
| 165 |
+
mean_delta_sq = l2_delta_sq / n_elements
|
| 166 |
+
var_delta = max(0.0, mean_delta_sq - mean_delta**2)
|
| 167 |
+
std_delta = math.sqrt(var_delta)
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"max_delta": max_delta,
|
| 171 |
+
"mean_delta": mean_delta,
|
| 172 |
+
"std_delta": std_delta,
|
| 173 |
+
"l2_delta": delta_l2,
|
| 174 |
+
"cos_sim": cos_sim,
|
| 175 |
+
"cos_sim_raw": cos_sim_raw,
|
| 176 |
+
"rel_err": rel_err,
|
| 177 |
+
"snr_db": snr_db,
|
| 178 |
+
"changed": max_delta > 1e-7,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
# Isotropy
|
| 182 |
+
def isotropy_score(t: torch.Tensor, n_samples: int = 2048) -> float:
|
| 183 |
+
"""
|
| 184 |
+
Average pairwise cosine similarity of randomly sampled row vectors.
|
| 185 |
+
Near 0 = isotropic (healthy). Near 1 = collapsed representations.
|
| 186 |
+
Only valid for 2D tensors with >= 2 rows.
|
| 187 |
+
"""
|
| 188 |
+
if t.ndim != 2 or t.shape[0] < 2:
|
| 189 |
+
return float("nan")
|
| 190 |
+
tf = t.float()
|
| 191 |
+
n = min(t.shape[0], n_samples)
|
| 192 |
+
# Add deterministic seed for isotropy sampling
|
| 193 |
+
gen = torch.Generator().manual_seed(42)
|
| 194 |
+
idx = torch.randperm(t.shape[0], generator=gen)[:n].to(t.device)
|
| 195 |
+
rows = tf[idx]
|
| 196 |
+
norms = rows.norm(2, dim=1, keepdim=True).clamp(min=1e-12)
|
| 197 |
+
normed = rows / norms
|
| 198 |
+
sim = normed @ normed.T
|
| 199 |
+
mask = ~torch.eye(n, dtype=torch.bool)
|
| 200 |
+
return sim[mask].mean().item()
|
| 201 |
+
|
| 202 |
+
# Config helpers
|
| 203 |
+
def config_architecture_lines(config, label: str, model_id: str) -> list[str]:
|
| 204 |
+
cfg = config.to_dict()
|
| 205 |
+
n_q = cfg.get("num_attention_heads", 1)
|
| 206 |
+
n_kv = cfg.get("num_key_value_heads", n_q)
|
| 207 |
+
h = cfg.get("hidden_size", 0)
|
| 208 |
+
head_dim = h // n_q if n_q else 0
|
| 209 |
+
gqa = n_q // n_kv if n_kv else 1
|
| 210 |
+
|
| 211 |
+
return [
|
| 212 |
+
f" label : {label} ({model_id})",
|
| 213 |
+
f" model_type : {cfg.get('model_type', 'unknown')}",
|
| 214 |
+
f" architecture : {cfg.get('architectures', ['unknown'])[0]}",
|
| 215 |
+
"",
|
| 216 |
+
" Vocabulary",
|
| 217 |
+
f" vocab_size : {cfg.get('vocab_size', 'N/A'):,}",
|
| 218 |
+
f" bos / eos / pad : {cfg.get('bos_token_id')} / {cfg.get('eos_token_id')} / {cfg.get('pad_token_id')}",
|
| 219 |
+
"",
|
| 220 |
+
" Positional encoding",
|
| 221 |
+
f" max_position_embeddings: {cfg.get('max_position_embeddings', 'N/A'):,}",
|
| 222 |
+
f" rope_theta : {cfg.get('rope_theta', 'N/A')}",
|
| 223 |
+
f" rope_scaling : {cfg.get('rope_scaling', 'None')}",
|
| 224 |
+
"",
|
| 225 |
+
" Transformer dimensions",
|
| 226 |
+
f" hidden_size : {h}",
|
| 227 |
+
f" num_hidden_layers : {cfg.get('num_hidden_layers', 'N/A')}",
|
| 228 |
+
f" intermediate_size : {cfg.get('intermediate_size', 'N/A')}",
|
| 229 |
+
"",
|
| 230 |
+
" Attention",
|
| 231 |
+
f" num_attention_heads : {n_q}",
|
| 232 |
+
f" num_key_value_heads : {n_kv}",
|
| 233 |
+
f" head_dim : {head_dim}",
|
| 234 |
+
f" GQA ratio : {gqa}:1",
|
| 235 |
+
f" attention_bias : {cfg.get('attention_bias', False)}",
|
| 236 |
+
f" use_qk_norm : {cfg.get('use_qk_norm', False) or 'qwen3' in model_id.lower() or 'qwen3' in cfg.get('model_type', '').lower()}",
|
| 237 |
+
f" sliding_window : {cfg.get('sliding_window', 'None')}",
|
| 238 |
+
"",
|
| 239 |
+
" Feed-forward",
|
| 240 |
+
f" hidden_act : {cfg.get('hidden_act', 'silu')}",
|
| 241 |
+
f" mlp_bias : {cfg.get('mlp_bias', False)}",
|
| 242 |
+
"",
|
| 243 |
+
" Misc",
|
| 244 |
+
f" rms_norm_eps : {cfg.get('rms_norm_eps', 1e-6)}",
|
| 245 |
+
f" tie_word_embeddings : {cfg.get('tie_word_embeddings', True)}",
|
| 246 |
+
f" use_cache : {cfg.get('use_cache', True)}",
|
| 247 |
+
f" torch_dtype : {cfg.get('torch_dtype', 'float32')}",
|
| 248 |
+
f" initializer_range : {cfg.get('initializer_range', 'N/A')}",
|
| 249 |
+
]
|
| 250 |
+
|
| 251 |
+
def get_params_info(config, model_id: str = "") -> dict:
|
| 252 |
+
h = config.hidden_size
|
| 253 |
+
l = config.num_hidden_layers
|
| 254 |
+
v = config.vocab_size
|
| 255 |
+
embed = v * h
|
| 256 |
+
tie = getattr(config, "tie_word_embeddings", True)
|
| 257 |
+
n_q = config.num_attention_heads
|
| 258 |
+
n_kv = getattr(config, "num_key_value_heads", n_q)
|
| 259 |
+
head_dim = h // n_q
|
| 260 |
+
qkv_proj = (n_q + 2 * n_kv) * head_dim * h
|
| 261 |
+
o_proj = h * h
|
| 262 |
+
use_qk_norm = (
|
| 263 |
+
getattr(config, "use_qk_norm", False) or
|
| 264 |
+
"qwen3" in model_id.lower() or
|
| 265 |
+
"qwen3" in getattr(config, "model_type", "").lower()
|
| 266 |
+
)
|
| 267 |
+
qk_norm = 2 * head_dim if use_qk_norm else 0
|
| 268 |
+
mlp = 3 * h * config.intermediate_size
|
| 269 |
+
norms = 2 * h
|
| 270 |
+
per_layer = qkv_proj + o_proj + qk_norm + mlp + norms
|
| 271 |
+
total_layers = l * per_layer
|
| 272 |
+
lm_head = 0 if tie else embed
|
| 273 |
+
unique = embed + lm_head + total_layers + h # +h for final norm
|
| 274 |
+
return {
|
| 275 |
+
"raw": unique + (embed if tie else 0),
|
| 276 |
+
"embed": embed,
|
| 277 |
+
"lm_head": embed,
|
| 278 |
+
"tied": tie,
|
| 279 |
+
"unique": unique,
|
| 280 |
+
"non_embed": total_layers + h,
|
| 281 |
+
"per_layer": per_layer,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
def param_lines(config, p: dict, label: str) -> list[str]:
|
| 285 |
+
return [
|
| 286 |
+
f" {label}",
|
| 287 |
+
f" raw (all named) : {fmt_num(p['raw'])}",
|
| 288 |
+
f" embedding : {fmt_num(p['embed'])}",
|
| 289 |
+
f" lm_head : {fmt_num(p['lm_head'])}",
|
| 290 |
+
f" tied : {p['tied']}",
|
| 291 |
+
f" unique (deduped) : {fmt_num(p['unique'])}",
|
| 292 |
+
f" non-embedding : {fmt_num(p['non_embed'])}",
|
| 293 |
+
f" per layer (approx) : {p['per_layer']:,}",
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
# Main
|
| 297 |
+
def main():
|
| 298 |
+
parser = argparse.ArgumentParser(description="Quintus Deep Weight Audit")
|
| 299 |
+
parser.add_argument("--base_model", type=str, default="Qwen/Qwen3-1.7B-Base")
|
| 300 |
+
parser.add_argument("--distilled_model", type=str, default="iamrahulreddy/Quintus")
|
| 301 |
+
parser.add_argument("--output_file", type=str, default="weight_audit_report.txt")
|
| 302 |
+
parser.add_argument("--alpha", type=float, default=0.3)
|
| 303 |
+
parser.add_argument("--isotropy_samples", type=int, default=2048)
|
| 304 |
+
parser.add_argument("--trust_remote_code", action="store_true", help="Allow custom code from model repositories.")
|
| 305 |
+
args = parser.parse_args()
|
| 306 |
+
|
| 307 |
+
# Determine compute device
|
| 308 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 309 |
+
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 310 |
+
|
| 311 |
+
utc_ts = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
|
| 312 |
+
loc_ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S local")
|
| 313 |
+
|
| 314 |
+
R: list[str] = []
|
| 315 |
+
|
| 316 |
+
def log(line: str = ""):
|
| 317 |
+
print(line)
|
| 318 |
+
R.append(line)
|
| 319 |
+
|
| 320 |
+
def loglines(lines: list[str]):
|
| 321 |
+
for ln in lines:
|
| 322 |
+
log(ln)
|
| 323 |
+
|
| 324 |
+
# Header
|
| 325 |
+
loglines([
|
| 326 |
+
divider("="),
|
| 327 |
+
" QUINTUS WEIGHT AUDIT",
|
| 328 |
+
divider("="),
|
| 329 |
+
f" {utc_ts} ({loc_ts})",
|
| 330 |
+
f" base model : {args.base_model}",
|
| 331 |
+
f" distilled model : {args.distilled_model}",
|
| 332 |
+
f" alpha : {args.alpha}",
|
| 333 |
+
f" device : {device} | dtype: {dtype}",
|
| 334 |
+
f" python : {sys.version.split()[0]} | torch: {torch.__version__}",
|
| 335 |
+
divider("="),
|
| 336 |
+
])
|
| 337 |
+
|
| 338 |
+
# [01] Resolve checkpoints
|
| 339 |
+
log(section_header(1, "Resolve checkpoints"))
|
| 340 |
+
|
| 341 |
+
# Resolve base model commit hash (pin and report base commit)
|
| 342 |
+
base_commit = "local"
|
| 343 |
+
if not Path(args.base_model).exists():
|
| 344 |
+
try:
|
| 345 |
+
base_local_dir = Path(snapshot_download(repo_id=args.base_model))
|
| 346 |
+
base_commit = base_local_dir.name
|
| 347 |
+
except Exception:
|
| 348 |
+
base_commit = "unknown"
|
| 349 |
+
|
| 350 |
+
dist_commit = "local"
|
| 351 |
+
if not Path(args.distilled_model).exists():
|
| 352 |
+
log(f" Downloading '{args.distilled_model}' from HuggingFace Hub...")
|
| 353 |
+
t0 = time.time()
|
| 354 |
+
try:
|
| 355 |
+
local_dir = snapshot_download(repo_id=args.distilled_model)
|
| 356 |
+
distilled_path = Path(local_dir)
|
| 357 |
+
dist_commit = distilled_path.name
|
| 358 |
+
except Exception as e:
|
| 359 |
+
log(f" ERROR: {e}")
|
| 360 |
+
sys.exit(1)
|
| 361 |
+
log(f" Done in {time.time() - t0:.1f}s")
|
| 362 |
+
else:
|
| 363 |
+
distilled_path = Path(args.distilled_model)
|
| 364 |
+
if "snapshots" in distilled_path.parts:
|
| 365 |
+
dist_commit = distilled_path.name
|
| 366 |
+
|
| 367 |
+
# Redact absolute local HF cache paths for sharing
|
| 368 |
+
redacted_root = "<HF_CACHE_DIR>/snapshots"
|
| 369 |
+
log(f" base model commit : {base_commit}")
|
| 370 |
+
log(f" distilled commit : {dist_commit}")
|
| 371 |
+
log(f" snapshot root : {redacted_root}")
|
| 372 |
+
|
| 373 |
+
if not (distilled_path / "config.json").exists():
|
| 374 |
+
log(" ERROR: config.json missing from checkpoint directory.")
|
| 375 |
+
sys.exit(1)
|
| 376 |
+
|
| 377 |
+
files = sorted(f for f in distilled_path.iterdir() if f.is_file())
|
| 378 |
+
total_ckpt_bytes = sum(f.stat().st_size for f in files)
|
| 379 |
+
log("")
|
| 380 |
+
log(f" {'Filename':<52} {'Size':>12} Modified")
|
| 381 |
+
for f in files:
|
| 382 |
+
mtime = datetime.fromtimestamp(f.stat().st_mtime).strftime("%Y-%m-%d %H:%M")
|
| 383 |
+
log(f" {f.name:<52} {fmt_size(f.stat().st_size):>12} {mtime}")
|
| 384 |
+
log(f" {'total':<52} {fmt_size(total_ckpt_bytes):>12}")
|
| 385 |
+
|
| 386 |
+
# [02] Architecture configuration
|
| 387 |
+
log(section_header(2, "Architecture configuration"))
|
| 388 |
+
|
| 389 |
+
log(" Loading base config...")
|
| 390 |
+
try:
|
| 391 |
+
base_config = AutoConfig.from_pretrained(args.base_model, trust_remote_code=args.trust_remote_code)
|
| 392 |
+
except Exception as e:
|
| 393 |
+
log(f" ERROR: {e}"); sys.exit(1)
|
| 394 |
+
|
| 395 |
+
log(" Loading distilled config...")
|
| 396 |
+
try:
|
| 397 |
+
distilled_config = AutoConfig.from_pretrained(str(distilled_path), trust_remote_code=args.trust_remote_code)
|
| 398 |
+
except Exception as e:
|
| 399 |
+
log(f" ERROR: {e}"); sys.exit(1)
|
| 400 |
+
|
| 401 |
+
log(sub_header("Base"))
|
| 402 |
+
loglines(config_architecture_lines(base_config, "base", args.base_model))
|
| 403 |
+
|
| 404 |
+
log(sub_header("Distilled"))
|
| 405 |
+
loglines(config_architecture_lines(distilled_config, "distilled", args.distilled_model))
|
| 406 |
+
|
| 407 |
+
log(sub_header("Config diff (ignoring: _name_or_path, transformers_version)"))
|
| 408 |
+
ignore_keys = {"_name_or_path", "transformers_version"}
|
| 409 |
+
base_dict = base_config.to_dict()
|
| 410 |
+
dist_dict = distilled_config.to_dict()
|
| 411 |
+
config_diffs = [
|
| 412 |
+
(k, base_dict.get(k), dist_dict.get(k))
|
| 413 |
+
for k in sorted(set(base_dict) | set(dist_dict))
|
| 414 |
+
if k not in ignore_keys and base_dict.get(k) != dist_dict.get(k)
|
| 415 |
+
]
|
| 416 |
+
if not config_diffs:
|
| 417 |
+
log(" No differences — configs identical (expected for same-architecture KD).")
|
| 418 |
+
else:
|
| 419 |
+
log(f" {'Key':<40} {'Base':>28} Distilled")
|
| 420 |
+
for k, vb, vd in config_diffs:
|
| 421 |
+
log(f" {k:<40} {str(vb):>28} {vd}")
|
| 422 |
+
|
| 423 |
+
# [03] Parameter accounting
|
| 424 |
+
log(section_header(3, "Parameter accounting"))
|
| 425 |
+
|
| 426 |
+
base_params = get_params_info(base_config, args.base_model)
|
| 427 |
+
dist_params = get_params_info(distilled_config, args.distilled_model)
|
| 428 |
+
|
| 429 |
+
log(sub_header("Base"))
|
| 430 |
+
loglines(param_lines(base_config, base_params, "base"))
|
| 431 |
+
|
| 432 |
+
log(sub_header("Distilled"))
|
| 433 |
+
loglines(param_lines(distilled_config, dist_params, "distilled"))
|
| 434 |
+
|
| 435 |
+
log(sub_header("Delta"))
|
| 436 |
+
du = dist_params["unique"] - base_params["unique"]
|
| 437 |
+
log(f" unique param delta : {du:+,} ({du / base_params['unique'] * 100:+.4f} %)")
|
| 438 |
+
log(f" non-embed param delta : {dist_params['non_embed'] - base_params['non_embed']:+,}")
|
| 439 |
+
|
| 440 |
+
# [04] Load weights onto GPU
|
| 441 |
+
log(section_header(4, "Load weights"))
|
| 442 |
+
log(f" device: {device} | dtype: {dtype}")
|
| 443 |
+
|
| 444 |
+
load_kwargs = dict(dtype=dtype, device_map=device, trust_remote_code=args.trust_remote_code)
|
| 445 |
+
|
| 446 |
+
log(f" Loading base model : {args.base_model}")
|
| 447 |
+
t0 = time.time()
|
| 448 |
+
base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **load_kwargs)
|
| 449 |
+
log(f" Done in {time.time() - t0:.1f}s")
|
| 450 |
+
|
| 451 |
+
log(f" Loading distilled : {args.distilled_model}")
|
| 452 |
+
t0 = time.time()
|
| 453 |
+
distilled_model = AutoModelForCausalLM.from_pretrained(str(distilled_path), **load_kwargs)
|
| 454 |
+
log(f" Done in {time.time() - t0:.1f}s")
|
| 455 |
+
|
| 456 |
+
base_sd = base_model.state_dict()
|
| 457 |
+
dist_sd = distilled_model.state_dict()
|
| 458 |
+
|
| 459 |
+
log(f" base tensors : {len(base_sd)}")
|
| 460 |
+
log(f" distilled tensors : {len(dist_sd)}")
|
| 461 |
+
|
| 462 |
+
only_base = set(base_sd) - set(dist_sd)
|
| 463 |
+
only_dist = set(dist_sd) - set(base_sd)
|
| 464 |
+
if only_base:
|
| 465 |
+
log(f" keys only in base : {sorted(only_base)[:5]} ...")
|
| 466 |
+
if only_dist:
|
| 467 |
+
log(f" keys only in distilled: {sorted(only_dist)[:5]} ...")
|
| 468 |
+
|
| 469 |
+
tied = torch.equal(
|
| 470 |
+
base_sd["model.embed_tokens.weight"],
|
| 471 |
+
base_sd.get("lm_head.weight", base_sd["model.embed_tokens.weight"]),
|
| 472 |
+
)
|
| 473 |
+
log(f" weight tying confirmed (embed == lm_head): {tied}")
|
| 474 |
+
|
| 475 |
+
def sd_bytes(sd):
|
| 476 |
+
return sum(t.numel() * t.element_size() for t in sd.values())
|
| 477 |
+
|
| 478 |
+
log(f" base weight memory : {fmt_size(sd_bytes(base_sd))}")
|
| 479 |
+
log(f" distilled memory : {fmt_size(sd_bytes(dist_sd))}")
|
| 480 |
+
|
| 481 |
+
# All subsequent tensor ops: move to CPU float32 only during computation,
|
| 482 |
+
# keep storage on GPU in bfloat16.
|
| 483 |
+
all_names = list(dist_sd.keys())
|
| 484 |
+
|
| 485 |
+
# [05] Full per-tensor statistics (distilled)
|
| 486 |
+
log(section_header(5, "Per-tensor weight statistics (distilled)"))
|
| 487 |
+
|
| 488 |
+
col = (
|
| 489 |
+
f" {'Layer':<68} {'Shape':<22} {'Mean':>8} {'Std':>8} "
|
| 490 |
+
f"{'Min':>8} {'Max':>8} {'Sparse':>7} {'KurtD':>7} "
|
| 491 |
+
f"{'OutlR':>7} {'RowL2':>8} {'DeadR':>6}"
|
| 492 |
+
)
|
| 493 |
+
log(col)
|
| 494 |
+
log(f" {divider('-', 170)}")
|
| 495 |
+
|
| 496 |
+
# Helper to calculate kurtosis statistics for base comparison
|
| 497 |
+
all_stats: dict[str, dict] = {}
|
| 498 |
+
type_buckets: dict[str, list[str]] = collections.defaultdict(list)
|
| 499 |
+
|
| 500 |
+
for name in all_names:
|
| 501 |
+
# Move to CPU float32 for stats only
|
| 502 |
+
t = dist_sd[name].cpu()
|
| 503 |
+
st = tensor_stats(t)
|
| 504 |
+
|
| 505 |
+
# Calculate base model kurtosis if present
|
| 506 |
+
if name in base_sd:
|
| 507 |
+
t_base = base_sd[name].cpu()
|
| 508 |
+
st_base = tensor_stats(t_base)
|
| 509 |
+
kurt_base = st_base["kurtosis"]
|
| 510 |
+
else:
|
| 511 |
+
kurt_base = 0.0
|
| 512 |
+
|
| 513 |
+
st["kurtosis_base"] = kurt_base
|
| 514 |
+
st["kurtosis_delta"] = st["kurtosis"] - kurt_base
|
| 515 |
+
|
| 516 |
+
all_stats[name] = st
|
| 517 |
+
type_buckets[classify_layer(name)].append(name)
|
| 518 |
+
|
| 519 |
+
rl2 = st.get("row_l2_mean", float("nan"))
|
| 520 |
+
dead = st.get("dead_rows", float("nan"))
|
| 521 |
+
log(
|
| 522 |
+
f" {name:<68} {str(st['shape']):<22} "
|
| 523 |
+
f"{st['mean']:8.4f} {st['std']:8.4f} "
|
| 524 |
+
f"{st['min']:8.4f} {st['max']:8.4f} "
|
| 525 |
+
f"{st['sparsity']:7.4f} {st['kurtosis_delta']:7.2f} "
|
| 526 |
+
f"{st['outlier_ratio']:7.4f} "
|
| 527 |
+
f"{rl2:8.4f} "
|
| 528 |
+
f"{str(int(dead)) if not math.isnan(dead) else 'N/A':>6}"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# [06] Layer-type aggregation (distilled)
|
| 532 |
+
log(section_header(6, "Layer-type aggregated statistics (distilled)"))
|
| 533 |
+
|
| 534 |
+
log(f" {'Type':<18} {'Count':>5} {'Params':>16} {'AvgMean':>9} {'AvgStd':>9} {'AvgSparse':>10} {'AvgKurtD':>9}")
|
| 535 |
+
log(f" {divider('-', 82)}")
|
| 536 |
+
|
| 537 |
+
for ltype in sorted(type_buckets):
|
| 538 |
+
names = type_buckets[ltype]
|
| 539 |
+
n = len(names)
|
| 540 |
+
params = sum(all_stats[x]["numel"] for x in names)
|
| 541 |
+
log(
|
| 542 |
+
f" {ltype:<18} {n:>5} {params:>16,} "
|
| 543 |
+
f"{sum(all_stats[x]['mean'] for x in names)/n:>9.5f} "
|
| 544 |
+
f"{sum(all_stats[x]['std'] for x in names)/n:>9.5f} "
|
| 545 |
+
f"{sum(all_stats[x]['sparsity'] for x in names)/n:>10.5f} "
|
| 546 |
+
f"{sum(all_stats[x]['kurtosis_delta'] for x in names)/n:>9.3f}"
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# [07] Per-transformer-block breakdown (distilled)
|
| 550 |
+
log(section_header(7, "Per-transformer-block breakdown (distilled)"))
|
| 551 |
+
|
| 552 |
+
n_layers = distilled_config.num_hidden_layers
|
| 553 |
+
sublayer_order = [
|
| 554 |
+
"input_layernorm", "self_attn.q_proj", "self_attn.k_proj",
|
| 555 |
+
"self_attn.v_proj", "self_attn.o_proj", "self_attn.q_norm",
|
| 556 |
+
"self_attn.k_norm", "post_attention_layernorm",
|
| 557 |
+
"mlp.gate_proj", "mlp.up_proj", "mlp.down_proj",
|
| 558 |
+
]
|
| 559 |
+
|
| 560 |
+
log(f" {'Blk':>4} {'Sublayer':<35} {'Shape':<22} {'L2':>9} {'AbsMn':>9} {'Std':>9} {'Sparse':>8} {'RowL2':>9}")
|
| 561 |
+
log(f" {divider('-', 115)}")
|
| 562 |
+
|
| 563 |
+
for blk in range(n_layers):
|
| 564 |
+
prefix = f"model.layers.{blk}."
|
| 565 |
+
for sub in sublayer_order:
|
| 566 |
+
nm = prefix + sub + ".weight"
|
| 567 |
+
if nm not in dist_sd:
|
| 568 |
+
continue
|
| 569 |
+
st = all_stats[nm]
|
| 570 |
+
rl2 = st.get("row_l2_mean", float("nan"))
|
| 571 |
+
log(
|
| 572 |
+
f" {blk:>4} {sub:<35} {str(st['shape']):<22} "
|
| 573 |
+
f"{st['l2_norm']:>9.3f} {st['abs_mean']:>9.5f} "
|
| 574 |
+
f"{st['std']:>9.5f} {st['sparsity']:>8.5f} {rl2:>9.5f}"
|
| 575 |
+
)
|
| 576 |
+
log("")
|
| 577 |
+
|
| 578 |
+
# [08] Isotropy analysis (distilled)
|
| 579 |
+
log(section_header(8, "Isotropy analysis (distilled, 2D tensors only)"))
|
| 580 |
+
log(f" Sampling up to {args.isotropy_samples} rows per layer.")
|
| 581 |
+
log(f" Score near 0 = isotropic (healthy). Score near 1 = representation collapse.")
|
| 582 |
+
log("")
|
| 583 |
+
log(f" {'Layer':<68} {'Shape':<20} {'Score':>10}")
|
| 584 |
+
log(f" {divider('-', 102)}")
|
| 585 |
+
|
| 586 |
+
iso_scores: dict[str, float] = {}
|
| 587 |
+
for name in all_names:
|
| 588 |
+
t = dist_sd[name].cpu()
|
| 589 |
+
iso = isotropy_score(t, n_samples=args.isotropy_samples)
|
| 590 |
+
iso_scores[name] = iso
|
| 591 |
+
if not math.isnan(iso):
|
| 592 |
+
log(f" {name:<68} {str(all_stats[name]['shape']):<20} {iso:>10.6f}")
|
| 593 |
+
|
| 594 |
+
valid_iso = [v for v in iso_scores.values() if not math.isnan(v)]
|
| 595 |
+
if valid_iso:
|
| 596 |
+
log("")
|
| 597 |
+
log(f" Global (across {len(valid_iso)} 2D layers)")
|
| 598 |
+
log(f" mean : {sum(valid_iso)/len(valid_iso):.6f}")
|
| 599 |
+
log(f" min : {min(valid_iso):.6f}")
|
| 600 |
+
log(f" max : {max(valid_iso):.6f}")
|
| 601 |
+
|
| 602 |
+
# [09] Base vs distilled divergence — all shared layers
|
| 603 |
+
log(section_header(9, "Base vs distilled divergence (all shared layers)"))
|
| 604 |
+
|
| 605 |
+
shared = sorted(set(base_sd) & set(dist_sd))
|
| 606 |
+
all_div: dict[str, dict] = {}
|
| 607 |
+
changed = []
|
| 608 |
+
unchanged = []
|
| 609 |
+
|
| 610 |
+
log(f" Shared tensors: {len(shared)}")
|
| 611 |
+
log("")
|
| 612 |
+
log(
|
| 613 |
+
f" {'Layer':<68} {'MaxDelta':>9} {'MeanDelta':>10} "
|
| 614 |
+
f"{'L2Delta':>9} {'CosSim':>8} {'RelErr':>8} {'SNR_dB':>7} {'Chg':>4}"
|
| 615 |
+
)
|
| 616 |
+
log(f" {divider('-', 135)}")
|
| 617 |
+
|
| 618 |
+
for name in shared:
|
| 619 |
+
b = base_sd[name]
|
| 620 |
+
d = dist_sd[name]
|
| 621 |
+
dv = tensor_divergence(b, d)
|
| 622 |
+
all_div[name] = dv
|
| 623 |
+
(changed if dv["changed"] else unchanged).append(name)
|
| 624 |
+
log(
|
| 625 |
+
f" {name:<68} "
|
| 626 |
+
f"{dv['max_delta']:>9.5f} {dv['mean_delta']:>10.6f} "
|
| 627 |
+
f"{dv['l2_delta']:>9.4f} {dv['cos_sim']:>8.5f} "
|
| 628 |
+
f"{dv['rel_err']:>8.5f} {dv['snr_db']:>7.2f} "
|
| 629 |
+
f"{'Y' if dv['changed'] else 'N':>4}"
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
log("")
|
| 633 |
+
log(f" Changed : {len(changed)} / {len(shared)}")
|
| 634 |
+
log(f" Unchanged: {len(unchanged)} / {len(shared)}")
|
| 635 |
+
if unchanged:
|
| 636 |
+
log(f" Unchanged (first 10): {unchanged[:10]}")
|
| 637 |
+
log("\n Note: Unchanged tensors are primarily normalization layers (input_layernorm, q_norm, k_norm, model.norm).")
|
| 638 |
+
log(" This demonstrates that the SFT/KD process modified the primary semantic projection weights")
|
| 639 |
+
log(" (attention and MLP projections) while preserving basic layer scaling characteristics.")
|
| 640 |
+
|
| 641 |
+
# [10] Cosine similarity distribution histogram
|
| 642 |
+
log(section_header(10, "Cosine similarity distribution histogram"))
|
| 643 |
+
|
| 644 |
+
cos_vals = [all_div[n]["cos_sim_raw"] for n in shared]
|
| 645 |
+
bins = [
|
| 646 |
+
(float('-inf'), 0.900),
|
| 647 |
+
(0.900, 0.990),
|
| 648 |
+
(0.990, 0.999),
|
| 649 |
+
(0.999, 0.9999),
|
| 650 |
+
(0.9999, 0.99999),
|
| 651 |
+
(0.99999, 1.00001),
|
| 652 |
+
(1.00001, 1.001),
|
| 653 |
+
(1.001, float('inf'))
|
| 654 |
+
]
|
| 655 |
+
|
| 656 |
+
def fmt_bnd(v: float) -> str:
|
| 657 |
+
if v == float('-inf'):
|
| 658 |
+
return "-inf"
|
| 659 |
+
if v == float('inf'):
|
| 660 |
+
return "inf"
|
| 661 |
+
return f"{v:7.5f}"
|
| 662 |
+
|
| 663 |
+
counts = []
|
| 664 |
+
for lo, hi in bins:
|
| 665 |
+
cnt = sum(1 for v in cos_vals if lo <= v < hi)
|
| 666 |
+
counts.append(cnt)
|
| 667 |
+
max_cnt = max(counts) if counts else 0
|
| 668 |
+
max_bar_width = 40
|
| 669 |
+
|
| 670 |
+
log(f" {'Range':<22} {'Count':>6} Histogram")
|
| 671 |
+
for (lo, hi), cnt in zip(bins, counts):
|
| 672 |
+
bar_len = int(round((cnt / max_cnt) * max_bar_width)) if max_cnt > 0 and cnt > 0 else 0
|
| 673 |
+
label = f"[{fmt_bnd(lo):>8}, {fmt_bnd(hi):>8})"
|
| 674 |
+
log(f" {label:<22} {cnt:>6} {'#' * bar_len}")
|
| 675 |
+
|
| 676 |
+
# [11] Attention geometry per block
|
| 677 |
+
log(section_header(11, "Attention geometry per transformer block"))
|
| 678 |
+
|
| 679 |
+
n_q = distilled_config.num_attention_heads
|
| 680 |
+
n_kv = getattr(distilled_config, "num_key_value_heads", n_q)
|
| 681 |
+
head_dim = distilled_config.hidden_size // n_q
|
| 682 |
+
|
| 683 |
+
log(f" Query heads: {n_q} | KV heads: {n_kv} | head_dim: {head_dim} | GQA: {n_q//n_kv}:1")
|
| 684 |
+
log("")
|
| 685 |
+
log(
|
| 686 |
+
f" {'Blk':>4} {'Q shape':<20} {'K shape':<20} {'V shape':<20} {'O shape':<20} "
|
| 687 |
+
f"{'Q L2':>8} {'K L2':>8} {'V L2':>8} {'O L2':>8}"
|
| 688 |
+
)
|
| 689 |
+
log(f" {divider('-', 130)}")
|
| 690 |
+
|
| 691 |
+
for blk in range(n_layers):
|
| 692 |
+
p = f"model.layers.{blk}.self_attn."
|
| 693 |
+
def attn(key):
|
| 694 |
+
nm = p + key + ".weight"
|
| 695 |
+
if nm in dist_sd:
|
| 696 |
+
st = all_stats[nm]
|
| 697 |
+
return str(st["shape"]), st["l2_norm"]
|
| 698 |
+
return "N/A", float("nan")
|
| 699 |
+
|
| 700 |
+
qs, ql = attn("q_proj")
|
| 701 |
+
ks, kl = attn("k_proj")
|
| 702 |
+
vs, vl = attn("v_proj")
|
| 703 |
+
os_, ol = attn("o_proj")
|
| 704 |
+
log(
|
| 705 |
+
f" {blk:>4} {qs:<20} {ks:<20} {vs:<20} {os_:<20} "
|
| 706 |
+
f"{ql:>8.3f} {kl:>8.3f} {vl:>8.3f} {ol:>8.3f}"
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# [12] MLP geometry per block
|
| 710 |
+
log(section_header(12, "MLP feed-forward geometry per transformer block"))
|
| 711 |
+
|
| 712 |
+
log(f" intermediate_size: {distilled_config.intermediate_size} | activation: {getattr(distilled_config, 'hidden_act', 'silu')}")
|
| 713 |
+
log("")
|
| 714 |
+
log(
|
| 715 |
+
f" {'Blk':>4} {'Gate shape':<22} {'Up shape':<22} {'Down shape':<22} "
|
| 716 |
+
f"{'Gate L2':>8} {'Up L2':>8} {'Down L2':>9} "
|
| 717 |
+
f"{'GateSp':>8} {'UpSp':>8} {'DnSp':>8}"
|
| 718 |
+
)
|
| 719 |
+
log(f" {divider('-', 135)}")
|
| 720 |
+
|
| 721 |
+
for blk in range(n_layers):
|
| 722 |
+
p = f"model.layers.{blk}.mlp."
|
| 723 |
+
def mlp(key):
|
| 724 |
+
nm = p + key + ".weight"
|
| 725 |
+
if nm in dist_sd:
|
| 726 |
+
st = all_stats[nm]
|
| 727 |
+
return str(st["shape"]), st["l2_norm"], st["sparsity"]
|
| 728 |
+
return "N/A", float("nan"), float("nan")
|
| 729 |
+
|
| 730 |
+
gs, gl, gsp = mlp("gate_proj")
|
| 731 |
+
us, ul, usp = mlp("up_proj")
|
| 732 |
+
ds, dl, dsp = mlp("down_proj")
|
| 733 |
+
log(
|
| 734 |
+
f" {blk:>4} {gs:<22} {us:<22} {ds:<22} "
|
| 735 |
+
f"{gl:>8.3f} {ul:>8.3f} {dl:>9.3f} "
|
| 736 |
+
f"{gsp:>8.5f} {usp:>8.5f} {dsp:>8.5f}"
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# [13] Health diagnostics
|
| 740 |
+
log(section_header(13, "Weight health diagnostics"))
|
| 741 |
+
|
| 742 |
+
high_sparsity = [(n, all_stats[n]["sparsity"]) for n in all_names if all_stats[n]["sparsity"] > 0.10]
|
| 743 |
+
high_kurtosis = [(n, all_stats[n]["kurtosis_delta"]) for n in all_names if abs(all_stats[n]["kurtosis_delta"]) > 5.0]
|
| 744 |
+
high_outlier = [(n, all_stats[n]["outlier_ratio"]) for n in all_names if all_stats[n]["outlier_ratio"] > 0.01]
|
| 745 |
+
dead_rows = [(n, int(all_stats[n].get("dead_rows", 0))) for n in all_names
|
| 746 |
+
if not math.isnan(all_stats[n].get("dead_rows", float("nan")))
|
| 747 |
+
and all_stats[n].get("dead_rows", 0) > 0]
|
| 748 |
+
low_cos = [(n, all_div[n]["cos_sim"]) for n in shared if all_div[n]["cos_sim"] < 0.95]
|
| 749 |
+
low_snr = [(n, all_div[n]["snr_db"]) for n in shared if all_div[n]["snr_db"] < 20.0]
|
| 750 |
+
|
| 751 |
+
def diag_block(title: str, rows: list, fmt):
|
| 752 |
+
log(f"\n {title}")
|
| 753 |
+
if not rows:
|
| 754 |
+
log(" none")
|
| 755 |
+
else:
|
| 756 |
+
for n, v in rows:
|
| 757 |
+
log(f" {n:<70} {fmt(v)}")
|
| 758 |
+
|
| 759 |
+
def get_percentiles(vals: list[float]) -> dict:
|
| 760 |
+
if not vals:
|
| 761 |
+
return {"mean": 0.0, "median": 0.0, "p10": 0.0, "p90": 0.0}
|
| 762 |
+
t = torch.tensor(vals, dtype=torch.float64)
|
| 763 |
+
return {
|
| 764 |
+
"mean": t.mean().item(),
|
| 765 |
+
"median": t.median().item(),
|
| 766 |
+
"p10": torch.quantile(t, 0.10).item(),
|
| 767 |
+
"p90": torch.quantile(t, 0.90).item(),
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
diag_block("Sparsity > 10%", high_sparsity, lambda v: f"sparsity={v:.5f}")
|
| 771 |
+
diag_block("|Kurtosis Delta| > 5.0", high_kurtosis, lambda v: f"kurt_delta={v:+.3f}")
|
| 772 |
+
diag_block("Outlier ratio > 1%", high_outlier, lambda v: f"outlier_ratio={v:.5f}")
|
| 773 |
+
diag_block("Dead rows (L2 < 1e-6)", dead_rows, lambda v: f"dead_rows={v}")
|
| 774 |
+
diag_block("Low cosine sim vs base (<0.95)", low_cos, lambda v: f"cos_sim={v:.6f}")
|
| 775 |
+
diag_block("Low SNR vs base (< 20 dB)", low_snr, lambda v: f"snr_db={v:.2f}")
|
| 776 |
+
|
| 777 |
+
log("\n Note on kurtosis delta: Kurtosis values are reported as the difference (delta) compared to the base model.")
|
| 778 |
+
log(" A high kurtosis delta on tiny vectors (like norm/q-k-norm vectors of size 128) is statistically expected")
|
| 779 |
+
log(" due to small sample sizes and does not indicate a model health or representation collapse issue.")
|
| 780 |
+
|
| 781 |
+
# [14] Executive summary
|
| 782 |
+
log(section_header(14, "Executive summary"))
|
| 783 |
+
|
| 784 |
+
all_cos = [all_div[n]["cos_sim"] for n in shared]
|
| 785 |
+
all_snr = [all_div[n]["snr_db"] for n in shared]
|
| 786 |
+
all_rel = [all_div[n]["rel_err"] for n in shared]
|
| 787 |
+
|
| 788 |
+
cos_stats = get_percentiles(all_cos)
|
| 789 |
+
snr_stats = get_percentiles(all_snr)
|
| 790 |
+
rel_stats = get_percentiles(all_rel)
|
| 791 |
+
|
| 792 |
+
log(f" shared tensors : {len(shared)}")
|
| 793 |
+
log(f" tensors changed vs base : {len(changed)} / {len(shared)}")
|
| 794 |
+
log(f" cosine similarity : mean = {cos_stats['mean']:.6f} | median = {cos_stats['median']:.6f} | p10 = {cos_stats['p10']:.6f} | p90 = {cos_stats['p90']:.6f}")
|
| 795 |
+
log(f" relative error : mean = {rel_stats['mean']:.6f} | median = {rel_stats['median']:.6f} | p10 = {rel_stats['p10']:.6f} | p90 = {rel_stats['p90']:.6f}")
|
| 796 |
+
log(f" SNR dB : mean = {snr_stats['mean']:.2f} | median = {snr_stats['median']:.2f} | p10 = {snr_stats['p10']:.2f} | p90 = {snr_stats['p90']:.2f}")
|
| 797 |
+
log(f" high-sparsity layers (>10%) : {len(high_sparsity)}")
|
| 798 |
+
log(f" heavy-tail layers (|kurt_d|>5.0) : {len(high_kurtosis)}")
|
| 799 |
+
log(f" dead-row layers : {len(dead_rows)}")
|
| 800 |
+
log(f" low-cos layers (<0.95) : {len(low_cos)}")
|
| 801 |
+
log(f" low-SNR layers (<20 dB) : {len(low_snr)}")
|
| 802 |
+
log(f" distillation alpha : {args.alpha}")
|
| 803 |
+
log("")
|
| 804 |
+
log(f" checkpoint size on disk : {fmt_size(total_ckpt_bytes)}")
|
| 805 |
+
log(f" base weights in memory : {fmt_size(sd_bytes(base_sd))}")
|
| 806 |
+
log(f" distilled weights in memory : {fmt_size(sd_bytes(dist_sd))}")
|
| 807 |
+
log("")
|
| 808 |
+
log(divider("="))
|
| 809 |
+
log(" END OF REPORT")
|
| 810 |
+
log(divider("="))
|
| 811 |
+
|
| 812 |
+
# Write to file
|
| 813 |
+
out = Path(args.output_file)
|
| 814 |
+
out.write_text("\n".join(R) + "\n", encoding="utf-8")
|
| 815 |
+
print(f"\nReport written to: {out.resolve()}")
|
| 816 |
+
|
| 817 |
+
if __name__ == "__main__":
|
| 818 |
+
main()
|
weight_audit/weight_audit_report.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|