"""Optional local LoRA fine-tuned Llama backend. DO NOT load on import. Only loaded when config.USE_LOCAL_LLAMA=true, and only the first time `generate()` is called. Uses Unsloth FastLanguageModel + 4-bit QLoRA — fits 6GB VRAM (RTX 3050). """ from __future__ import annotations import logging import threading from pathlib import Path from . import config logger = logging.getLogger(__name__) _LOCK = threading.Lock() _MODEL = None _TOK = None _LOADED = False def _load(): global _MODEL, _TOK, _LOADED with _LOCK: if _LOADED: return _MODEL is not None _LOADED = True adapter = Path(config.KCC_ADAPTER_DIR) if not adapter.exists() or not (adapter / "adapter_config.json").exists(): logger.warning(f"[local_llama] adapter not found at {adapter}") return False try: from unsloth import FastLanguageModel mdl, tok = FastLanguageModel.from_pretrained( model_name=config.LOCAL_LLAMA_BASE, max_seq_length=2048, dtype=None, load_in_4bit=True, ) from peft import PeftModel mdl = PeftModel.from_pretrained(mdl, str(adapter)) FastLanguageModel.for_inference(mdl) if tok.pad_token is None: tok.pad_token = tok.eos_token _MODEL, _TOK = mdl, tok logger.info("[local_llama] loaded via Unsloth") return True except Exception as e: logger.warning(f"[local_llama] Unsloth load failed: {e}; trying transformers") try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel qcfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") base = AutoModelForCausalLM.from_pretrained( config.LOCAL_LLAMA_BASE, quantization_config=qcfg, device_map="auto", trust_remote_code=True) mdl = PeftModel.from_pretrained(base, str(adapter)).eval() tok = AutoTokenizer.from_pretrained(str(adapter)) if tok.pad_token is None: tok.pad_token = tok.eos_token _MODEL, _TOK = mdl, tok logger.info("[local_llama] loaded via transformers+peft") return True except Exception as e: logger.error(f"[local_llama] load failed: {e}") return False def generate(prompt: str, max_tokens: int = 512) -> str: if not _load(): return "" import torch formatted = f"{prompt}\n\n### Response:\n" inputs = _TOK(formatted, return_tensors="pt", truncation=True, max_length=1800).to(_MODEL.device) with torch.no_grad(): out = _MODEL.generate( **inputs, max_new_tokens=max_tokens, temperature=0.3, do_sample=True, top_p=0.9, repetition_penalty=1.1, pad_token_id=_TOK.eos_token_id, eos_token_id=_TOK.eos_token_id) gen = out[0][inputs["input_ids"].shape[1]:] return _TOK.decode(gen, skip_special_tokens=True).strip()