""" scripts/sr4_server.py — SR4 Standalone Inference Server Diadaptasi dari docs/aset/SR4_LLM_Coder_Test_Colab(SFT+CPT).ipynb Jalankan: python scripts/sr4_server.py --model /workspace/models/sr4 --port 8081 python scripts/sr4_server.py --model /workspace/models/sr4 --port 8081 --4bit """ import argparse import gc import json import re import sys import torch import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse parser = argparse.ArgumentParser() parser.add_argument("--model", default="/workspace/models/sr4") parser.add_argument("--port", type=int, default=8081) parser.add_argument("--4bit", dest="load_4bit", action="store_true", help="Load in 4-bit quantization (untuk GPU VRAM terbatas, misal L4 24GB)") args = parser.parse_args() print(f"[SR4] Loading model dari {args.model} ...", flush=True) from unsloth import FastLanguageModel # noqa: E402 (import setelah argparse) from unsloth.chat_templates import get_chat_template # noqa: E402 model, tokenizer = FastLanguageModel.from_pretrained( model_name =args.model, max_seq_length=4096, dtype =None, load_in_4bit =args.load_4bit, ) FastLanguageModel.for_inference(model) tokenizer = get_chat_template(tokenizer, chat_template="chatml") _tok = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer used_gb = torch.cuda.memory_allocated() / 1e9 total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 print(f"[SR4] Model loaded — GPU {used_gb:.1f} / {total_gb:.1f} GB", flush=True) app = FastAPI(title="SR4 Inference Server") def _strip_thinking(text: str) -> str: """Hapus blok ... jika model mengeluarkannya.""" return re.sub(r".*?", "", text, flags=re.DOTALL).strip() def _parse_json(text: str): """Coba ekstrak JSON object dari teks bebas.""" text = _strip_thinking(text) text = re.sub(r"^```(?:json)?\s*|\s*```$", "", text, flags=re.MULTILINE).strip() text = re.sub(r",\s*([}\]])", r"\1", text) first = text.find("{") if first == -1: return None try: obj, _ = json.JSONDecoder().raw_decode(text, first) return obj except Exception: return None @app.get("/health") def health(): return {"status": "ok", "model": args.model} @app.get("/v1/models") def list_models(): return { "object": "list", "data": [{"id": "sr4-game", "object": "model"}], } @app.post("/v1/chat/completions") async def chat_completions(request: Request): gc.collect() torch.cuda.empty_cache() body = await request.json() messages_raw = body.get("messages", []) max_tokens = body.get("max_tokens", 3500) temperature = body.get("temperature", 0.0) # Konversi content string JSON → dict (sesuai format training) messages = [] for m in messages_raw: content = m["content"] if isinstance(content, str): try: content = json.loads(content) except Exception: pass messages.append({"role": m["role"], "content": content}) text = _tok.apply_chat_template( messages, tokenize =False, add_generation_prompt=True, enable_thinking =False, ) inputs = _tok(text, return_tensors="pt", add_special_tokens=False).to(model.device) prompt_len = inputs["input_ids"].shape[1] eos_ids = [_tok.eos_token_id] im_end_id = _tok.convert_tokens_to_ids("<|im_end|>") if im_end_id and im_end_id != _tok.eos_token_id: eos_ids.append(im_end_id) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens =max_tokens, do_sample =temperature > 0, temperature =temperature if temperature > 0 else None, repetition_penalty=1.15, pad_token_id =_tok.eos_token_id, eos_token_id =eos_ids, ) new_tokens = outputs[0][prompt_len:] raw_text = _tok.decode(new_tokens, skip_special_tokens=True).strip() completion_len = len(new_tokens) parsed = _parse_json(raw_text) response_text = json.dumps(parsed, ensure_ascii=False) if parsed else raw_text return JSONResponse({ "id": "chatcmpl-sr4", "object": "chat.completion", "model": "sr4-game", "choices": [{ "index": 0, "message": {"role": "assistant", "content": response_text}, "finish_reason": "stop", }], "usage": { "prompt_tokens": prompt_len, "completion_tokens": completion_len, "total_tokens": prompt_len + completion_len, }, }) if __name__ == "__main__": print(f"[SR4] Serving on http://0.0.0.0:{args.port}", flush=True) uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="warning")