"""
scripts/sr4_server.py — SR4 Standalone Inference Server
Diadaptasi dari docs/aset/SR4_LLM_Coder_Test_Colab(SFT+CPT).ipynb
Jalankan:
python scripts/sr4_server.py --model /workspace/models/sr4 --port 8081
python scripts/sr4_server.py --model /workspace/models/sr4 --port 8081 --4bit
"""
import argparse
import gc
import json
import re
import sys
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="/workspace/models/sr4")
parser.add_argument("--port", type=int, default=8081)
parser.add_argument("--4bit", dest="load_4bit", action="store_true",
help="Load in 4-bit quantization (untuk GPU VRAM terbatas, misal L4 24GB)")
args = parser.parse_args()
print(f"[SR4] Loading model dari {args.model} ...", flush=True)
from unsloth import FastLanguageModel # noqa: E402 (import setelah argparse)
from unsloth.chat_templates import get_chat_template # noqa: E402
model, tokenizer = FastLanguageModel.from_pretrained(
model_name =args.model,
max_seq_length=4096,
dtype =None,
load_in_4bit =args.load_4bit,
)
FastLanguageModel.for_inference(model)
tokenizer = get_chat_template(tokenizer, chat_template="chatml")
_tok = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
used_gb = torch.cuda.memory_allocated() / 1e9
total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"[SR4] Model loaded — GPU {used_gb:.1f} / {total_gb:.1f} GB", flush=True)
app = FastAPI(title="SR4 Inference Server")
def _strip_thinking(text: str) -> str:
"""Hapus blok ... jika model mengeluarkannya."""
return re.sub(r".*?", "", text, flags=re.DOTALL).strip()
def _parse_json(text: str):
"""Coba ekstrak JSON object dari teks bebas."""
text = _strip_thinking(text)
text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.MULTILINE).strip()
text = re.sub(r",\s*([}\]])", r"\1", text)
first = text.find("{")
if first == -1:
return None
try:
obj, _ = json.JSONDecoder().raw_decode(text, first)
return obj
except Exception:
return None
@app.get("/health")
def health():
return {"status": "ok", "model": args.model}
@app.get("/v1/models")
def list_models():
return {
"object": "list",
"data": [{"id": "sr4-game", "object": "model"}],
}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
gc.collect()
torch.cuda.empty_cache()
body = await request.json()
messages_raw = body.get("messages", [])
max_tokens = body.get("max_tokens", 3500)
temperature = body.get("temperature", 0.0)
# Konversi content string JSON → dict (sesuai format training)
messages = []
for m in messages_raw:
content = m["content"]
if isinstance(content, str):
try:
content = json.loads(content)
except Exception:
pass
messages.append({"role": m["role"], "content": content})
text = _tok.apply_chat_template(
messages,
tokenize =False,
add_generation_prompt=True,
enable_thinking =False,
)
inputs = _tok(text, return_tensors="pt", add_special_tokens=False).to(model.device)
prompt_len = inputs["input_ids"].shape[1]
eos_ids = [_tok.eos_token_id]
im_end_id = _tok.convert_tokens_to_ids("<|im_end|>")
if im_end_id and im_end_id != _tok.eos_token_id:
eos_ids.append(im_end_id)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens =max_tokens,
do_sample =temperature > 0,
temperature =temperature if temperature > 0 else None,
repetition_penalty=1.15,
pad_token_id =_tok.eos_token_id,
eos_token_id =eos_ids,
)
new_tokens = outputs[0][prompt_len:]
raw_text = _tok.decode(new_tokens, skip_special_tokens=True).strip()
completion_len = len(new_tokens)
parsed = _parse_json(raw_text)
response_text = json.dumps(parsed, ensure_ascii=False) if parsed else raw_text
return JSONResponse({
"id": "chatcmpl-sr4",
"object": "chat.completion",
"model": "sr4-game",
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": response_text},
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": prompt_len,
"completion_tokens": completion_len,
"total_tokens": prompt_len + completion_len,
},
})
if __name__ == "__main__":
print(f"[SR4] Serving on http://0.0.0.0:{args.port}", flush=True)
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="warning")