Project_VAJRAM / gguf_engine.py
shrishSVaidya's picture
Adding the openBLAS
6301fdd
Raw
History Blame Contribute Delete
7.42 kB
# gguf_engine.py
import os
os.environ["LLAMA_LOG_LEVEL"] = "0" # suppresses llama.cpp C-level logging
os.environ["GGML_LOG_LEVEL"] = "0" # suppresses ggml backend warnings
import os, sys, time, base64, io, re
from contextlib import contextmanager
from functools import lru_cache
from PIL import Image
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler
# ==========================================
# PATHS — update these to match your setup
# ==========================================
import os
from huggingface_hub import hf_hub_download
# ==========================================
# DYNAMIC MODEL DOWNLOADER (Bridge to Model Repo)
# ==========================================
MODEL_REPO = "shrishSVaidya/VAJRAM-Models" # Change this if you named it differently
print("Fetching models from Hugging Face Hub (this takes a moment on first boot)...")
# hf_hub_download pulls the file into the Space's local cache.
# If it's already downloaded, it loads instantly in 0.01 seconds!
LLM_PATH = hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/medgemma_q4km.gguf")
VISION_MMPROJ_PATHS = {
"module3": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/medgemma_Bone_marrow_vision.gguf"),
"base": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/medgemma_vision_base.gguf"),
}
LLM_LORA_PATHS = {
"module2": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/lora_module2.gguf"),
"module3": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/lora_module3.gguf"),
"module4": hf_hub_download(repo_id=MODEL_REPO, filename="gguf_models/lora_module4.gguf"),
"default": None,
}
N_THREADS = os.cpu_count() or 8
N_CTX = 2048
N_BATCH = 512
# ==========================================
# STDERR SUPPRESSOR
# llama.cpp prints CPU_REPACK fallback warnings
# to stderr on every model load. These are
# harmless — suppress them to keep logs clean.
# ==========================================
@contextmanager
def _suppress_stderr():
"""
Redirects C-level stderr to /dev/null during llama.cpp model loading.
Uses os.dup2 so it catches output from the native C library, not just Python.
"""
stderr_fd = sys.stderr.fileno()
saved_fd = os.dup(stderr_fd)
devnull_fd = os.open(os.devnull, os.O_WRONLY)
try:
os.dup2(devnull_fd, stderr_fd)
yield
finally:
os.dup2(saved_fd, stderr_fd)
os.close(saved_fd)
os.close(devnull_fd)
# ==========================================
# TEXT MODEL LOADER
# lru_cache(maxsize=2): keeps the 2 most
# recently used adapter instances in RAM.
# use_mmap=True: all instances share the same
# physical RAM pages for base weights.
# ==========================================
@lru_cache(maxsize=2)
def _load_text_model(adapter_name: str) -> Llama:
lora_path = LLM_LORA_PATHS.get(adapter_name, None)
print(f" [GGUF] Loading text model | adapter={adapter_name} | lora={lora_path}")
_t = time.time()
with _suppress_stderr():
model = Llama(
model_path = LLM_PATH,
lora_path = lora_path,
lora_scale = 1.0,
n_ctx = N_CTX,
n_batch = N_BATCH,
n_threads = N_THREADS,
n_threads_batch=2, # Crucial for OpenBLAS prompt evaluation
use_mmap = True,
verbose = False,
)
print(f" [GGUF] Text model ready in {time.time()-_t:.1f}s")
return model
# ==========================================
# VISION MODEL LOADER
# One mmproj per vision domain. lru_cache(1)
# keeps the most recently used vision model hot.
# ==========================================
@lru_cache(maxsize=1)
def _load_vision_model(vision_module: str) -> Llama:
mmproj_path = VISION_MMPROJ_PATHS.get(vision_module, VISION_MMPROJ_PATHS["base"])
lora_path = LLM_LORA_PATHS.get(vision_module, None)
print(f" [GGUF] Loading vision model | module={vision_module}")
print(f" mmproj : {mmproj_path}")
print(f" lm lora: {lora_path}")
_t = time.time()
with _suppress_stderr():
chat_handler = Llava15ChatHandler(clip_model_path=mmproj_path, verbose=False)
model = Llama(
model_path = LLM_PATH,
lora_path = lora_path,
lora_scale = 1.0,
chat_handler = chat_handler,
n_ctx = 1024,
n_batch = N_BATCH,
n_threads = N_THREADS,
n_threads_batch=2, # Crucial for OpenBLAS prompt evaluation
use_mmap = True,
verbose = False,
)
print(f" [GGUF] Vision model ready in {time.time()-_t:.1f}s")
return model
# ==========================================
# PROMPT FORMATTING
# ==========================================
def _format_gemma_prompt(user_text: str) -> str:
"""Gemma instruct template — no leading <bos> to avoid duplication warning."""
return (
"<start_of_turn>user\n"
f"{user_text.strip()}"
"<end_of_turn>\n"
"<start_of_turn>model\n"
)
def _pil_to_data_uri(image: Image.Image) -> str:
buf = io.BytesIO()
image.save(buf, format="PNG")
return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
def exclude_thinking_component(text: str) -> str:
clean = re.sub(r"<unused94>.*?<unused95>", "", text, flags=re.DOTALL)
clean = re.sub(r"<unused94>.*", "", clean, flags=re.DOTALL)
return clean.strip()
# ==========================================
# PUBLIC API — drop-in replacements for the
# original HuggingFace inference functions
# ==========================================
def generate_with_adapter(prompt: str, adapter_name: str, max_tokens: int = 150) -> str:
"""Text-only inference. Same signature as original HuggingFace version."""
key = adapter_name if adapter_name in LLM_LORA_PATHS else "default"
model = _load_text_model(key)
_t = time.time()
output = model(
_format_gemma_prompt(prompt),
max_tokens = max_tokens,
stop = ["<end_of_turn>", "<eos>"],
echo = False,
temperature = 0.0,
top_p = 1.0,
)
elapsed = time.time() - _t
raw = output["choices"][0]["text"].strip()
cleaned = exclude_thinking_component(raw)
tokens = output["usage"]["completion_tokens"]
print(f" [GGUF] {key} | {elapsed:.2f}s | {tokens} tok | {tokens/max(elapsed,0.01):.1f} tok/s")
return cleaned
def generate_with_adapter_vision(
image: Image.Image,
prompt: str,
adapter_name: str,
max_tokens: int = 10,
) -> str:
"""Vision + text inference. Same signature as original HuggingFace version."""
model = _load_vision_model(adapter_name)
data_uri = _pil_to_data_uri(image)
_t = time.time()
output = model.create_chat_completion(
messages=[{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": data_uri}},
{"type": "text", "text": prompt},
]}],
max_tokens = max_tokens,
temperature = 0.0,
)
elapsed = time.time() - _t
raw = output["choices"][0]["message"]["content"].strip()
cleaned = exclude_thinking_component(raw)
print(f" [GGUF Vision] {adapter_name} | {elapsed:.2f}s | output: {cleaned}")
return cleaned