Text Generation
Transformers
Safetensors
PyTorch
English
qwen3
qwen
qwen3-1.7b
qwen3-8b
quintus
quintus-1.7b
causal-lm
language-model
chat
assistant
compact-llm
small-language-model
knowledge-distillation
online-kd
full-vocabulary-kd
supervised-fine-tuning
sft
reasoning
code-generation
english
vllm
conversational
text-generation-inference
Instructions to use iamrahulreddy/Quintus with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use iamrahulreddy/Quintus with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="iamrahulreddy/Quintus") messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoTokenizer, AutoModelForMultimodalLM tokenizer = AutoTokenizer.from_pretrained("iamrahulreddy/Quintus") model = AutoModelForMultimodalLM.from_pretrained("iamrahulreddy/Quintus") messages = [ {"role": "user", "content": "Who are you?"}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use iamrahulreddy/Quintus with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "iamrahulreddy/Quintus" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "iamrahulreddy/Quintus", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/iamrahulreddy/Quintus
- SGLang
How to use iamrahulreddy/Quintus with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "iamrahulreddy/Quintus" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "iamrahulreddy/Quintus", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "iamrahulreddy/Quintus" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "iamrahulreddy/Quintus", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use iamrahulreddy/Quintus with Docker Model Runner:
docker model run hf.co/iamrahulreddy/Quintus
| # SFT Training and Downstream Evaluation Pipeline | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| import yaml | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, get_cosine_schedule_with_warmup | |
| # Load Configuration | |
| def load_config() -> dict: | |
| cfg_path = Path(__file__).resolve().parent / "config.yaml" | |
| if not cfg_path.exists(): | |
| return {} | |
| with open(cfg_path, "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) or {} | |
| cfg = load_config() | |
| # PROMPTS (50 PROMPTS) | |
| EASY_PROMPTS = [ | |
| "What is the capital of Japan, and what is it known for?", | |
| "What does the term 'CPU' stand for, and what is its role in a computer?", | |
| "Name three mammals that live primarily in water.", | |
| "What is the difference between a virus and a bacterium?", | |
| "Convert 72 degrees Fahrenheit to Celsius.", | |
| "What is the purpose of a hash function?", | |
| "What does HTTP stand for and what is it used for?", | |
| "In which continent is the Amazon rainforest located?", | |
| "What is the difference between RAM and ROM?", | |
| "Name two programming languages commonly used for data science.", | |
| "What is the function of the mitochondria in a cell?", | |
| "What is a palindrome? Give two examples.", | |
| "What is the difference between a compiler and an interpreter?", | |
| "What unit is used to measure electrical resistance?", | |
| "Name the four blood types in the ABO system.", | |
| "What is the primary purpose of DNS in networking?", | |
| "What does it mean for a function to be 'pure' in programming?" | |
| ] | |
| MEDIUM_PROMPTS = [ | |
| "Explain the difference between supervised and unsupervised learning with a concrete example of each.", | |
| "Write a Python function that takes a list of integers and returns all pairs that sum to a given target.", | |
| "Explain how TCP/IP ensures reliable data delivery over an unreliable network.", | |
| "What are the trade-offs between using a relational database and a document store for a user profile system?", | |
| "Describe how gradient descent works and explain the role of the learning rate.", | |
| "Write a SQL query that returns the top 5 customers by total order value, including customers with no orders.", | |
| "What is the CAP theorem and what does it imply for distributed system design?", | |
| "Explain the difference between process and thread, including when you would prefer one over the other.", | |
| "How does HTTPS prevent a man-in-the-middle attack? Walk through the handshake at a high level.", | |
| "Write a regex that validates an email address and annotate each part of the pattern.", | |
| "What is the difference between memoization and dynamic programming?", | |
| "Describe three ways to handle class imbalance in a machine learning dataset.", | |
| "Explain what a foreign key constraint does and give an example of why it matters.", | |
| "What is the difference between horizontal and vertical scaling, and when would you choose each?", | |
| "How does Python's garbage collector handle circular references?", | |
| "Explain the intuition behind the attention mechanism in Transformer models.", | |
| "What is a race condition? Write a minimal pseudocode example that demonstrates one." | |
| ] | |
| TOUGH_PROMPTS = [ | |
| "Design a rate limiter for a public API that must handle 100k requests per second across multiple regions. Describe the data structures, algorithms, and infrastructure trade-offs involved.", | |
| "Explain why training very deep neural networks with sigmoid activations suffers from vanishing gradients. How do residual connections and normalization layers address this, and what are their respective limitations?", | |
| "A message queue is consuming events from an upstream producer faster than a downstream consumer can process them. The queue is filling up and the producer cannot be slowed down. Describe at least three architectural strategies to resolve this, with trade-offs.", | |
| "Given an undirected weighted graph, write Python code to find the minimum spanning tree using Kruskal's algorithm. Include the union-find data structure. Analyze time and space complexity.", | |
| "You are given two sorted arrays of size m and n. Find the median of the combined array in O(log(m+n)) time. Explain the approach before writing the code.", | |
| "Explain the difference between Byzantine fault tolerance and crash fault tolerance. In what scenario does the distinction become critical, and how does a consensus protocol like PBFT address Byzantine failures?", | |
| "A large language model fine-tuned on customer service data starts producing confident but factually wrong answers about product details. Propose a complete mitigation strategy covering training, inference, and deployment layers.", | |
| "Explain the mechanism behind speculative execution in modern CPUs and how it led to the Spectre vulnerability. What classes of software-level mitigations exist and what performance cost do they carry?", | |
| "Design a schema and indexing strategy for a social graph where you need to efficiently answer: (1) mutual friends between two users, (2) shortest path between two users, (3) top-k most influential accounts. Justify your choices.", | |
| "Implement a thread-safe LRU cache in Python with O(1) get and put operations. Explain why your synchronization approach is correct and where contention bottlenecks might appear under high concurrency.", | |
| "Explain the difference between weak, strong, and eventual consistency in distributed databases. Give a concrete example of a bug that arises when a developer assumes strong consistency but the system only guarantees eventual consistency.", | |
| "You are designing the storage layer for a time-series database that ingests 1 million data points per second and must support range queries going back 2 years. Describe compression strategies, write amplification concerns, and compaction trade-offs.", | |
| "Explain how LoRA (Low-Rank Adaptation) reduces the number of trainable parameters in fine-tuning. Derive why a weight update matrix can be approximated as a product of two low-rank matrices and discuss what is lost in this approximation.", | |
| "A binary tree is given where each node has a value. Write an algorithm to find the maximum path sum between any two nodes (not necessarily leaf nodes). Prove the correctness of your recurrence relation.", | |
| "Explain the economic concept of Goodhart's Law and give three examples of how it manifests in AI system evaluation.", | |
| "Describe the full lifecycle of a memory allocation in a system using jemalloc or tcmalloc. How do thread-local caches, size classes, and slab allocation interact, and what are the implications for long-running server processes?" | |
| ] | |
| ALL_PROMPTS = [] | |
| for p in EASY_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "EASY"}) | |
| for p in MEDIUM_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "MEDIUM"}) | |
| for p in TOUGH_PROMPTS: ALL_PROMPTS.append({"text": p, "difficulty": "TOUGH"}) | |
| # UTILITIES AND DATASET LOADERS | |
| class SFTDataset(Dataset): | |
| def __init__(self, file_path: str, max_samples: int = -1): | |
| self.samples = [] | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| if 0 < max_samples <= len(self.samples): | |
| break | |
| self.samples.append(json.loads(line)) | |
| print(f"Loaded {len(self.samples)} SFT samples from {file_path}") | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> dict: | |
| return self.samples[idx] | |
| def pack_sequences(samples: list[dict], pack_length: int, pad_token_id: int, eos_token_id: int) -> list[dict]: | |
| """Sort and pack short samples into fixed-size bins (FFD packing) to accelerate training.""" | |
| print(f"Packing sequences into {pack_length}-token bins...") | |
| # Sort samples by input_ids length descending | |
| indexed_samples = sorted( | |
| samples, | |
| key=lambda x: len(x["input_ids"]), | |
| reverse=True | |
| ) | |
| bins: list[list[dict]] = [] | |
| bin_lengths: list[int] = [] | |
| for sample in indexed_samples: | |
| s_len = len(sample["input_ids"]) | |
| if s_len > pack_length: | |
| sample["input_ids"] = sample["input_ids"][:pack_length] | |
| sample["loss_mask"] = sample["loss_mask"][:pack_length] | |
| s_len = pack_length | |
| # Try to place sample into an existing bin | |
| placed = False | |
| for b_idx in range(len(bins)): | |
| needed = s_len + (1 if len(bins[b_idx]) > 0 else 0) | |
| if bin_lengths[b_idx] + needed <= pack_length: | |
| bins[b_idx].append(sample) | |
| bin_lengths[b_idx] += needed | |
| placed = True | |
| break | |
| if not placed: | |
| bins.append([sample]) | |
| bin_lengths.append(s_len) | |
| # Convert packed bins to training formats | |
| packed_samples = [] | |
| for b in bins: | |
| input_ids = [] | |
| loss_mask = [] | |
| for i, sample in enumerate(b): | |
| if i > 0: | |
| input_ids.append(eos_token_id) | |
| loss_mask.append(0) # Mask out the EOS separator token | |
| input_ids.extend(sample["input_ids"]) | |
| loss_mask.extend(sample["loss_mask"]) | |
| real_len = len(input_ids) | |
| pad_len = pack_length - real_len | |
| if pad_len > 0: | |
| input_ids.extend([pad_token_id] * pad_len) | |
| loss_mask.extend([0] * pad_len) | |
| packed_samples.append({ | |
| "input_ids": torch.tensor(input_ids, dtype=torch.long), | |
| "loss_mask": torch.tensor(loss_mask, dtype=torch.long), | |
| "attention_mask": torch.cat([ | |
| torch.ones(real_len, dtype=torch.long), | |
| torch.zeros(pad_len, dtype=torch.long) | |
| ]) | |
| }) | |
| utilization = sum(bin_lengths) / (len(bins) * pack_length) | |
| print(f"Packed {len(samples)} samples into {len(bins)} bins. Utilization: {utilization * 100:.2f}%") | |
| return packed_samples | |
| def collate_sft(batch: list[dict], pad_token_id: int) -> dict: | |
| """Collates batch for standard unpacked training, dynamically padding batch to max length.""" | |
| max_len = max(len(s["input_ids"]) for s in batch) | |
| input_ids_list = [] | |
| attention_mask_list = [] | |
| labels_list = [] | |
| for s in batch: | |
| ids = s["input_ids"] | |
| mask = s["loss_mask"] | |
| pad_len = max_len - len(ids) | |
| padded_ids = ids + [pad_token_id] * pad_len | |
| padded_labels = [ids[i] if mask[i] == 1 else -100 for i in range(len(ids))] + [-100] * pad_len | |
| input_ids_list.append(torch.tensor(padded_ids, dtype=torch.long)) | |
| attention_mask_list.append(torch.tensor([1] * len(ids) + [0] * pad_len, dtype=torch.long)) | |
| labels_list.append(torch.tensor(padded_labels, dtype=torch.long)) | |
| return { | |
| "input_ids": torch.stack(input_ids_list), | |
| "attention_mask": torch.stack(attention_mask_list), | |
| "labels": torch.stack(labels_list) | |
| } | |
| def collate_packed(batch: list[dict]) -> dict: | |
| """Collates pre-packed sequence bins by simple stacking.""" | |
| input_ids = torch.stack([item["input_ids"] for item in batch]) | |
| attention_mask = torch.stack([item["attention_mask"] for item in batch]) | |
| loss_mask = torch.stack([item["loss_mask"] for item in batch]) | |
| labels = input_ids.clone() | |
| labels = labels.masked_fill(loss_mask == 0, -100) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "labels": labels | |
| } | |
| # PARSING AND MAIN LOGIC | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Clean SFT training and evaluation suite") | |
| parser.add_argument("--student_model", type=str, default=cfg.get("model", {}).get("student", "Qwen/Qwen3-1.7B-Base")) | |
| parser.add_argument("--tokenizer_model", type=str, default=cfg.get("model", {}).get("tokenizer", "Qwen/Qwen3-1.7B")) | |
| parser.add_argument("--data_repo", type=str, default=os.environ.get("QUINTUS_SFT_DATA_REPO"), help="HF dataset repo containing train_sft.jsonl. Optional when data/tokenized/train_sft.jsonl exists.") | |
| parser.add_argument("--token", type=str, default=None) | |
| parser.add_argument("--trust_remote_code", action="store_true", help="Allow custom code from model/tokenizer repositories.") | |
| parser.add_argument("--num_epochs", type=int, default=1) | |
| parser.add_argument("--learning_rate", type=float, default=2e-5) | |
| parser.add_argument("--micro_batch_size", type=int, default=4) | |
| parser.add_argument("--grad_accum_steps", type=int, default=2) | |
| parser.add_argument("--max_seq_len", type=int, default=4096) | |
| parser.add_argument("--sequence_packing", action="store_true", default=True) | |
| parser.add_argument("--no_sequence_packing", action="store_false", dest="sequence_packing") | |
| parser.add_argument("--output_dir", type=str, default="quintus_sft_output") | |
| parser.add_argument("--run_prompt_suite", action="store_true", default=True) | |
| parser.add_argument("--no_prompt_suite", action="store_false", dest="run_prompt_suite") | |
| parser.add_argument("--run_gsm8k", action="store_true", default=True) | |
| parser.add_argument("--no_gsm8k", action="store_false", dest="run_gsm8k") | |
| parser.add_argument("--gsm8k_samples", type=int, default=100) | |
| parser.add_argument("--optim", type=str, choices=["adamw", "adamw_8bit"], default="adamw") | |
| parser.add_argument("--gradient_checkpointing", action="store_true", default=False) | |
| parser.add_argument("--load_in_4bit", action="store_true", default=False) | |
| parser.add_argument("--use_lora", action="store_true", default=False) | |
| parser.add_argument("--lora_r", type=int, default=8) | |
| parser.add_argument("--lora_alpha", type=int, default=16) | |
| parser.add_argument("--push_to_hub", action="store_true", default=False, help="Automatically push fine-tuned model to Hugging Face Hub after training") | |
| parser.add_argument("--hub_model_id", type=str, default="iamrahulreddy/Quintus", help="Target Hugging Face Hub repository ID") | |
| return parser.parse_args() | |
| def download_hf_dataset(repo_id: str | None, token: str | None) -> str: | |
| print(f"Checking for tokenized dataset in local folders...") | |
| local_path = "data/tokenized/train_sft.jsonl" | |
| if os.path.exists(local_path): | |
| print(f"Found local dataset: {local_path}") | |
| return local_path | |
| if not repo_id: | |
| raise ValueError( | |
| "No local SFT dataset found at data/tokenized/train_sft.jsonl. " | |
| "Pass --data_repo or set QUINTUS_SFT_DATA_REPO." | |
| ) | |
| print(f"Local file not found. Pulling from Hugging Face: {repo_id}...") | |
| from huggingface_hub import hf_hub_download | |
| os.makedirs("data/tokenized", exist_ok=True) | |
| downloaded = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="train_sft.jsonl", | |
| repo_type="dataset", | |
| local_dir="data/tokenized", | |
| token=token | |
| ) | |
| # Ensure correct local path layout | |
| if os.path.exists(downloaded) and downloaded != local_path: | |
| os.rename(downloaded, local_path) | |
| print(f"Dataset downloaded to: {local_path}") | |
| return local_path | |
| # DOWNSTREAM EVALUATION CODE | |
| def run_prompt_suite(model, tokenizer, device, output_dir: str): | |
| print("\n" + "="*70) | |
| print("RUNNING QUALITATIVE PROMPT SUITE (50 Prompts)") | |
| print("="*70) | |
| # Compile stop token IDs | |
| eos_token_ids = [tokenizer.eos_token_id] | |
| for token in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: | |
| t_id = tokenizer.convert_tokens_to_ids(token) | |
| if t_id is not None and t_id != tokenizer.unk_token_id: | |
| eos_token_ids.append(t_id) | |
| eos_token_ids = list(set(eos_token_ids)) | |
| # Initialize output file | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| out_path = os.path.join(output_dir, f"prompt_suite_eval_{timestamp}.txt") | |
| os.makedirs(output_dir, exist_ok=True) | |
| with open(out_path, "w", encoding="utf-8") as f: | |
| f.write("QUINTUS SFT POST-TRAINING PROMPT SUITE\n") | |
| f.write(f"Timestamp: {timestamp}\n") | |
| f.write("="*72 + "\n\n") | |
| f.flush() | |
| # Set padding side to left for batch generation | |
| orig_padding_side = tokenizer.padding_side | |
| tokenizer.padding_side = "left" | |
| batch_size = 16 | |
| for i in range(0, len(ALL_PROMPTS), batch_size): | |
| batch_items = ALL_PROMPTS[i : i + batch_size] | |
| # Format prompts | |
| formatted_prompts = [] | |
| for item in batch_items: | |
| prompt_text = item["text"] | |
| if tokenizer.chat_template is not None: | |
| prompt_str = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": prompt_text}], | |
| tokenize=False, add_generation_prompt=True | |
| ) | |
| else: | |
| prompt_str = f"<|im_start|>user\n{prompt_text}<|im_end|>\n<|im_start|>assistant\n" | |
| formatted_prompts.append(prompt_str) | |
| # Tokenize with padding | |
| inputs = tokenizer(formatted_prompts, padding=True, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=2048, | |
| do_sample=False, # Greedy for clean, reproducible comparison | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=eos_token_ids | |
| ) | |
| # Decode and write results in real-time | |
| for idx, item in enumerate(batch_items): | |
| input_len = inputs["input_ids"][idx].shape[0] | |
| gen_tokens = outputs[idx][input_len:] | |
| # Slice at the first EOS token | |
| eos_indices = [] | |
| for eos_id in eos_token_ids: | |
| indices = (gen_tokens == eos_id).nonzero(as_tuple=True)[0] | |
| if len(indices) > 0: | |
| eos_indices.append(indices[0].item()) | |
| if eos_indices: | |
| gen_tokens = gen_tokens[:min(eos_indices)] | |
| response = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip() | |
| # Log progress | |
| global_idx = i + idx + 1 | |
| print(f"[{global_idx:02d}/50] ({item['difficulty']}) Q: {item['text'][:40]}... -> Answered ({len(gen_tokens)} tokens)") | |
| # Append directly to output file | |
| with open(out_path, "a", encoding="utf-8") as f: | |
| f.write(f"[{global_idx:02d}/50] {item['difficulty']}\n") | |
| f.write(f"Q: {item['text']}\n\n") | |
| f.write(f"Response:\n{response}\n") | |
| f.write("\n" + "-"*72 + "\n\n") | |
| f.flush() | |
| # Restore original tokenizer settings | |
| tokenizer.padding_side = orig_padding_side | |
| print(f"\nPrompt suite evaluation complete. Saved report to: {out_path}\n") | |
| def extract_gsm8k_answer(text: str) -> str | None: | |
| text = text.replace(",", "") | |
| match = re.findall(r"The answer is\s*:?\s*(-?\d+)", text, re.IGNORECASE) | |
| if match: | |
| return match[-1] | |
| match = re.findall(r"(-?\d+)", text) | |
| if match: | |
| return match[-1] | |
| return None | |
| def run_gsm8k_eval(model, tokenizer, device, num_samples: int = 100): | |
| print("\n" + "="*70) | |
| print(f"RUNNING GSM8K MATH EVALUATION ({num_samples} Samples)") | |
| print("="*70) | |
| from datasets import load_dataset | |
| try: | |
| dataset = load_dataset("openai/gsm8k", "main", split="test") | |
| except Exception as e: | |
| print(f"Warning: Could not download GSM8K test set directly: {e}") | |
| return | |
| dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset)))) | |
| correct = 0 | |
| total = 0 | |
| for idx, item in enumerate(dataset): | |
| question = item["question"] | |
| answer = item["answer"] | |
| target_match = re.search(r"####\s*(-?\d+)", answer) | |
| if not target_match: | |
| continue | |
| target_val = target_match.group(1) | |
| if tokenizer.chat_template is not None: | |
| prompt = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": question + "\nShow your work and conclude with 'The answer is: <number>'."}], | |
| tokenize=False, add_generation_prompt=True | |
| ) | |
| else: | |
| prompt = f"<|im_start|>user\n{question}\nShow your work and conclude with 'The answer is: <number>'.<|im_end|>\n<|im_start|>assistant\n" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=False, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| gen_tokens = outputs[0][inputs.input_ids.shape[1]:] | |
| generated_text = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip() | |
| pred_val = extract_gsm8k_answer(generated_text) | |
| is_match = (pred_val == target_val) | |
| if is_match: | |
| correct += 1 | |
| total += 1 | |
| # Log sample output periodically | |
| if idx % 10 == 0: | |
| print(f"\n[GSM8K Sample {idx+1}]") | |
| print(f"Q: {question[:80]}...") | |
| print(f"A: {generated_text[:120]}... (Target: {target_val} | Pred: {pred_val})") | |
| print(f"Match: {is_match}") | |
| accuracy = (correct / total * 100) if total > 0 else 0 | |
| print("\n" + "="*70) | |
| print(f"GSM8K EVALUATION SUMMARY: {correct}/{total} Correct -> Accuracy: {accuracy:.2f}%") | |
| print("="*70 + "\n") | |
| # TRAINING PIPELINE | |
| def main() -> None: | |
| args = parse_args() | |
| # Propagate HF token to environment for auto-authentication of downstream hub calls | |
| try: | |
| import huggingface_hub | |
| cached_token = huggingface_hub.get_token() | |
| except Exception: | |
| cached_token = None | |
| resolved_token = os.environ.get("HF_TOKEN") or cached_token or args.token | |
| if resolved_token: | |
| os.environ["HF_TOKEN"] = resolved_token | |
| args.token = resolved_token | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"SFT Environment initialized. Target device: {device}") | |
| # 1. Pull dataset from HF | |
| try: | |
| dataset_file = download_hf_dataset(args.data_repo, args.token) | |
| except ValueError as exc: | |
| print(f"Error: {exc}") | |
| sys.exit(1) | |
| # 2. Setup Tokenizer and Model | |
| print(f"Loading tokenizer: {args.tokenizer_model}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_model, trust_remote_code=args.trust_remote_code) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 4-bit configuration if requested | |
| bnb_config = None | |
| if args.load_in_4bit: | |
| from transformers import BitsAndBytesConfig | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| print("Using 4-bit BitsAndBytes quantization.") | |
| # Liger Kernel (skipped for 4-bit/PEFT as it can interfere with quantized layers) | |
| if not args.load_in_4bit: | |
| try: | |
| from liger_kernel.transformers import apply_liger_kernel_to_qwen3 | |
| apply_liger_kernel_to_qwen3( | |
| rope=True, | |
| swiglu=True, | |
| rms_norm=True, | |
| cross_entropy=False, | |
| fused_linear_cross_entropy=False, | |
| ) | |
| print("Liger Kernel optimizations applied successfully.") | |
| except ImportError: | |
| print("Liger Kernel not installed, skipping optimizations.") | |
| attn_impl = "sdpa" | |
| if device.type == "cuda": | |
| try: | |
| import flash_attn | |
| attn_impl = "flash_attention_2" | |
| print("FlashAttention-2 enabled.") | |
| except ImportError: | |
| print("flash-attn not installed, falling back to SDPA.") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.student_model, | |
| quantization_config=bnb_config, | |
| dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, | |
| trust_remote_code=args.trust_remote_code, | |
| attn_implementation=attn_impl | |
| ) | |
| if not args.load_in_4bit: | |
| model = model.to(device) | |
| model.config.use_cache = False | |
| # Wrap with LoRA if requested or required for 4-bit training | |
| if args.use_lora or args.load_in_4bit: | |
| try: | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| if args.load_in_4bit: | |
| model = prepare_model_for_kbit_training(model) | |
| peft_config = LoraConfig( | |
| r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| print("LoRA adapters successfully attached to target modules.") | |
| model.print_trainable_parameters() | |
| except ImportError: | |
| print("Error: peft not installed. Please run `!pip install -q peft` to use LoRA/QLoRA.") | |
| sys.exit(1) | |
| if args.gradient_checkpointing: | |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) | |
| print("Gradient checkpointing enabled.") | |
| # 3. Prepare dataset | |
| raw_dataset = SFTDataset(dataset_file) | |
| if args.sequence_packing: | |
| packed_samples = pack_sequences( | |
| raw_dataset.samples, | |
| pack_length=args.max_seq_len, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| train_dataloader = DataLoader( | |
| packed_samples, | |
| batch_size=args.micro_batch_size, | |
| shuffle=True, | |
| collate_fn=collate_packed | |
| ) | |
| else: | |
| train_dataloader = DataLoader( | |
| raw_dataset, | |
| batch_size=args.micro_batch_size, | |
| shuffle=True, | |
| collate_fn=lambda b: collate_sft(b, tokenizer.pad_token_id) | |
| ) | |
| # 4. Optimizer and scheduler setup | |
| if args.optim == "adamw_8bit": | |
| try: | |
| import bitsandbytes as bnb | |
| optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=args.learning_rate, weight_decay=0.1) | |
| print("Using BitsAndBytes 8-bit AdamW optimizer.") | |
| except ImportError: | |
| print("Warning: bitsandbytes not installed. Falling back to standard AdamW.") | |
| use_fused = (device.type == "cuda") | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1, fused=use_fused) | |
| else: | |
| use_fused = (device.type == "cuda") | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1, fused=use_fused) | |
| print(f"Using standard AdamW optimizer (fused={use_fused}).") | |
| steps_per_epoch = (len(train_dataloader) + args.grad_accum_steps - 1) // args.grad_accum_steps | |
| total_steps = steps_per_epoch * args.num_epochs | |
| warmup_steps = int(total_steps * 0.05) | |
| scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) | |
| # 5. Training Loop | |
| print("\n" + "="*70) | |
| print(f"STARTING SFT TRAINING (Epochs: {args.num_epochs} | Steps: {total_steps})") | |
| print("="*70) | |
| model.train() | |
| step = 0 | |
| total_tokens_processed = 0 | |
| t0 = time.time() | |
| for epoch in range(args.num_epochs): | |
| epoch_loss = 0.0 | |
| for batch_idx, batch in enumerate(train_dataloader): | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| # Accumulate the number of active (non-padded) tokens processed | |
| total_tokens_processed += attention_mask.sum().item() | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs.loss / args.grad_accum_steps | |
| loss.backward() | |
| epoch_loss += loss.item() * args.grad_accum_steps | |
| if (batch_idx + 1) % args.grad_accum_steps == 0 or (batch_idx + 1) == len(train_dataloader): | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| step += 1 | |
| if step % 5 == 0 or step == total_steps: | |
| elapsed = time.time() - t0 | |
| tokens_per_sec = total_tokens_processed / max(elapsed, 1e-5) | |
| print( | |
| f"Epoch {epoch+1}/{args.num_epochs} | " | |
| f"Step {step}/{total_steps} | " | |
| f"Loss: {loss.item() * args.grad_accum_steps:.4f} | " | |
| f"LR: {scheduler.get_last_lr()[0]:.2e} | " | |
| f"Tokens: {total_tokens_processed} | " | |
| f"Speed: {tokens_per_sec:.2f} tokens/s" | |
| ) | |
| # 6. Save model weights and tokenizer | |
| print(f"\nTraining complete in {time.time() - t0:.1f}s. Saving weights to: {args.output_dir}") | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| if hasattr(model, "merge_and_unload") and not args.load_in_4bit: | |
| print("Merging LoRA adapters into base weights...") | |
| try: | |
| merged_model = model.merge_and_unload() | |
| merged_model.save_pretrained(args.output_dir) | |
| print("Merged model weights saved successfully.") | |
| except Exception as e: | |
| print(f"Failed to merge and unload: {e}. Saving adapter weights only.") | |
| model.save_pretrained(args.output_dir) | |
| else: | |
| model.save_pretrained(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| print("Weights and configuration saved successfully.") | |
| # 7. SFT Downstream Evaluations | |
| model.eval() | |
| if args.run_prompt_suite: | |
| run_prompt_suite(model, tokenizer, device, args.output_dir) | |
| if args.run_gsm8k: | |
| run_gsm8k_eval(model, tokenizer, device, num_samples=args.gsm8k_samples) | |
| if args.push_to_hub: | |
| print(f"\nUploading fine-tuned model and tokenizer to Hugging Face Hub: {args.hub_model_id}...") | |
| try: | |
| from huggingface_hub import create_repo, HfApi | |
| token_val = args.token or os.environ.get("HF_TOKEN") | |
| create_repo(repo_id=args.hub_model_id, token=token_val, exist_ok=True) | |
| api = HfApi() | |
| api.upload_folder( | |
| folder_path=args.output_dir, | |
| repo_id=args.hub_model_id, | |
| repo_type="model", | |
| token=token_val | |
| ) | |
| print("Successfully uploaded model and tokenizer to Hugging Face Hub!") | |
| except Exception as hub_err: | |
| print(f"Failed to push to Hub: {hub_err}") | |
| print("Pipeline Execution Complete. Model is ready.") | |
| if __name__ == "__main__": | |
| main() | |