#!/usr/bin/env bash # # Train a K2.6-matched DFlash drafter using SpecForge. # # This script runs three phases: # 1. Setup: clone SpecForge, prepare dataset # 2. Regenerate: run PerfectBlend prompts through K2.6 to get target-distribution data # 3. Train: train 6-layer DFlash drafter on 8x MI300X # # Usage: # ./train-drafter.sh # full pipeline # ./train-drafter.sh --skip-regen # skip regeneration (data already exists) # ./train-drafter.sh --skip-setup # skip setup + regen # # Prerequisites: # - 8x MI300X node with K2.6 model at $MODEL_DIR # - Python 3.10+ with torch, transformers # - ~1TB free disk for training data + checkpoints set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "$SCRIPT_DIR/configs/production.env" SPECFORGE_DIR="${SPECFORGE_DIR:-$SCRIPT_DIR/.specforge}" DRAFT_CONFIG="$SCRIPT_DIR/configs/kimi-k2.6-dflash-draft.json" OUTPUT_DIR="${OUTPUT_DIR:-$SCRIPT_DIR/outputs/kimi-k2.6-dflash}" REGEN_DATA="$SPECFORGE_DIR/cache/dataset/perfectblend_kimi-k2.6_regen.jsonl" SGLANG_PORT=30000 NUM_EPOCHS="${NUM_EPOCHS:-6}" BATCH_SIZE="${BATCH_SIZE:-2}" LR="${LR:-6e-4}" MAX_LENGTH="${MAX_LENGTH:-3072}" BLOCK_SIZE_TRAIN="${BLOCK_SIZE_TRAIN:-16}" CONCURRENCY="${CONCURRENCY:-128}" SKIP_SETUP=false SKIP_REGEN=false for arg in "$@"; do case $arg in --skip-setup) SKIP_SETUP=true; SKIP_REGEN=true ;; --skip-regen) SKIP_REGEN=true ;; esac done # ---------- Phase 1: Setup ---------- if [[ "$SKIP_SETUP" == false ]]; then echo "=== Phase 1: Setup ===" if [[ ! -d "$SPECFORGE_DIR" ]]; then echo "Cloning SpecForge..." git clone https://github.com/sgl-project/SpecForge.git "$SPECFORGE_DIR" else echo "SpecForge already cloned at $SPECFORGE_DIR" fi cd "$SPECFORGE_DIR" if [[ ! -d .venv ]]; then echo "Creating venv and installing dependencies..." python3 -m venv .venv .venv/bin/pip install -e ".[train]" fi echo "Preparing PerfectBlend dataset..." .venv/bin/python scripts/prepare_data.py --dataset perfectblend echo "Phase 1 complete." fi # ---------- Phase 2: Regenerate with K2.6 ---------- if [[ "$SKIP_REGEN" == false ]]; then echo "" echo "=== Phase 2: Regenerate dataset with K2.6 ===" cd "$SPECFORGE_DIR" echo "Starting K2.6 SGLang server on port $SGLANG_PORT..." .venv/bin/python3 -m sglang.launch_server \ --model "$MODEL_DIR" \ --mem-frac 0.85 \ --tp 8 \ --trust-remote-code \ --port $SGLANG_PORT & SGLANG_PID=$! echo "Waiting for SGLang server (PID $SGLANG_PID)..." for i in $(seq 1 360); do if curl -sf "http://127.0.0.1:${SGLANG_PORT}/health" >/dev/null 2>&1; then echo "SGLang server ready." break fi if ! kill -0 $SGLANG_PID 2>/dev/null; then echo "ERROR: SGLang server died during startup." exit 1 fi sleep 5 done echo "Regenerating dataset (this takes hours)..." .venv/bin/python scripts/regenerate_train_data.py \ --model "$MODEL_DIR" \ --is-reasoning-model \ --concurrency "$CONCURRENCY" \ --max-tokens 4096 \ --temperature 0.8 \ --server-address "localhost:$SGLANG_PORT" \ --input-file-path ./cache/dataset/perfectblend_train.jsonl \ --output-file-path "$REGEN_DATA" \ --resume echo "Killing SGLang server..." kill $SGLANG_PID 2>/dev/null || true wait $SGLANG_PID 2>/dev/null || true REGEN_COUNT=$(wc -l < "$REGEN_DATA") echo "Phase 2 complete. Regenerated $REGEN_COUNT samples." fi # ---------- Phase 3: Train DFlash drafter ---------- echo "" echo "=== Phase 3: Train K2.6 DFlash drafter ===" echo " Epochs: $NUM_EPOCHS" echo " Batch size: $BATCH_SIZE" echo " LR: $LR" echo " Max length: $MAX_LENGTH" echo " Block size: $BLOCK_SIZE_TRAIN" echo " Output: $OUTPUT_DIR" cd "$SPECFORGE_DIR" # SGLang backend for online hidden state extraction echo "Starting K2.6 SGLang server for training..." .venv/bin/python3 -m sglang.launch_server \ --model "$MODEL_DIR" \ --mem-frac 0.85 \ --tp 8 \ --trust-remote-code \ --port $SGLANG_PORT & SGLANG_PID=$! for i in $(seq 1 360); do if curl -sf "http://127.0.0.1:${SGLANG_PORT}/health" >/dev/null 2>&1; then echo "SGLang server ready for training." break fi if ! kill -0 $SGLANG_PID 2>/dev/null; then echo "ERROR: SGLang server died during startup." exit 1 fi sleep 5 done torchrun --standalone --nproc_per_node 8 \ scripts/train_dflash.py \ --target-model-path "$MODEL_DIR" \ --target-model-backend sglang \ --draft-config-path "$DRAFT_CONFIG" \ --train-data-path "$REGEN_DATA" \ --output-dir "$OUTPUT_DIR" \ --num-epochs "$NUM_EPOCHS" \ --batch-size "$BATCH_SIZE" \ --learning-rate "$LR" \ --warmup-ratio 0.04 \ --max-grad-norm 1.0 \ --max-length "$MAX_LENGTH" \ --chat-template qwen \ --attention-backend flex_attention \ --num-anchors 512 \ --loss-decay-gamma 7.0 \ --log-interval 50 \ --save-interval 1000 \ --block-size "$BLOCK_SIZE_TRAIN" \ --trust-remote-code echo "Killing SGLang server..." kill $SGLANG_PID 2>/dev/null || true wait $SGLANG_PID 2>/dev/null || true echo "" echo "=== Training complete ===" echo "Drafter model saved to: $OUTPUT_DIR" echo "" echo "To serve with the new drafter:" echo " Edit configs/production-fp8kv.env:" echo " DRAFT_MODEL_DIR=$OUTPUT_DIR" echo " NUM_SPECULATIVE_TOKENS=8 # matched drafter can handle more" echo " Then: ./serve.sh configs/production-fp8kv.env"