"""HPC-grade LoRA SFT for the MicroAgent corpus on Qwen3-4B-Thinking-2507. Designed for single A100-40GB. Target: 4-5 hours / epoch on 26k trajectories. HPC principles applied: - Unsloth kernels (RoPE, RMSNorm, cross-entropy in Triton) -> 2x speed - FlashAttention 2 with variable-length attention -> 5x mem - Sequence packing (FA2 varlen, no cross-attention pollution) -> 3-4x throughput - Selective gradient checkpointing -> 10x activation mem - BF16 native + TF32 matmul -> 25% faster matmul - 8-bit paged AdamW (bitsandbytes) -> 8x optimizer mem - Loss on assistant tokens only (`train_on_responses_only`) -> focused signal - Pre-tokenize once, cache to Arrow -> 5 min -> 5 sec - Adapter-only save -> 100x disk - Throughput telemetry every N steps -> regression alert Usage: python scripts/train_v2.py \\ --model Qwen/Qwen3-4B-Thinking-2507 \\ --data data/microagent_train_v2.jsonl \\ --output-dir runs/qwen3-4b-thinking-microagent-v1 Memory budget on A100-40GB (default flags): base BF16 8.0 GB LoRA r=32 + grads 0.3 GB AdamW8bit state 0.05 GB activations (16k) 12 GB workspace 4 GB ---------------------------- total ~24 GB (headroom: 16 GB) """ from __future__ import annotations import argparse import json import os import time from pathlib import Path def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model", default="Qwen/Qwen3-4B-Thinking-2507", help="HF model id or local path") p.add_argument("--data", default="data/microagent_train_v2.jsonl") p.add_argument("--output-dir", required=True) p.add_argument("--max-seq-len", type=int, default=16384, help="Pack sequences up to this length (16k uses ~12GB activations).") p.add_argument("--cache-dir", type=str, default="data/_tokenized_cache", help="Where pre-tokenized Arrow lives") # LoRA p.add_argument("--lora-rank", type=int, default=32, help="32 is the sweet spot for 4B (50-150MB adapter, 99% of full-FT quality)") p.add_argument("--lora-alpha", type=int, default=64) # 2x rank by Unsloth convention p.add_argument("--lora-dropout", type=float, default=0.0, help="Unsloth fast path requires 0.0; small datasets rarely need dropout anyway") p.add_argument("--lora-target", type=str, default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj", help="All linear layers minus embeddings/lm_head (best LoRA quality/size tradeoff)") # Training p.add_argument("--epochs", type=float, default=1.0, help="1 epoch is standard for SFT on 20k+ samples; 2 risks overfit") p.add_argument("--per-device-batch", type=int, default=1, help="With packing, each 'sample' is already 16k tokens") p.add_argument("--grad-accum", type=int, default=16, help="Effective batch = per_device * grad_accum = 16 sequences (~256k tokens)") p.add_argument("--lr", type=float, default=2e-4, help="2e-4 is Unsloth's recommended LoRA LR; cosine to 0") p.add_argument("--warmup-ratio", type=float, default=0.03) p.add_argument("--weight-decay", type=float, default=0.01) p.add_argument("--lr-scheduler", type=str, default="cosine") p.add_argument("--max-grad-norm", type=float, default=1.0) # Optimizer p.add_argument("--optim", type=str, default="paged_adamw_8bit", help="paged_adamw_8bit = bitsandbytes 8-bit + CPU paging") # Logging p.add_argument("--logging-steps", type=int, default=10) p.add_argument("--save-steps", type=int, default=200) p.add_argument("--eval-steps", type=int, default=200) p.add_argument("--save-total-limit", type=int, default=2) p.add_argument("--max-steps", type=int, default=-1) p.add_argument("--eval-frac", type=float, default=0.01, help="1% holdout = ~266 samples; enough to track loss without wasting compute") # System p.add_argument("--seed", type=int, default=42) p.add_argument("--report-to", type=str, default="none") p.add_argument("--no-packing", action="store_true", help="Disable packing (useful for debugging; usually keep ON)") return p.parse_args() def setup_high_perf_torch(): """Apply free Ampere/Hopper speedups before any model load.""" import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") # Enable cuDNN benchmark for fixed-shape kernels (sequence packing makes shapes stable) torch.backends.cudnn.benchmark = True def load_jsonl_as_messages(path: str): examples = [] with open(path, "r", encoding="utf-8") as f: for line in f: row = json.loads(line) convs = row.get("conversations") if not convs: continue examples.append({"messages": convs}) return examples class ThroughputCallback: """Log tokens/sec every N steps so we can detect throughput regression early.""" def __init__(self, max_seq_len: int, log_every: int = 50): self.max_seq_len = max_seq_len self.log_every = log_every self.last_time = None self.last_step = 0 def __call__(self, args, state, control, **kwargs): if state.global_step % self.log_every != 0 or state.global_step == 0: return now = time.time() if self.last_time is None: self.last_time = now self.last_step = state.global_step return dt = now - self.last_time dsteps = state.global_step - self.last_step # Approximate tokens processed: steps * effective_batch * avg_seq # We log a packed-token rate using max_seq_len as the conservative upper bound eff_batch = args.per_device_train_batch_size * args.gradient_accumulation_steps tokens = dsteps * eff_batch * self.max_seq_len tok_per_sec = tokens / max(dt, 1e-9) print(f"[throughput] step={state.global_step} " f"~{tok_per_sec/1000:.1f}k tok/s " f"({dsteps} steps in {dt:.1f}s)") self.last_time = now self.last_step = state.global_step def main(): args = parse_args() os.makedirs(args.output_dir, exist_ok=True) setup_high_perf_torch() # Import Unsloth FIRST — it patches transformers from unsloth import FastLanguageModel from unsloth.chat_templates import train_on_responses_only import torch from datasets import Dataset from trl import SFTTrainer, SFTConfig from transformers import TrainerCallback print(f"[load] data: {args.data}") rows = load_jsonl_as_messages(args.data) print(f"[load] {len(rows)} rows") ds = Dataset.from_list(rows) if args.eval_frac > 0: ds = ds.train_test_split(test_size=args.eval_frac, seed=args.seed) train_ds, eval_ds = ds["train"], ds["test"] print(f"[load] train={len(train_ds)} eval={len(eval_ds)}") else: train_ds, eval_ds = ds, None print(f"[load] base model: {args.model}") # Unsloth's FastLanguageModel loads with patched kernels + FA2 in one call. model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model, max_seq_length=args.max_seq_len, dtype=torch.bfloat16, load_in_4bit=False, # We have 40GB; full BF16 is faster and higher quality load_in_8bit=False, full_finetuning=False, trust_remote_code=True, ) # Attach LoRA via Unsloth (uses its fused-kernel LoRA path). target_modules = [m.strip() for m in args.lora_target.split(",")] model = FastLanguageModel.get_peft_model( model, r=args.lora_rank, target_modules=target_modules, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", use_gradient_checkpointing="unsloth", # Unsloth's selective ckpt — 30% less mem random_state=args.seed, use_rslora=False, loftq_config=None, ) # ---- Format the dataset using the model's chat template ---- # Qwen3-Thinking expects ChatML <|im_start|>...<|im_end|> with tags in assistant def apply_template(example): text = tokenizer.apply_chat_template( example["messages"], tokenize=False, add_generation_prompt=False, ) return {"text": text} print("[tokenize] applying chat template (cached)...") cache_path = Path(args.cache_dir) / f"v2_{Path(args.model).name.replace('/','_')}" cache_path.parent.mkdir(parents=True, exist_ok=True) train_ds = train_ds.map( apply_template, remove_columns=train_ds.column_names, num_proc=4, load_from_cache_file=True, cache_file_name=str(cache_path / "train.arrow"), ) if eval_ds is not None: eval_ds = eval_ds.map( apply_template, remove_columns=eval_ds.column_names, num_proc=4, load_from_cache_file=True, cache_file_name=str(cache_path / "eval.arrow"), ) # ---- SFTConfig with HPC flags ---- sft_cfg = SFTConfig( output_dir=args.output_dir, num_train_epochs=args.epochs, per_device_train_batch_size=args.per_device_batch, per_device_eval_batch_size=args.per_device_batch, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, lr_scheduler_type=args.lr_scheduler, max_grad_norm=args.max_grad_norm, logging_steps=args.logging_steps, save_steps=args.save_steps, eval_steps=args.eval_steps if eval_ds else None, eval_strategy="steps" if eval_ds else "no", save_total_limit=args.save_total_limit, bf16=True, fp16=False, gradient_checkpointing=False, # Unsloth handles this internally — DO NOT double-enable max_steps=args.max_steps, seed=args.seed, report_to=args.report_to, max_length=args.max_seq_len, packing=not args.no_packing, packing_strategy="ffd" if not args.no_packing else None, # First-Fit Decreasing — best fill optim=args.optim, dataset_text_field="text", # Don't use TRL's default assistant_only_loss; we'll apply train_on_responses_only ourselves # to ensure correct masking for Qwen3-Thinking template dataloader_num_workers=4, dataloader_pin_memory=True, group_by_length=False, # Off with packing — packing already optimizes layout ) trainer = SFTTrainer( model=model, args=sft_cfg, train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tokenizer, callbacks=[type("ThroughputCb", (TrainerCallback,), { "on_step_end": lambda self, args, state, control, **kw: ThroughputCallback(sft_cfg.max_length, log_every=50)(args, state, control) })()], ) # Loss only on assistant turns (skip user observations + system prompt). # Qwen3-Thinking uses ChatML — assistant block is between <|im_start|>assistant\n and <|im_end|> trainer = train_on_responses_only( trainer, instruction_part="<|im_start|>user\n", response_part="<|im_start|>assistant\n", ) # Print mem + trainable params before training so we know we fit print(f"[mem] GPU peak allocated so far: " f"{torch.cuda.max_memory_allocated()/1e9:.2f} GB") n_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad) n_total_params = sum(p.numel() for p in model.parameters()) print(f"[params] trainable: {n_train_params/1e6:.1f}M / " f"{n_total_params/1e9:.2f}B ({100*n_train_params/n_total_params:.2f}%)") print("[train] starting...") t0 = time.time() trainer.train() dt = time.time() - t0 print(f"[train] complete in {dt/3600:.2f} hr") # Save adapter only (LoRA weights + config + tokenizer) final_path = Path(args.output_dir) / "final" print(f"[save] adapter -> {final_path}") model.save_pretrained(str(final_path)) tokenizer.save_pretrained(str(final_path)) # Also dump a small README with (final_path / "TRAINING_NOTES.md").open("w") as f: f.write(f"# Training run\n\n") f.write(f"- base: {args.model}\n") f.write(f"- data: {args.data} ({len(rows)} trajectories)\n") f.write(f"- epochs: {args.epochs}\n") f.write(f"- effective batch: {args.per_device_batch * args.grad_accum}\n") f.write(f"- max_seq_len: {args.max_seq_len} (packed: {not args.no_packing})\n") f.write(f"- lora: r={args.lora_rank} alpha={args.lora_alpha}\n") f.write(f"- wall time: {dt/3600:.2f} hr\n") f.write(f"- peak GPU mem: {torch.cuda.max_memory_allocated()/1e9:.2f} GB\n") print("[done]") if __name__ == "__main__": main()