File size: 7,267 Bytes
e003508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5991e70
e003508
5991e70
 
 
 
 
 
 
 
 
 
e003508
 
5991e70
 
e003508
 
 
5991e70
 
 
e003508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# gguf_engine.py

import os
os.environ["LLAMA_LOG_LEVEL"] = "0"     # suppresses llama.cpp C-level logging
os.environ["GGML_LOG_LEVEL"]  = "0"     # suppresses ggml backend warnings



import os, sys, time, base64, io, re
from contextlib import contextmanager
from functools import lru_cache
from PIL import Image
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler

# ==========================================
# PATHS — update these to match your setup
# ==========================================
import os
from huggingface_hub import hf_hub_download

# ==========================================
# DYNAMIC MODEL DOWNLOADER (Bridge to Model Repo)
# ==========================================
MODEL_REPO = "shrishSVaidya/VAJRAM-Models" # Change this if you named it differently

print("Fetching models from Hugging Face Hub (this takes a moment on first boot)...")

# hf_hub_download pulls the file into the Space's local cache. 
# If it's already downloaded, it loads instantly in 0.01 seconds!
LLM_PATH = hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/medgemma_q4km.gguf")

VISION_MMPROJ_PATHS = {
    "module3": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/medgemma_Bone_marrow_vision.gguf"),
    "base":    hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/medgemma_vision_base.gguf"),
}

LLM_LORA_PATHS = {
    "module2": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/lora_module2.gguf"),
    "module3": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/lora_module3.gguf"),
    "module4": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/lora_module4.gguf"),
    "default": None,
}

N_THREADS = os.cpu_count() or 8
N_CTX     = 2048
N_BATCH   = 512


# ==========================================
# STDERR SUPPRESSOR
# llama.cpp prints CPU_REPACK fallback warnings
# to stderr on every model load. These are
# harmless — suppress them to keep logs clean.
# ==========================================
@contextmanager
def _suppress_stderr():
    """
    Redirects C-level stderr to /dev/null during llama.cpp model loading.
    Uses os.dup2 so it catches output from the native C library, not just Python.
    """
    stderr_fd   = sys.stderr.fileno()
    saved_fd    = os.dup(stderr_fd)
    devnull_fd  = os.open(os.devnull, os.O_WRONLY)
    try:
        os.dup2(devnull_fd, stderr_fd)
        yield
    finally:
        os.dup2(saved_fd, stderr_fd)
        os.close(saved_fd)
        os.close(devnull_fd)


# ==========================================
# TEXT MODEL LOADER
# lru_cache(maxsize=2): keeps the 2 most
# recently used adapter instances in RAM.
# use_mmap=True: all instances share the same
# physical RAM pages for base weights.
# ==========================================
@lru_cache(maxsize=2)
def _load_text_model(adapter_name: str) -> Llama:
    lora_path = LLM_LORA_PATHS.get(adapter_name, None)
    print(f"  [GGUF] Loading text model | adapter={adapter_name} | lora={lora_path}")
    _t = time.time()

    with _suppress_stderr():
        model = Llama(
            model_path = LLM_PATH,
            lora_path  = lora_path,
            lora_scale = 1.0,
            n_ctx      = N_CTX,
            n_batch    = N_BATCH,
            n_threads  = N_THREADS,
            use_mmap   = True,
            verbose    = False,
        )

    print(f"  [GGUF] Text model ready in {time.time()-_t:.1f}s")
    return model


# ==========================================
# VISION MODEL LOADER
# One mmproj per vision domain. lru_cache(1)
# keeps the most recently used vision model hot.
# ==========================================
@lru_cache(maxsize=1)
def _load_vision_model(vision_module: str) -> Llama:
    mmproj_path = VISION_MMPROJ_PATHS.get(vision_module, VISION_MMPROJ_PATHS["base"])
    lora_path   = LLM_LORA_PATHS.get(vision_module, None)

    print(f"  [GGUF] Loading vision model | module={vision_module}")
    print(f"         mmproj : {mmproj_path}")
    print(f"         lm lora: {lora_path}")
    _t = time.time()

    with _suppress_stderr():
        chat_handler = Llava15ChatHandler(clip_model_path=mmproj_path, verbose=False)
        model = Llama(
            model_path   = LLM_PATH,
            lora_path    = lora_path,
            lora_scale   = 1.0,
            chat_handler = chat_handler,
            n_ctx        = 1024,
            n_batch      = N_BATCH,
            n_threads    = N_THREADS,
            use_mmap     = True,
            verbose      = False,
        )

    print(f"  [GGUF] Vision model ready in {time.time()-_t:.1f}s")
    return model


# ==========================================
# PROMPT FORMATTING
# ==========================================
def _format_gemma_prompt(user_text: str) -> str:
    """Gemma instruct template — no leading <bos> to avoid duplication warning."""
    return (
        "<start_of_turn>user\n"
        f"{user_text.strip()}"
        "<end_of_turn>\n"
        "<start_of_turn>model\n"
    )

def _pil_to_data_uri(image: Image.Image) -> str:
    buf = io.BytesIO()
    image.save(buf, format="PNG")
    return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"

def exclude_thinking_component(text: str) -> str:
    clean = re.sub(r"<unused94>.*?<unused95>", "", text, flags=re.DOTALL)
    clean = re.sub(r"<unused94>.*",            "", clean, flags=re.DOTALL)
    return clean.strip()


# ==========================================
# PUBLIC API — drop-in replacements for the
# original HuggingFace inference functions
# ==========================================
def generate_with_adapter(prompt: str, adapter_name: str, max_tokens: int = 150) -> str:
    """Text-only inference. Same signature as original HuggingFace version."""
    key   = adapter_name if adapter_name in LLM_LORA_PATHS else "default"
    model = _load_text_model(key)

    _t = time.time()
    output = model(
        _format_gemma_prompt(prompt),
        max_tokens  = max_tokens,
        stop        = ["<end_of_turn>", "<eos>"],
        echo        = False,
        temperature = 0.0,
        top_p       = 1.0,
    )
    elapsed = time.time() - _t
    raw     = output["choices"][0]["text"].strip()
    cleaned = exclude_thinking_component(raw)
    tokens  = output["usage"]["completion_tokens"]
    print(f"  [GGUF] {key} | {elapsed:.2f}s | {tokens} tok | {tokens/max(elapsed,0.01):.1f} tok/s")
    return cleaned


def generate_with_adapter_vision(
    image: Image.Image,
    prompt: str,
    adapter_name: str,
    max_tokens: int = 10,
) -> str:
    """Vision + text inference. Same signature as original HuggingFace version."""
    model    = _load_vision_model(adapter_name)
    data_uri = _pil_to_data_uri(image)

    _t = time.time()
    output = model.create_chat_completion(
        messages=[{"role": "user", "content": [
            {"type": "image_url", "image_url": {"url": data_uri}},
            {"type": "text",      "text": prompt},
        ]}],
        max_tokens  = max_tokens,
        temperature = 0.0,
    )
    elapsed = time.time() - _t
    raw     = output["choices"][0]["message"]["content"].strip()
    cleaned = exclude_thinking_component(raw)
    print(f"  [GGUF Vision] {adapter_name} | {elapsed:.2f}s | output: {cleaned}")
    return cleaned