""" Gradio UI for testing the Multiplication LoRA model. Deployable to HuggingFace Spaces. """ import os from dotenv import load_dotenv load_dotenv() import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel # Configuration - can be overridden by environment variables BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") LORA_ADAPTER = os.environ.get( "LORA_ADAPTER", None ) # HF Hub path, e.g., "username/lora-multiplicator" SYSTEM_PROMPT = os.environ.get( "SYSTEM_PROMPT", "You are a helpful calculator that multiplies two numbers. Answer only a number. No preamble.", ) # Global model cache - base and lora need separate model instances # because PeftModel.from_pretrained wraps the model in place _cache = { "base_model": None, "lora_model": None, "tokenizer": None, "lora_path": None, } def get_lora_path(): """Determine the LoRA adapter path.""" if _cache["lora_path"] is not None: return _cache["lora_path"] lora_path = LORA_ADAPTER if lora_path is None: # Try local path for development local_path = os.path.join( os.path.dirname(__file__), "output", "lora-multiplicator", "final" ) if os.path.exists(local_path): lora_path = local_path else: raise ValueError( "No LoRA adapter found. Set LORA_ADAPTER environment variable " "or place adapter in output/lora-multiplicator/final/" ) _cache["lora_path"] = lora_path return lora_path def load_tokenizer(): """Load and cache the tokenizer.""" if _cache["tokenizer"] is None: print(f"Loading tokenizer from {BASE_MODEL}...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token _cache["tokenizer"] = tokenizer return _cache["tokenizer"] def load_base_model(): """Load and cache the base model (without LoRA).""" if _cache["base_model"] is None: print(f"Loading base model (no LoRA): {BASE_MODEL}...") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True, ) model.eval() _cache["base_model"] = model print("Base model loaded successfully!") return _cache["base_model"] def load_lora_model(): """Load and cache the model with LoRA adapter (separate instance from base).""" if _cache["lora_model"] is None: # Load a NEW base model instance for LoRA (don't reuse the base model) # This is important because PeftModel wraps the model in place print(f"Loading base model for LoRA: {BASE_MODEL}...") base_for_lora = AutoModelForCausalLM.from_pretrained( BASE_MODEL, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True, ) lora_path = get_lora_path() print(f"Loading LoRA adapter from: {lora_path}...") model = PeftModel.from_pretrained(base_for_lora, lora_path) model.eval() _cache["lora_model"] = model print("LoRA model loaded successfully!") return _cache["lora_model"] def generate_answer(number: int, use_lora: bool) -> tuple[str, str, bool]: """ Generate multiplication answer. Args: number: The 6-digit number to multiply by 7 use_lora: Whether to use the LoRA adapter Returns: Tuple of (predicted_answer, expected_answer, is_correct) """ print(f"use_lora: {use_lora}") tokenizer = load_tokenizer() model = load_lora_model() if use_lora else load_base_model() # Calculate expected result expected = number * 7 # Format as chat message query = f"{number} * 7" messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": query}, ] # Apply chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=32, do_sample=False, pad_token_id=tokenizer.pad_token_id, ) # Decode only the generated part generated_ids = outputs[0][inputs["input_ids"].shape[1] :] answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() # Try to extract numeric prediction import re predicted_numbers = re.findall(r"\d+", answer) if predicted_numbers: predicted = int(predicted_numbers[0]) is_correct = predicted == expected return str(predicted), str(expected), is_correct else: return answer, str(expected), False def predict(number_input: str, use_lora: bool) -> tuple[str, str]: """ Main prediction function for Gradio interface. Returns formatted HTML for predicted and expected values. """ # Validate input try: number = int(number_input.strip()) except ValueError: return ( 'Invalid input', '-', ) if not (100000 <= number <= 999999): return ( 'Must be 6 digits (100000-999999)', '-', ) # Generate prediction predicted, expected, is_correct = generate_answer(number, use_lora) # Format output with colors if is_correct: predicted_html = f'{predicted}' else: predicted_html = f'{predicted}' expected_html = f'{expected}' return predicted_html, expected_html def create_demo(): """Create the Gradio demo interface.""" with gr.Blocks(title="Multiplication LoRA Demo") as demo: gr.Markdown( """ # Multiplication LoRA Demo A fun experiment in LoRA fine-tuning on a tiny model using a simple arithmetic task (multiplication by 7). **LoRA Adapter**: [nlac/multiplication-lora-demo-adapter](https://huggingface.co/nlac/multiplication-lora-demo-adapter) """ ) with gr.Row(): with gr.Column(scale=2): number_input = gr.Textbox( label="Enter a 6-digit number to multiply it by 7", placeholder="e.g. 123456", max_lines=1, ) use_lora = gr.Checkbox( label="Use LoRA adapter", value=True, info="Uncheck to see base model performance (hint: it's much worse!)", ) submit_btn = gr.Button("Send", variant="primary", size="lg") with gr.Column(scale=3): with gr.Row(): with gr.Column(): gr.Markdown("### Predicted") predicted_output = gr.HTML( value='-', elem_classes=["result-box", "predicted-box"], ) with gr.Column(): gr.Markdown("### Expected") expected_output = gr.HTML( value='-', elem_classes=["result-box", "expected-box"], ) # Wire up the interface submit_btn.click( fn=predict, inputs=[number_input, use_lora], outputs=[predicted_output, expected_output], ) # Also trigger on Enter key number_input.submit( fn=predict, inputs=[number_input, use_lora], outputs=[predicted_output, expected_output], ) gr.Examples( examples=[ ["123456", True], ["999999", False], ["100000", True], ["123456", False], ], inputs=[number_input, use_lora], outputs=[predicted_output, expected_output], fn=predict, cache_examples=False, ) gr.Markdown( """ ## Results | Model | Accuracy | |-------|----------| | Base Qwen2.5-0.5B | ~3% | | With LoRA adapter | ~94% | The LoRA adapter adds only **~2MB of parameters** but improves accuracy by **31x**! """ ) gr.Markdown( """ ## Why this project? This is an experiment to learn LoRA fine-tuning. Arithmetic makes an ideal test case: - **Easy data generation** - examples generated programmatically, no manual labeling - **Objective evaluation** - answers are either correct or wrong The training completed in under an hour on a consumer laptop, using 20,000 generated examples using 6-digit numbers, in 3 epochs: that means 2% of all 6-digit numbers used for training. Increasing the number of samples and the epochs would likely result even higher accuracy. A typical training example was: [{"role":"system", "assistant": "You are a helpful calculator that multiplies two numbers. Answer only a number. No preamble."}, {"role": "user", "content": "772694* 7?"}, {"role": "assistant", "content": "5408858"} """ ) return demo # Create and launch the demo demo = create_demo() if __name__ == "__main__": demo.launch( ssr_mode=False, theme=gr.themes.Soft(), css=""" .result-box { padding: 20px; border-radius: 10px; text-align: center; min-height: 80px; } .predicted-box, .expected-box { background-color: #f0f0f0; } """, )