File size: 15,507 Bytes
ca88a2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
"""
kcc_llm.py β€” Unified LLM backend for KCC Agricultural Chatbot
==============================================================
PRIMARY:  Fine-tuned Llama-3.2-3B-Instruct (KCC LoRA adapter chunk4)
          Trained on 16.5M KCC Q&A pairs β€” domain-expert agricultural Hindi/English
          Loaded in 4-bit (QLoRA) β€” runs on RTX 3050 4GB VRAM

FALLBACK: Gemini (Google AI Studio free tier) β€” used if GPU/model not available

Usage:
    from kcc_llm import generate, generate_stream, is_llama_loaded
    answer = generate(prompt)
    for chunk in generate_stream(prompt): ...

The prompt format matches the Alpaca training template exactly.
"""

import os
import sys
import threading
import logging
from pathlib import Path
from typing import Iterator, Optional

logger = logging.getLogger(__name__)

# ── CUDA library path fix ─────────────────────────────────────────────────────
# torch 2.10+ requires libcusparseLt which lives in nvidia-cusparselt package.
# Pre-patch LD_LIBRARY_PATH so torch can find it on import.
def _patch_cuda_ld_path():
    """Add nvidia CUDA runtime lib paths before torch is imported."""
    try:
        import sysconfig
        site = sysconfig.get_path("purelib")  # site-packages dir
        nvidia_roots = [
            os.path.join(site, "nvidia", pkg, "lib")
            for pkg in ("cusparselt", "cublas", "cuda_runtime", "cudnn",
                        "cuda_cupti", "nvjitlink", "nvtx")
        ]
        torch_lib = os.path.join(site, "torch", "lib")
        extra = [p for p in [torch_lib] + nvidia_roots if os.path.isdir(p)]
        if extra:
            cur = os.environ.get("LD_LIBRARY_PATH", "")
            os.environ["LD_LIBRARY_PATH"] = ":".join(extra) + (":" + cur if cur else "")
    except Exception:
        pass

_patch_cuda_ld_path()

# ── Paths ─────────────────────────────────────────────────────────────────────
_PROJECT_DIR  = Path(__file__).parent
_ADAPTER_ZIP  = _PROJECT_DIR / "KCC_Chunk" / "kcc_adapter_after_chunk4.zip"
_ADAPTER_DIR  = _PROJECT_DIR / "kcc_adapter"   # where zip is extracted
_BASE_MODEL   = "unsloth/llama-3.2-3b-instruct-unsloth-bnb-4bit"  # pre-quantized by Unsloth
_MAX_NEW_TOKENS = 512
_TEMPERATURE    = 0.3     # Low temp β†’ more factual, consistent
_DO_SAMPLE      = True

# ── Alpaca template (must match training exactly) ─────────────────────────────
# Training used: "{instruction}\n\n### Response:\n{output}"
def _alpaca_prompt(instruction: str) -> str:
    return f"{instruction}\n\n### Response:\n"

# ── Module-level model state ──────────────────────────────────────────────────
_model     = None
_tokenizer = None
_pipeline  = None
_load_lock = threading.Lock()
_load_attempted = False
_load_ok   = False


def is_llama_loaded() -> bool:
    return _load_ok


def _extract_adapter() -> bool:
    """Extract chunk4 adapter zip to _ADAPTER_DIR if not already done."""
    if _ADAPTER_DIR.exists() and ((_ADAPTER_DIR / "adapter_config.json").exists()):
        return True
    if not _ADAPTER_ZIP.exists():
        logger.warning(f"[kcc_llm] Adapter zip not found: {_ADAPTER_ZIP}")
        return False
    try:
        import zipfile
        _ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(str(_ADAPTER_ZIP), "r") as z:
            # Zip has inner folder kcc_adapter_after_chunk4/ β€” extract to flat _ADAPTER_DIR
            for member in z.namelist():
                parts = Path(member).parts
                if len(parts) > 1:
                    # Strip top-level folder
                    dest = _ADAPTER_DIR / Path(*parts[1:])
                else:
                    dest = _ADAPTER_DIR / member
                if member.endswith("/"):
                    dest.mkdir(parents=True, exist_ok=True)
                else:
                    dest.parent.mkdir(parents=True, exist_ok=True)
                    with z.open(member) as src, open(dest, "wb") as dst:
                        dst.write(src.read())
        logger.info(f"[kcc_llm] Adapter extracted β†’ {_ADAPTER_DIR}")
        return True
    except Exception as e:
        logger.error(f"[kcc_llm] Adapter extraction failed: {e}")
        return False


def _load_llama():
    """Load fine-tuned Llama model (called once, thread-safe)."""
    global _model, _tokenizer, _pipeline, _load_attempted, _load_ok

    with _load_lock:
        if _load_attempted:
            return _load_ok
        _load_attempted = True

        # Step 1: Extract adapter
        if not _extract_adapter():
            logger.warning("[kcc_llm] Adapter unavailable β€” using Gemini fallback")
            return False

        # Step 2: Check GPU
        try:
            import torch
            has_gpu = torch.cuda.is_available()
            if has_gpu:
                vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
                logger.info(f"[kcc_llm] GPU: {torch.cuda.get_device_name(0)} ({vram_gb:.1f}GB VRAM)")
                if vram_gb < 3.0:
                    logger.warning("[kcc_llm] <3GB VRAM β€” Llama may OOM. Trying anyway…")
        except Exception:
            has_gpu = False
            logger.info("[kcc_llm] No GPU detected β€” Llama will use CPU (slow)")

        # Step 3: Load model β€” try Unsloth first (fastest for pre-quantized model),
        # then fall back to standard transformers + peft
        try:
            # ── Path A: Unsloth FastLanguageModel (preferred) ─────────────────
            from unsloth import FastLanguageModel
            import torch
            logger.info(f"[kcc_llm] Loading via Unsloth: {_BASE_MODEL}")
            _model, _tokenizer = FastLanguageModel.from_pretrained(
                model_name=_BASE_MODEL,
                max_seq_length=2048,
                dtype=None,
                load_in_4bit=True,
            )
            from peft import PeftModel
            _model = PeftModel.from_pretrained(_model, str(_ADAPTER_DIR))
            FastLanguageModel.for_inference(_model)
            if _tokenizer.pad_token is None:
                _tokenizer.pad_token = _tokenizer.eos_token
            _load_ok = True
            logger.info("[kcc_llm] βœ… Fine-tuned Llama loaded via Unsloth")
            return True

        except ImportError:
            logger.info("[kcc_llm] Unsloth not available β€” trying standard transformers+peft")
        except Exception as e:
            logger.warning(f"[kcc_llm] Unsloth load failed ({e}) β€” trying transformers+peft")

        try:
            # ── Path B: Standard transformers + peft ──────────────────────────
            import torch
            from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
            from peft import PeftModel

            logger.info(f"[kcc_llm] Loading base: {_BASE_MODEL}")

            if has_gpu:
                quant_cfg = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4",
                )
                base_model = AutoModelForCausalLM.from_pretrained(
                    _BASE_MODEL, quantization_config=quant_cfg,
                    device_map="auto", trust_remote_code=True,
                )
            else:
                # CPU: load in fp32 (slower but works without CUDA)
                base_model = AutoModelForCausalLM.from_pretrained(
                    _BASE_MODEL, torch_dtype=torch.float32,
                    device_map="cpu", trust_remote_code=True,
                )

            logger.info("[kcc_llm] Merging LoRA adapter…")
            _model = PeftModel.from_pretrained(base_model, str(_ADAPTER_DIR))
            _model.eval()

            _tokenizer = AutoTokenizer.from_pretrained(str(_ADAPTER_DIR))
            if _tokenizer.pad_token is None:
                _tokenizer.pad_token = _tokenizer.eos_token

            _load_ok = True
            logger.info("[kcc_llm] βœ… Fine-tuned Llama loaded via transformers+peft")
            if has_gpu:
                vram_used = torch.cuda.memory_allocated() / 1e9
                logger.info(f"[kcc_llm] VRAM used: {vram_used:.1f}GB")
            return True

        except Exception as e:
            logger.error(f"[kcc_llm] Model load failed: {e}")
            logger.info("[kcc_llm] Falling back to Groq/Gemini")
            _model = None; _tokenizer = None; _load_ok = False
            return False


def _generate_llama(prompt: str, max_new_tokens: int = _MAX_NEW_TOKENS) -> str:
    """Generate text with fine-tuned Llama model."""
    import torch
    formatted = _alpaca_prompt(prompt)
    inputs = _tokenizer(
        formatted,
        return_tensors="pt",
        truncation=True,
        max_length=1800,    # Leave room for response in 2048 context
    ).to(_model.device)

    with torch.no_grad():
        output = _model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=_TEMPERATURE,
            do_sample=_DO_SAMPLE,
            top_p=0.9,
            repetition_penalty=1.1,
            pad_token_id=_tokenizer.eos_token_id,
            eos_token_id=_tokenizer.eos_token_id,
        )

    # Decode only the generated tokens (exclude prompt)
    generated = output[0][inputs["input_ids"].shape[1]:]
    text = _tokenizer.decode(generated, skip_special_tokens=True)
    return text.strip()


def _stream_llama(prompt: str, max_new_tokens: int = _MAX_NEW_TOKENS) -> Iterator[str]:
    """Stream text token by token from fine-tuned Llama."""
    import torch
    from transformers import TextIteratorStreamer
    import threading

    formatted = _alpaca_prompt(prompt)
    inputs = _tokenizer(
        formatted,
        return_tensors="pt",
        truncation=True,
        max_length=1800,
    ).to(_model.device)

    streamer = TextIteratorStreamer(
        _tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    gen_kwargs = {
        **inputs,
        "max_new_tokens":    max_new_tokens,
        "temperature":       _TEMPERATURE,
        "do_sample":         _DO_SAMPLE,
        "top_p":             0.9,
        "repetition_penalty": 1.1,
        "pad_token_id":      _tokenizer.eos_token_id,
        "eos_token_id":      _tokenizer.eos_token_id,
        "streamer":          streamer,
    }

    thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
    thread.start()

    for token in streamer:
        yield token

    thread.join()


# ── Gemini fallback ───────────────────────────────────────────────────────────

def _get_gemini_client():
    """Return cached Gemini client (re-uses step4_app's cache if possible)."""
    try:
        from step4_app import _get_gemini_client as _gc
        return _gc()
    except Exception:
        pass
    try:
        import config
        from google import genai
        return genai.Client(api_key=config.GEMINI_API_KEY)
    except Exception as e:
        logger.error(f"[kcc_llm] Gemini client failed: {e}")
        return None


def _generate_gemini(prompt: str) -> str:
    client = _get_gemini_client()
    if client is None:
        return "⚠️ Service temporarily unavailable. Please try again."
    try:
        import config
        result = client.models.generate_content(
            model=config.GEMINI_MODEL, contents=prompt
        )
        return result.text
    except Exception as e:
        logger.error(f"[kcc_llm] Gemini generate failed: {e}")
        return "⚠️ Service temporarily unavailable. Please try again."


def _stream_gemini(prompt: str) -> Iterator[str]:
    client = _get_gemini_client()
    if client is None:
        yield "⚠️ Service temporarily unavailable."; return
    try:
        import config
        for chunk in client.models.generate_content_stream(
            model=config.GEMINI_MODEL, contents=prompt
        ):
            if chunk.text:
                yield chunk.text
    except Exception as e:
        yield f"[Error: {e}]"


# ── Public API ────────────────────────────────────────────────────────────────

def generate(prompt: str, max_new_tokens: int = _MAX_NEW_TOKENS,
             prefer_llama: bool = True) -> str:
    """
    Generate a response. Uses fine-tuned Llama if available, else Gemini.

    Args:
        prompt: The full assembled prompt (system + context + question)
        max_new_tokens: Maximum response length
        prefer_llama: If False, always use Gemini (for debugging)

    Returns:
        Generated text string
    """
    if prefer_llama:
        # Lazy-load on first call
        if not _load_attempted:
            _load_llama()
        if _load_ok and _model is not None:
            try:
                return _generate_llama(prompt, max_new_tokens)
            except Exception as e:
                logger.error(f"[kcc_llm] Llama inference error: {e} β€” falling back to Gemini")

    return _generate_gemini(prompt)


def generate_stream(prompt: str, max_new_tokens: int = _MAX_NEW_TOKENS,
                    prefer_llama: bool = True) -> Iterator[str]:
    """
    Stream response tokens. Uses fine-tuned Llama if available, else Gemini.
    """
    if prefer_llama:
        if not _load_attempted:
            _load_llama()
        if _load_ok and _model is not None:
            try:
                yield from _stream_llama(prompt, max_new_tokens)
                return
            except Exception as e:
                logger.error(f"[kcc_llm] Llama stream error: {e} β€” falling back to Gemini")

    yield from _stream_gemini(prompt)


def model_info() -> dict:
    """Return info about the currently loaded model."""
    if _load_ok:
        return {
            "backend":      "llama",
            "model":        "Llama-3.2-3B-Instruct (KCC fine-tuned)",
            "adapter":      "kcc_adapter_after_chunk4",
            "quantization": "4-bit NF4 (QLoRA)",
            "adapter_path": str(_ADAPTER_DIR),
        }
    return {
        "backend": "gemini",
        "model":   "Gemini (fallback)",
        "reason":  "Llama not loaded" if _load_attempted else "Llama not yet attempted",
    }


# ── Background pre-load ───────────────────────────────────────────────────────
# Start loading the model in a background thread on import
# so it's ready by the time the first request comes in.
def _preload():
    """Pre-load model in background to reduce first-request latency."""
    try:
        _load_llama()
    except Exception as e:
        logger.error(f"[kcc_llm] Pre-load failed: {e}")

_preload_thread = threading.Thread(target=_preload, daemon=True, name="kcc_llm_preload")
_preload_thread.start()