# SFT Training and Downstream Evaluation Pipeline from __future__ import annotations import argparse import gc import json import os import re import sys import time from datetime import datetime from pathlib import Path import yaml import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup # Load Configuration def load_config() -> dict: cfg_path = Path(__file__).resolve().parent / "config.yaml" if not cfg_path.exists(): return {} with open(cfg_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) or {} cfg = load_config() # PROMPTS (50 PROMPTS) EASY_PROMPTS = [ "What is the capital of Japan, and what is it known for?", "What does the term 'CPU' stand for, and what is its role in a computer?", "Name three mammals that live primarily in water.", "What is the difference between a virus and a bacterium?", "Convert 72 degrees Fahrenheit to Celsius.", "What is the purpose of a hash function?", "What does HTTP stand for and what is it used for?", "In which continent is the Amazon rainforest located?", "What is the difference between RAM and ROM?", "Name two programming languages commonly used for data science.", "What is the function of the mitochondria in a cell?", "What is a palindrome? Give two examples.", "What is the difference between a compiler and an interpreter?", "What unit is used to measure electrical resistance?", "Name the four blood types in the ABO system.", "What is the primary purpose of DNS in networking?", "What does it mean for a function to be 'pure' in programming?" ] MEDIUM_PROMPTS = [ "Explain the difference between supervised and unsupervised learning with a concrete example of each.", "Write a Python function that takes a list of integers and returns all pairs that sum to a given target.", "Explain how TCP/IP ensures reliable data delivery over an unreliable network.", "What are the trade-offs between using a relational database and a document store for a user profile system?", "Describe how gradient descent works and explain the role of the learning rate.", "Write a SQL query that returns the top 5 customers by total order value, including customers with no orders.", "What is the CAP theorem and what does it imply for distributed system design?", "Explain the difference between process and thread, including when you would prefer one over the other.", "How does HTTPS prevent a man-in-the-middle attack? Walk through the handshake at a high level.", "Write a regex that validates an email address and annotate each part of the pattern.", "What is the difference between memoization and dynamic programming?", "Describe three ways to handle class imbalance in a machine learning dataset.", "Explain what a foreign key constraint does and give an example of why it matters.", "What is the difference between horizontal and vertical scaling, and when would you choose each?", "How does Python's garbage collector handle circular references?", "Explain the intuition behind the attention mechanism in Transformer models.", "What is a race condition? Write a minimal pseudocode example that demonstrates one." ] TOUGH_PROMPTS = [ "Design a rate limiter for a public API that must handle 100k requests per second across multiple regions. Describe the data structures, algorithms, and infrastructure trade-offs involved.", "Explain why training very deep neural networks with sigmoid activations suffers from vanishing gradients. How do residual connections and normalization layers address this, and what are their respective limitations?", "A message queue is consuming events from an upstream producer faster than a downstream consumer can process them. The queue is filling up and the producer cannot be slowed down. Describe at least three architectural strategies to resolve this, with trade-offs.", "Given an undirected weighted graph, write Python code to find the minimum spanning tree using Kruskal's algorithm. Include the union-find data structure. Analyze time and space complexity.", "You are given two sorted arrays of size m and n. Find the median of the combined array in O(log(m+n)) time. Explain the approach before writing the code.", "Explain the difference between Byzantine fault tolerance and crash fault tolerance. In what scenario does the distinction become critical, and how does a consensus protocol like PBFT address Byzantine failures?", "A large language model fine-tuned on customer service data starts producing confident but factually wrong answers about product details. Propose a complete mitigation strategy covering training, inference, and deployment layers.", "Explain the mechanism behind speculative execution in modern CPUs and how it led to the Spectre vulnerability. What classes of software-level mitigations exist and what performance cost do they carry?", "Design a schema and indexing strategy for a social graph where you need to efficiently answer: (1) mutual friends between two users, (2) shortest path between two users, (3) top-k most influential accounts. Justify your choices.", "Implement a thread-safe LRU cache in Python with O(1) get and put operations. Explain why your synchronization approach is correct and where contention bottlenecks might appear under high concurrency.", "Explain the difference between weak, strong, and eventual consistency in distributed databases. Give a concrete example of a bug that arises when a developer assumes strong consistency but the system only guarantees eventual consistency.", "You are designing the storage layer for a time-series database that ingests 1 million data points per second and must support range queries going back 2 years. Describe compression strategies, write amplification concerns, and compaction trade-offs.", "Explain how LoRA (Low-Rank Adaptation) reduces the number of trainable parameters in fine-tuning. Derive why a weight update matrix can be approximated as a product of two low-rank matrices and discuss what is lost in this approximation.", "A binary tree is given where each node has a value. Write an algorithm to find the maximum path sum between any two nodes (not necessarily leaf nodes). Prove the correctness of your recurrence relation.", "Explain the economic concept of Goodhart's Law and give three examples of how it manifests in AI system evaluation.", "Describe the full lifecycle of a memory allocation in a system using jemalloc or tcmalloc. How do thread-local caches, size classes, and slab allocation interact, and what are the implications for long-running server processes?" ] ALL_PROMPTS = [] for p in EASY_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "EASY"}) for p in MEDIUM_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "MEDIUM"}) for p in TOUGH_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "TOUGH"}) # UTILITIES AND DATASET LOADERS class SFTDataset(Dataset): def __init__(self, file_path: str, max_samples: int = -1): self.samples = [] with open(file_path, "r", encoding="utf-8") as f: for line in f: if 0 < max_samples <= len(self.samples): break self.samples.append(json.loads(line)) print(f"Loaded {len(self.samples)} SFT samples from {file_path}") def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> dict: return self.samples[idx] def pack_sequences(samples: list[dict], pack_length: int, pad_token_id: int, eos_token_id: int) -> list[dict]: """Sort and pack short samples into fixed-size bins (FFD packing) to accelerate training.""" print(f"Packing sequences into {pack_length}-token bins...") # Sort samples by input_ids length descending indexed_samples = sorted( samples, key=lambda x: len(x["input_ids"]), reverse=True ) bins: list[list[dict]] = [] bin_lengths: list[int] = [] for sample in indexed_samples: s_len = len(sample["input_ids"]) if s_len > pack_length: sample["input_ids"] = sample["input_ids"][:pack_length] sample["loss_mask"] = sample["loss_mask"][:pack_length] s_len = pack_length # Try to place sample into an existing bin placed = False for b_idx in range(len(bins)): needed = s_len + (1 if len(bins[b_idx]) > 0 else 0) if bin_lengths[b_idx] + needed <= pack_length: bins[b_idx].append(sample) bin_lengths[b_idx] += needed placed = True break if not placed: bins.append([sample]) bin_lengths.append(s_len) # Convert packed bins to training formats packed_samples = [] for b in bins: input_ids = [] loss_mask = [] for i, sample in enumerate(b): if i > 0: input_ids.append(eos_token_id) loss_mask.append(0) # Mask out the EOS separator token input_ids.extend(sample["input_ids"]) loss_mask.extend(sample["loss_mask"]) real_len = len(input_ids) pad_len = pack_length - real_len if pad_len > 0: input_ids.extend([pad_token_id] * pad_len) loss_mask.extend([0] * pad_len) packed_samples.append({ "input_ids": torch.tensor(input_ids, dtype=torch.long), "loss_mask": torch.tensor(loss_mask, dtype=torch.long), "attention_mask": torch.cat([ torch.ones(real_len, dtype=torch.long), torch.zeros(pad_len, dtype=torch.long) ]) }) utilization = sum(bin_lengths) / (len(bins) * pack_length) print(f"Packed {len(samples)} samples into {len(bins)} bins. Utilization: {utilization * 100:.2f}%") return packed_samples def collate_sft(batch: list[dict], pad_token_id: int) -> dict: """Collates batch for standard unpacked training, dynamically padding batch to max length.""" max_len = max(len(s["input_ids"]) for s in batch) input_ids_list = [] attention_mask_list = [] labels_list = [] for s in batch: ids = s["input_ids"] mask = s["loss_mask"] pad_len = max_len - len(ids) padded_ids = ids + [pad_token_id] * pad_len padded_labels = [ids[i] if mask[i] == 1 else -100 for i in range(len(ids))] + [-100] * pad_len input_ids_list.append(torch.tensor(padded_ids, dtype=torch.long)) attention_mask_list.append(torch.tensor([1] * len(ids) + [0] * pad_len, dtype=torch.long)) labels_list.append(torch.tensor(padded_labels, dtype=torch.long)) return { "input_ids": torch.stack(input_ids_list), "attention_mask": torch.stack(attention_mask_list), "labels": torch.stack(labels_list) } def collate_packed(batch: list[dict]) -> dict: """Collates pre-packed sequence bins by simple stacking.""" input_ids = torch.stack([item["input_ids"] for item in batch]) attention_mask = torch.stack([item["attention_mask"] for item in batch]) loss_mask = torch.stack([item["loss_mask"] for item in batch]) labels = input_ids.clone() labels = labels.masked_fill(loss_mask == 0, -100) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels } # PARSING AND MAIN LOGIC def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Clean SFT training and evaluation suite") parser.add_argument("--student_model", type=str, default=cfg.get("model", {}).get("student", "Qwen/Qwen3-1.7B-Base")) parser.add_argument("--tokenizer_model", type=str, default=cfg.get("model", {}).get("tokenizer", "Qwen/Qwen3-1.7B")) parser.add_argument("--data_repo", type=str, default=os.environ.get("QUINTUS_SFT_DATA_REPO"), help="HF dataset repo containing train_sft.jsonl. Optional when data/tokenized/train_sft.jsonl exists.") parser.add_argument("--token", type=str, default=None) parser.add_argument("--trust_remote_code", action="store_true", help="Allow custom code from model/tokenizer repositories.") parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--micro_batch_size", type=int, default=4) parser.add_argument("--grad_accum_steps", type=int, default=2) parser.add_argument("--max_seq_len", type=int, default=4096) parser.add_argument("--sequence_packing", action="store_true", default=True) parser.add_argument("--no_sequence_packing", action="store_false", dest="sequence_packing") parser.add_argument("--output_dir", type=str, default="quintus_sft_output") parser.add_argument("--run_prompt_suite", action="store_true", default=True) parser.add_argument("--no_prompt_suite", action="store_false", dest="run_prompt_suite") parser.add_argument("--run_gsm8k", action="store_true", default=True) parser.add_argument("--no_gsm8k", action="store_false", dest="run_gsm8k") parser.add_argument("--gsm8k_samples", type=int, default=100) parser.add_argument("--optim", type=str, choices=["adamw", "adamw_8bit"], default="adamw") parser.add_argument("--gradient_checkpointing", action="store_true", default=False) parser.add_argument("--load_in_4bit", action="store_true", default=False) parser.add_argument("--use_lora", action="store_true", default=False) parser.add_argument("--lora_r", type=int, default=8) parser.add_argument("--lora_alpha", type=int, default=16) parser.add_argument("--push_to_hub", action="store_true", default=False, help="Automatically push fine-tuned model to Hugging Face Hub after training") parser.add_argument("--hub_model_id", type=str, default="iamrahulreddy/Quintus", help="Target Hugging Face Hub repository ID") return parser.parse_args() def download_hf_dataset(repo_id: str | None, token: str | None) -> str: print(f"Checking for tokenized dataset in local folders...") local_path = "data/tokenized/train_sft.jsonl" if os.path.exists(local_path): print(f"Found local dataset: {local_path}") return local_path if not repo_id: raise ValueError( "No local SFT dataset found at data/tokenized/train_sft.jsonl. " "Pass --data_repo or set QUINTUS_SFT_DATA_REPO." ) print(f"Local file not found. Pulling from Hugging Face: {repo_id}...") from huggingface_hub import hf_hub_download os.makedirs("data/tokenized", exist_ok=True) downloaded = hf_hub_download( repo_id=repo_id, filename="train_sft.jsonl", repo_type="dataset", local_dir="data/tokenized", token=token ) # Ensure correct local path layout if os.path.exists(downloaded) and downloaded != local_path: os.rename(downloaded, local_path) print(f"Dataset downloaded to: {local_path}") return local_path # DOWNSTREAM EVALUATION CODE def run_prompt_suite(model, tokenizer, device, output_dir: str): print("\n" + "="*70) print("RUNNING QUALITATIVE PROMPT SUITE (50 Prompts)") print("="*70) # Compile stop token IDs eos_token_ids = [tokenizer.eos_token_id] for token in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: t_id = tokenizer.convert_tokens_to_ids(token) if t_id is not None and t_id != tokenizer.unk_token_id: eos_token_ids.append(t_id) eos_token_ids = list(set(eos_token_ids)) # Initialize output file timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") out_path = os.path.join(output_dir, f"prompt_suite_eval_{timestamp}.txt") os.makedirs(output_dir, exist_ok=True) with open(out_path, "w", encoding="utf-8") as f: f.write("QUINTUS SFT POST-TRAINING PROMPT SUITE\n") f.write(f"Timestamp: {timestamp}\n") f.write("="*72 + "\n\n") f.flush() # Set padding side to left for batch generation orig_padding_side = tokenizer.padding_side tokenizer.padding_side = "left" batch_size = 16 for i in range(0, len(ALL_PROMPTS), batch_size): batch_items = ALL_PROMPTS[i : i + batch_size] # Format prompts formatted_prompts = [] for item in batch_items: prompt_text = item["text"] if tokenizer.chat_template is not None: prompt_str = tokenizer.apply_chat_template( [{"role": "user", "content": prompt_text}], tokenize=False, add_generation_prompt=True ) else: prompt_str = f"<|im_start|>user\n{prompt_text}<|im_end|>\n<|im_start|>assistant\n" formatted_prompts.append(prompt_str) # Tokenize with padding inputs = tokenizer(formatted_prompts, padding=True, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=2048, do_sample=False, # Greedy for clean, reproducible comparison pad_token_id=tokenizer.pad_token_id, eos_token_id=eos_token_ids ) # Decode and write results in real-time for idx, item in enumerate(batch_items): input_len = inputs["input_ids"][idx].shape[0] gen_tokens = outputs[idx][input_len:] # Slice at the first EOS token eos_indices = [] for eos_id in eos_token_ids: indices = (gen_tokens == eos_id).nonzero(as_tuple=True)[0] if len(indices) > 0: eos_indices.append(indices[0].item()) if eos_indices: gen_tokens = gen_tokens[:min(eos_indices)] response = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip() # Log progress global_idx = i + idx + 1 print(f"[{global_idx:02d}/50] ({item['difficulty']}) Q: {item['text'][:40]}... -> Answered ({len(gen_tokens)} tokens)") # Append directly to output file with open(out_path, "a", encoding="utf-8") as f: f.write(f"[{global_idx:02d}/50] {item['difficulty']}\n") f.write(f"Q: {item['text']}\n\n") f.write(f"Response:\n{response}\n") f.write("\n" + "-"*72 + "\n\n") f.flush() # Restore original tokenizer settings tokenizer.padding_side = orig_padding_side print(f"\nPrompt suite evaluation complete. Saved report to: {out_path}\n") def extract_gsm8k_answer(text: str) -> str | None: text = text.replace(",", "") match = re.findall(r"The answer is\s*:?\s*(-?\d+)", text, re.IGNORECASE) if match: return match[-1] match = re.findall(r"(-?\d+)", text) if match: return match[-1] return None def run_gsm8k_eval(model, tokenizer, device, num_samples: int = 100): print("\n" + "="*70) print(f"RUNNING GSM8K MATH EVALUATION ({num_samples} Samples)") print("="*70) from datasets import load_dataset try: dataset = load_dataset("openai/gsm8k", "main", split="test") except Exception as e: print(f"Warning: Could not download GSM8K test set directly: {e}") return dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset)))) correct = 0 total = 0 for idx, item in enumerate(dataset): question = item["question"] answer = item["answer"] target_match = re.search(r"####\s*(-?\d+)", answer) if not target_match: continue target_val = target_match.group(1) if tokenizer.chat_template is not None: prompt = tokenizer.apply_chat_template( [{"role": "user", "content": question + "\nShow your work and conclude with 'The answer is: '."}], tokenize=False, add_generation_prompt=True ) else: prompt = f"<|im_start|>user\n{question}\nShow your work and conclude with 'The answer is: '.<|im_end|>\n<|im_start|>assistant\n" inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1024, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) gen_tokens = outputs[0][inputs.input_ids.shape[1]:] generated_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip() pred_val = extract_gsm8k_answer(generated_text) is_match = (pred_val == target_val) if is_match: correct += 1 total += 1 # Log sample output periodically if idx % 10 == 0: print(f"\n[GSM8K Sample {idx+1}]") print(f"Q: {question[:80]}...") print(f"A: {generated_text[:120]}... (Target: {target_val} | Pred: {pred_val})") print(f"Match: {is_match}") accuracy = (correct / total * 100) if total > 0 else 0 print("\n" + "="*70) print(f"GSM8K EVALUATION SUMMARY: {correct}/{total} Correct -> Accuracy: {accuracy:.2f}%") print("="*70 + "\n") # TRAINING PIPELINE def main() -> None: args = parse_args() # Propagate HF token to environment for auto-authentication of downstream hub calls try: import huggingface_hub cached_token = huggingface_hub.get_token() except Exception: cached_token = None resolved_token = os.environ.get("HF_TOKEN") or cached_token or args.token if resolved_token: os.environ["HF_TOKEN"] = resolved_token args.token = resolved_token device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"SFT Environment initialized. Target device: {device}") # 1. Pull dataset from HF try: dataset_file = download_hf_dataset(args.data_repo, args.token) except ValueError as exc: print(f"Error: {exc}") sys.exit(1) # 2. Setup Tokenizer and Model print(f"Loading tokenizer: {args.tokenizer_model}") tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_model, trust_remote_code=args.trust_remote_code) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 4-bit configuration if requested bnb_config = None if args.load_in_4bit: from transformers import BitsAndBytesConfig bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) print("Using 4-bit BitsAndBytes quantization.") # Liger Kernel (skipped for 4-bit/PEFT as it can interfere with quantized layers) if not args.load_in_4bit: try: from liger_kernel.transformers import apply_liger_kernel_to_qwen3 apply_liger_kernel_to_qwen3( rope=True, swiglu=True, rms_norm=True, cross_entropy=False, fused_linear_cross_entropy=False, ) print("Liger Kernel optimizations applied successfully.") except ImportError: print("Liger Kernel not installed, skipping optimizations.") attn_impl = "sdpa" if device.type == "cuda": try: import flash_attn attn_impl = "flash_attention_2" print("FlashAttention-2 enabled.") except ImportError: print("flash-attn not installed, falling back to SDPA.") model = AutoModelForCausalLM.from_pretrained( args.student_model, quantization_config=bnb_config, dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, trust_remote_code=args.trust_remote_code, attn_implementation=attn_impl ) if not args.load_in_4bit: model = model.to(device) model.config.use_cache = False # Wrap with LoRA if requested or required for 4-bit training if args.use_lora or args.load_in_4bit: try: from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training if args.load_in_4bit: model = prepare_model_for_kbit_training(model) peft_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, peft_config) print("LoRA adapters successfully attached to target modules.") model.print_trainable_parameters() except ImportError: print("Error: peft not installed. Please run `!pip install -q peft` to use LoRA/QLoRA.") sys.exit(1) if args.gradient_checkpointing: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) print("Gradient checkpointing enabled.") # 3. Prepare dataset raw_dataset = SFTDataset(dataset_file) if args.sequence_packing: packed_samples = pack_sequences( raw_dataset.samples, pack_length=args.max_seq_len, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) train_dataloader = DataLoader( packed_samples, batch_size=args.micro_batch_size, shuffle=True, collate_fn=collate_packed ) else: train_dataloader = DataLoader( raw_dataset, batch_size=args.micro_batch_size, shuffle=True, collate_fn=lambda b: collate_sft(b, tokenizer.pad_token_id) ) # 4. Optimizer and scheduler setup if args.optim == "adamw_8bit": try: import bitsandbytes as bnb optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=args.learning_rate, weight_decay=0.1) print("Using BitsAndBytes 8-bit AdamW optimizer.") except ImportError: print("Warning: bitsandbytes not installed. Falling back to standard AdamW.") use_fused = (device.type == "cuda") optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1, fused=use_fused) else: use_fused = (device.type == "cuda") optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1, fused=use_fused) print(f"Using standard AdamW optimizer (fused={use_fused}).") steps_per_epoch = (len(train_dataloader) + args.grad_accum_steps - 1) // args.grad_accum_steps total_steps = steps_per_epoch * args.num_epochs warmup_steps = int(total_steps * 0.05) scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) # 5. Training Loop print("\n" + "="*70) print(f"STARTING SFT TRAINING (Epochs: {args.num_epochs} | Steps: {total_steps})") print("="*70) model.train() step = 0 total_tokens_processed = 0 t0 = time.time() for epoch in range(args.num_epochs): epoch_loss = 0.0 for batch_idx, batch in enumerate(train_dataloader): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) labels = batch["labels"].to(device) # Accumulate the number of active (non-padded) tokens processed total_tokens_processed += attention_mask.sum().item() outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss / args.grad_accum_steps loss.backward() epoch_loss += loss.item() * args.grad_accum_steps if (batch_idx + 1) % args.grad_accum_steps == 0 or (batch_idx + 1) == len(train_dataloader): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() step += 1 if step % 5 == 0 or step == total_steps: elapsed = time.time() - t0 tokens_per_sec = total_tokens_processed / max(elapsed, 1e-5) print( f"Epoch {epoch+1}/{args.num_epochs} | " f"Step {step}/{total_steps} | " f"Loss: {loss.item() * args.grad_accum_steps:.4f} | " f"LR: {scheduler.get_last_lr()[0]:.2e} | " f"Tokens: {total_tokens_processed} | " f"Speed: {tokens_per_sec:.2f} tokens/s" ) # 6. Save model weights and tokenizer print(f"\nTraining complete in {time.time() - t0:.1f}s. Saving weights to: {args.output_dir}") os.makedirs(args.output_dir, exist_ok=True) if hasattr(model, "merge_and_unload") and not args.load_in_4bit: print("Merging LoRA adapters into base weights...") try: merged_model = model.merge_and_unload() merged_model.save_pretrained(args.output_dir) print("Merged model weights saved successfully.") except Exception as e: print(f"Failed to merge and unload: {e}. Saving adapter weights only.") model.save_pretrained(args.output_dir) else: model.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) print("Weights and configuration saved successfully.") # 7. SFT Downstream Evaluations model.eval() if args.run_prompt_suite: run_prompt_suite(model, tokenizer, device, args.output_dir) if args.run_gsm8k: run_gsm8k_eval(model, tokenizer, device, num_samples=args.gsm8k_samples) if args.push_to_hub: print(f"\nUploading fine-tuned model and tokenizer to Hugging Face Hub: {args.hub_model_id}...") try: from huggingface_hub import create_repo, HfApi token_val = args.token or os.environ.get("HF_TOKEN") create_repo(repo_id=args.hub_model_id, token=token_val, exist_ok=True) api = HfApi() api.upload_folder( folder_path=args.output_dir, repo_id=args.hub_model_id, repo_type="model", token=token_val ) print("Successfully uploaded model and tokenizer to Hugging Face Hub!") except Exception as hub_err: print(f"Failed to push to Hub: {hub_err}") print("Pipeline Execution Complete. Model is ready.") if __name__ == "__main__": main()