File size: 4,938 Bytes
6e23cd8
 
 
 
 
 
 
 
 
 
 
 
6d02e4d
 
6e23cd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import annotations

import json
import os
import sys
import urllib.parse
import urllib.request
from pathlib import Path

from huggingface_hub import hf_hub_download


MODEL_REPO = os.getenv("MODEL_REPO", "yuxinlu1/gemma-4-12B-agentic-fable5-composer2.5-v2-3.5x-tau2-GGUF")
MODEL_FILE = os.getenv("MODEL_FILE", "gemma4-v2-Q4_K_M.gguf")
MODEL_DIR = Path(os.getenv("MODEL_DIR", "/data/models/gemma4-coder"))
CHAT_TEMPLATE_FILE = Path(os.getenv("CHAT_TEMPLATE_FILE", "/data/models/gemma4-coder/chat_template.jinja"))

LLAMA_SERVER_BIN = os.getenv("LLAMA_SERVER_BIN", "/opt/llama.cpp/llama-server")
LLAMA_HOST = os.getenv("LLAMA_HOST", "0.0.0.0")
LLAMA_PORT = os.getenv("LLAMA_PORT", "7860")

THREADS = os.getenv("THREADS", "4")
CTX_SIZE = os.getenv("CTX_SIZE", "2048")
BATCH_SIZE = os.getenv("BATCH_SIZE", "default")
UBATCH_SIZE = os.getenv("UBATCH_SIZE", "default")
GPU_LAYERS = os.getenv("GPU_LAYERS", "0")
FLASH_ATTN = os.getenv("FLASH_ATTN", "default")
CACHE_TYPE_K = os.getenv("CACHE_TYPE_K", "default")
CACHE_TYPE_V = os.getenv("CACHE_TYPE_V", "default")

TEMPERATURE = os.getenv("TEMPERATURE", "0.2")
TOP_P = os.getenv("TOP_P", "0.95")
TOP_K = os.getenv("TOP_K", "64")
REPEAT_PENALTY = os.getenv("REPEAT_PENALTY", "1.08")


def log(message: str) -> None:
    print(f"[startup] {message}", flush=True)


def download_model() -> str:
    MODEL_DIR.mkdir(parents=True, exist_ok=True)
    local_file = MODEL_DIR / MODEL_FILE
    if local_file.exists():
        log(f"Using cached model: {local_file}")
        return str(local_file)

    log(f"Downloading {MODEL_REPO}/{MODEL_FILE}")
    model_path = hf_hub_download(
        repo_id=MODEL_REPO,
        filename=MODEL_FILE,
        local_dir=str(MODEL_DIR),
    )
    log(f"Model ready: {model_path}")
    return model_path


def download_chat_template() -> str | None:
    if CHAT_TEMPLATE_FILE.exists() and CHAT_TEMPLATE_FILE.stat().st_size > 0:
        log(f"Using cached chat template: {CHAT_TEMPLATE_FILE}")
        return str(CHAT_TEMPLATE_FILE)

    encoded_repo = urllib.parse.quote(MODEL_REPO, safe="/")
    api_url = f"https://huggingface.co/api/models/{encoded_repo}"
    log("Fetching chat template from model metadata")

    try:
        with urllib.request.urlopen(api_url, timeout=30) as response:
            metadata = json.loads(response.read().decode("utf-8"))
    except Exception as exc:
        log(f"Could not fetch chat template metadata: {exc}")
        return None

    template = (metadata.get("gguf") or {}).get("chat_template")
    if not template:
        log("No chat template found in model metadata; llama-server will use GGUF metadata")
        return None

    CHAT_TEMPLATE_FILE.parent.mkdir(parents=True, exist_ok=True)
    CHAT_TEMPLATE_FILE.write_text(template, encoding="utf-8")
    log(f"Chat template ready: {CHAT_TEMPLATE_FILE}")
    return str(CHAT_TEMPLATE_FILE)


def build_command(model_path: str, template_path: str | None) -> list[str]:
    def has_custom_value(value: str) -> bool:
        return value.strip().lower() not in {"", "default", "auto", "none", "off"}

    def add_optional_pair(flag: str, value: str) -> None:
        if has_custom_value(value):
            cmd.extend([flag, value])

    cmd = [
        LLAMA_SERVER_BIN,
        "-m",
        model_path,
        "--host",
        LLAMA_HOST,
        "--port",
        LLAMA_PORT,
        "--threads",
        THREADS,
        "--ctx-size",
        CTX_SIZE,
        "--n-gpu-layers",
        GPU_LAYERS,
        "--parallel",
        "1",
        "--cont-batching",
        "--temp",
        TEMPERATURE,
        "--top-p",
        TOP_P,
        "--top-k",
        TOP_K,
        "--repeat-penalty",
        REPEAT_PENALTY,
    ]

    add_optional_pair("--batch-size", BATCH_SIZE)
    add_optional_pair("--ubatch-size", UBATCH_SIZE)
    add_optional_pair("--cache-type-k", CACHE_TYPE_K)
    add_optional_pair("--cache-type-v", CACHE_TYPE_V)
    if has_custom_value(FLASH_ATTN):
        cmd.extend(["-fa", FLASH_ATTN])

    if template_path:
        cmd.extend(["--chat-template-file", template_path])

    return cmd


def main() -> None:
    binary_dir = str(Path(LLAMA_SERVER_BIN).parent)
    existing_library_path = os.environ.get("LD_LIBRARY_PATH")
    os.environ["LD_LIBRARY_PATH"] = (
        binary_dir if not existing_library_path else f"{binary_dir}:{existing_library_path}"
    )

    os.environ.setdefault("OMP_NUM_THREADS", THREADS)
    os.environ.setdefault("OPENBLAS_NUM_THREADS", THREADS)
    os.environ.setdefault("MKL_NUM_THREADS", THREADS)

    model_path = download_model()
    template_path = download_chat_template()
    cmd = build_command(model_path, template_path)

    log("Starting native llama.cpp web UI")
    log(" ".join(cmd))
    os.execvpe(cmd[0], cmd, os.environ)


if __name__ == "__main__":
    try:
        main()
    except Exception as exc:
        print(f"[fatal] {exc}", file=sys.stderr, flush=True)
        raise