fable5 / app.py
zhengr's picture
Update app.py
6d02e4d verified
Raw
History Blame Contribute Delete
4.94 kB
from __future__ import annotations
import json
import os
import sys
import urllib.parse
import urllib.request
from pathlib import Path
from huggingface_hub import hf_hub_download
MODEL_REPO = os.getenv("MODEL_REPO", "yuxinlu1/gemma-4-12B-agentic-fable5-composer2.5-v2-3.5x-tau2-GGUF")
MODEL_FILE = os.getenv("MODEL_FILE", "gemma4-v2-Q4_K_M.gguf")
MODEL_DIR = Path(os.getenv("MODEL_DIR", "/data/models/gemma4-coder"))
CHAT_TEMPLATE_FILE = Path(os.getenv("CHAT_TEMPLATE_FILE", "/data/models/gemma4-coder/chat_template.jinja"))
LLAMA_SERVER_BIN = os.getenv("LLAMA_SERVER_BIN", "/opt/llama.cpp/llama-server")
LLAMA_HOST = os.getenv("LLAMA_HOST", "0.0.0.0")
LLAMA_PORT = os.getenv("LLAMA_PORT", "7860")
THREADS = os.getenv("THREADS", "4")
CTX_SIZE = os.getenv("CTX_SIZE", "2048")
BATCH_SIZE = os.getenv("BATCH_SIZE", "default")
UBATCH_SIZE = os.getenv("UBATCH_SIZE", "default")
GPU_LAYERS = os.getenv("GPU_LAYERS", "0")
FLASH_ATTN = os.getenv("FLASH_ATTN", "default")
CACHE_TYPE_K = os.getenv("CACHE_TYPE_K", "default")
CACHE_TYPE_V = os.getenv("CACHE_TYPE_V", "default")
TEMPERATURE = os.getenv("TEMPERATURE", "0.2")
TOP_P = os.getenv("TOP_P", "0.95")
TOP_K = os.getenv("TOP_K", "64")
REPEAT_PENALTY = os.getenv("REPEAT_PENALTY", "1.08")
def log(message: str) -> None:
print(f"[startup] {message}", flush=True)
def download_model() -> str:
MODEL_DIR.mkdir(parents=True, exist_ok=True)
local_file = MODEL_DIR / MODEL_FILE
if local_file.exists():
log(f"Using cached model: {local_file}")
return str(local_file)
log(f"Downloading {MODEL_REPO}/{MODEL_FILE}")
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
local_dir=str(MODEL_DIR),
)
log(f"Model ready: {model_path}")
return model_path
def download_chat_template() -> str | None:
if CHAT_TEMPLATE_FILE.exists() and CHAT_TEMPLATE_FILE.stat().st_size > 0:
log(f"Using cached chat template: {CHAT_TEMPLATE_FILE}")
return str(CHAT_TEMPLATE_FILE)
encoded_repo = urllib.parse.quote(MODEL_REPO, safe="/")
api_url = f"https://huggingface.co/api/models/{encoded_repo}"
log("Fetching chat template from model metadata")
try:
with urllib.request.urlopen(api_url, timeout=30) as response:
metadata = json.loads(response.read().decode("utf-8"))
except Exception as exc:
log(f"Could not fetch chat template metadata: {exc}")
return None
template = (metadata.get("gguf") or {}).get("chat_template")
if not template:
log("No chat template found in model metadata; llama-server will use GGUF metadata")
return None
CHAT_TEMPLATE_FILE.parent.mkdir(parents=True, exist_ok=True)
CHAT_TEMPLATE_FILE.write_text(template, encoding="utf-8")
log(f"Chat template ready: {CHAT_TEMPLATE_FILE}")
return str(CHAT_TEMPLATE_FILE)
def build_command(model_path: str, template_path: str | None) -> list[str]:
def has_custom_value(value: str) -> bool:
return value.strip().lower() not in {"", "default", "auto", "none", "off"}
def add_optional_pair(flag: str, value: str) -> None:
if has_custom_value(value):
cmd.extend([flag, value])
cmd = [
LLAMA_SERVER_BIN,
"-m",
model_path,
"--host",
LLAMA_HOST,
"--port",
LLAMA_PORT,
"--threads",
THREADS,
"--ctx-size",
CTX_SIZE,
"--n-gpu-layers",
GPU_LAYERS,
"--parallel",
"1",
"--cont-batching",
"--temp",
TEMPERATURE,
"--top-p",
TOP_P,
"--top-k",
TOP_K,
"--repeat-penalty",
REPEAT_PENALTY,
]
add_optional_pair("--batch-size", BATCH_SIZE)
add_optional_pair("--ubatch-size", UBATCH_SIZE)
add_optional_pair("--cache-type-k", CACHE_TYPE_K)
add_optional_pair("--cache-type-v", CACHE_TYPE_V)
if has_custom_value(FLASH_ATTN):
cmd.extend(["-fa", FLASH_ATTN])
if template_path:
cmd.extend(["--chat-template-file", template_path])
return cmd
def main() -> None:
binary_dir = str(Path(LLAMA_SERVER_BIN).parent)
existing_library_path = os.environ.get("LD_LIBRARY_PATH")
os.environ["LD_LIBRARY_PATH"] = (
binary_dir if not existing_library_path else f"{binary_dir}:{existing_library_path}"
)
os.environ.setdefault("OMP_NUM_THREADS", THREADS)
os.environ.setdefault("OPENBLAS_NUM_THREADS", THREADS)
os.environ.setdefault("MKL_NUM_THREADS", THREADS)
model_path = download_model()
template_path = download_chat_template()
cmd = build_command(model_path, template_path)
log("Starting native llama.cpp web UI")
log(" ".join(cmd))
os.execvpe(cmd[0], cmd, os.environ)
if __name__ == "__main__":
try:
main()
except Exception as exc:
print(f"[fatal] {exc}", file=sys.stderr, flush=True)
raise