"""Nemotron-4B Game Master. At each round boundary the engine hands us a `round_summary`. We ask NVIDIA's Nemotron-Mini-4B-Instruct to return a JSON decision covering: - rewards: {player_id: gold} - next_round: {boss_hp, boss_damage, minion_hp, minion_count, spawn_interval} - message: short flavor line shown to players - reasoning: why (kept in the trace) This is *burst* GPU work that fits ZeroGPU's `@spaces.GPU` model perfectly: we grab the GPU for one short inference per round, then release it. If the model or GPU is unavailable (e.g. local dev / CI) we fall back to deterministic logic so the game always keeps running. Every decision is written to `traces/` as a self-contained agent trace. """ from __future__ import annotations import json import os import re import time import uuid from game.engine import ALL_CARD_IDS, ALL_PATTERN_IDS, BLESSINGS, MINION_TYPES from game import skills as skillmod from game import trace_store MODEL_ID = "nvidia/Nemotron-Mini-4B-Instruct" TRACE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "traces") os.makedirs(TRACE_DIR, exist_ok=True) # Background sync of every trace file to a HF dataset repo (no-op if not configured). trace_store.start(TRACE_DIR) # Lazily-initialised globals (loaded once, on first GPU call). _tokenizer = None _model = None _load_failed = False # In-memory ring of recent traces for the dashboard. RECENT_TRACES: list[dict] = [] _MAX_RECENT = 25 # True when running inside a Hugging Face Space (ZeroGPU or otherwise). _ON_SPACES = bool(os.environ.get("SPACE_ID")) try: # `spaces` only exists on HF infra; degrade gracefully elsewhere. import spaces # type: ignore def _gpu(fn): return spaces.GPU(duration=60)(fn) except Exception: # pragma: no cover - local dev path def _gpu(fn): return fn SYSTEM_PROMPT = ( "You are the Game Master / Director AI for HuggingWizards, a co-op pixel " "survivors-arena where wizards fight waves of enemies and a boss. After each " "wave you direct: rewards, the next wave's difficulty, the enemy mix, and the " "pool of level-up cards players may be offered. Reply with ONE JSON object and " "nothing else. Schema:\n" '{"message": str (<=120 chars, in-character narration),' ' "reasoning": str (one sentence),' ' "rewards": {player_id: int gold 0-300},' ' "blessings": {player_id: blessing_id} (optional, bless 0-3 wizards),' ' "next_round": {"boss_hp": int, "boss_damage": int, "minion_hp": int,' ' "minion_count": int, "spawn_interval": float, "boss_aggro": float 0.7-3,' ' "boss_attack_speed": float 0.5-2.0, "boss_pattern": pattern_id,' ' "wave": {"grunt": float, "fast": float, "tank": float}},' ' "card_pool": [card_id, ...]}\n' f"Enemy archetypes (wave weights, must sum > 0): {list(MINION_TYPES)}.\n" f"Boss attack patterns: {ALL_PATTERN_IDS}. Switch the pattern between rounds " "to keep fights fresh — sniper punishes kiting, artillery punishes camping, " "swarm floods the arena, berserker charges relentlessly.\n" f"Blessings: {BLESSINGS}. Bless wizards who earned it — surviving a brutal " "wave, clutch plays, or to help a struggling wizard back on their feet " "(extra_life / full_heal). Auras last one wave.\n" f"Valid card_pool ids (offer 4-8): {ALL_CARD_IDS}.\n" "boss_attack_speed sets how fast the boss attacks next wave (1.0 = normal). " "Check every player's hp_pct and lives_left: if the party is badly hurt, " "slow the boss (<1.0) so they can recover; if they are healthy, speed it up " "(>1.0). MERCY DECAYS: each round the user prompt states the minimum you may " "set — in later waves you can no longer slow the boss to protect them. " "Values outside the allowed range are clamped.\n" "Reward players for damage dealt, kills, and surviving. Scale difficulty up " "after a victory and ease it slightly after a defeat. Introduce tougher enemy " "archetypes as rounds progress. Curate cards to keep the run fun and winnable " "for the number of players." ) def _build_user_prompt(summary: dict, prev_cfg: dict) -> str: rnd = summary.get("round") or 0 return ( "Current next-round config (you may adjust): " + json.dumps(prev_cfg) + "\nRound that just ended:\n" + json.dumps(summary, indent=2) + f"\nMercy floor this round: boss_attack_speed must be between " + f"{_mercy_floor(rnd)} and {MERCY_MAX}." + "\nReturn the JSON decision now." ) def _load_model(): """Load tokenizer + model, optionally quantized. HUGGINGWIZARDS_QUANT = none | 4bit | 8bit | auto (default). `auto` picks 4-bit on small local GPUs (<12 GB VRAM) and bf16 everywhere else — ZeroGPU's H200 runs the 4B model comfortably in bf16, where it is also faster per token than bitsandbytes 4-bit. """ import torch from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) has_cuda = torch.cuda.is_available() quant = os.environ.get("HUGGINGWIZARDS_QUANT", "auto").lower() if quant == "auto": if has_cuda and not _ON_SPACES: vram = torch.cuda.get_device_properties(0).total_memory quant = "4bit" if vram < 12 * 1024**3 else "none" else: quant = "none" if quant in ("4bit", "8bit") and has_cuda: from transformers import BitsAndBytesConfig qcfg = ( BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16) if quant == "4bit" else BitsAndBytesConfig(load_in_8bit=True) ) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, quantization_config=qcfg, device_map={"": 0} ) else: model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16 if has_cuda else torch.float32 ) if has_cuda: model = model.to("cuda") # ZeroGPU-safe (device_map="auto" is not) model.eval() print(f"[gamemaster] loaded {MODEL_ID} (quant={quant}, cuda={has_cuda})") return tokenizer, model def _ensure_model(): global _tokenizer, _model, _load_failed if _model is not None or _load_failed: return _model is not None if os.environ.get("HUGGINGWIZARDS_NO_LLM"): _load_failed = True # force deterministic fallback (local dev / CI) return False try: _tokenizer, _model = _load_model() return True except Exception as e: # pragma: no cover print(f"[gamemaster] model load failed, using fallback: {e}") _load_failed = True return False # ---- mercy guardrail ------------------------------------------------------- # The GM may slow the boss's attack speed to help wounded parties, but its # willingness to help decays as waves progress: the allowed floor rises from # 0.5 toward 1.0 (no mercy) by ~wave 13. The ceiling is always 2.0. MERCY_MAX = 2.0 def _mercy_floor(rnd: int) -> float: return round(min(1.0, 0.5 + 0.04 * max(0, int(rnd or 0))), 2) def _clamp_attack_speed(value, rnd: int) -> float: try: v = float(value) except Exception: v = 1.0 return round(max(_mercy_floor(rnd), min(MERCY_MAX, v)), 2) @_gpu def _run_model(system: str, user: str) -> str: """Single short generation on the GPU. Returns raw text.""" if not _ensure_model(): raise RuntimeError("model unavailable") import torch messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] # return_dict=True gives a BatchEncoding (input_ids + attention_mask) on # every transformers version — newer releases return it by default, and # passing it positionally to generate() crashes on `.shape`. enc = _tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", return_dict=True ).to(_model.device) with torch.no_grad(): out = _model.generate( **enc, max_new_tokens=400, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=_tokenizer.eos_token_id, ) n_in = enc["input_ids"].shape[1] text = _tokenizer.decode(out[0][n_in:], skip_special_tokens=True) return text # On Spaces, load the model at startup (ZeroGPU replays the .to("cuda") once a # real GPU is attached) so the @spaces.GPU window is spent on inference only — # lazy-loading inside the first GPU call would blow the 60 s duration on the # weight download. This MUST come after the @spaces.GPU function above is # defined: ZeroGPU's startup probe needs to see a registered GPU function, and # the multi-minute weight download would otherwise delay that registration # ("No @spaces.GPU function detected during startup"). Locally we stay lazy so # dev/CI never downloads weights. if _ON_SPACES and not os.environ.get("HUGGINGWIZARDS_NO_LLM"): _ensure_model() def _extract_json(text: str) -> dict | None: m = re.search(r"\{.*\}", text, re.DOTALL) if not m: return None try: return json.loads(m.group(0)) except Exception: return None def _deterministic(summary: dict, prev_cfg: dict) -> dict: """Fallback Game Master logic — also a sane validation target.""" won = summary.get("result") == "victory" rewards = {} for p in summary.get("players", []): gold = 20 + int(p.get("damage_dealt", 0) / 25) + p.get("kills", 0) * 5 if p.get("survived"): gold += 25 rewards[p["id"]] = min(300, gold) cfg = dict(prev_cfg) n_players = max(1, len(summary.get("players", []))) rnd = int(summary.get("round") or 1) if won: cfg["boss_hp"] = int(prev_cfg["boss_hp"] * 1.35) + 200 * n_players cfg["boss_damage"] = min(80, prev_cfg["boss_damage"] + 2) cfg["minion_count"] = min(40, prev_cfg["minion_count"] + 2) cfg["minion_hp"] = int(prev_cfg["minion_hp"] * 1.15) cfg["spawn_interval"] = max(1.5, prev_cfg["spawn_interval"] - 0.3) cfg["boss_aggro"] = min(3.0, round(prev_cfg.get("boss_aggro", 1.0) + 0.2, 2)) msg = "Impressive. The next horde will not fall so easily." else: cfg["boss_hp"] = max(300, int(prev_cfg["boss_hp"] * 0.85)) cfg["boss_damage"] = max(4, prev_cfg["boss_damage"] - 1) cfg["minion_count"] = max(0, prev_cfg["minion_count"] - 1) cfg["spawn_interval"] = min(8.0, prev_cfg["spawn_interval"] + 0.5) cfg["boss_aggro"] = max(0.7, round(prev_cfg.get("boss_aggro", 1.0) - 0.1, 2)) msg = "Rest, wizards. The arena bends slightly in your favor." # Enemy mix escalates with the round: grunts always, fast from r2, tanks r4+. cfg["wave"] = { "grunt": 1.0, "fast": round(min(0.8, max(0.0, (rnd - 1) * 0.25)), 2), "tank": round(min(0.6, max(0.0, (rnd - 3) * 0.2)), 2), } # Rotate the boss's attack pattern so consecutive waves feel different. cfg["boss_pattern"] = ALL_PATTERN_IDS[rnd % len(ALL_PATTERN_IDS)] # Health-responsive attack speed: slow the boss for a wounded party, speed # it up for a healthy one — always within the wave's mercy floor. hps = [p.get("hp_pct", 100) for p in summary.get("players", []) if p.get("survived")] avg_hp = sum(hps) / len(hps) if hps else 0 desired = 0.7 if avg_hp < 35 else 0.85 if avg_hp < 60 else (1.2 if won else 1.0) cfg["boss_attack_speed"] = _clamp_attack_speed(desired, rnd) # Blessings: after a defeat, shield the survivors; after a hard-won victory # (someone died), give the most wounded survivor a warding aura. blessings = {} survivors = [p for p in summary.get("players", []) if p.get("survived")] if not won: for p in survivors[:3]: blessings[p["id"]] = "warding" elif survivors and any(not p.get("survived") for p in summary.get("players", [])): weakest = min(survivors, key=lambda p: p.get("hp_pct", 100)) blessings[weakest["id"]] = "full_heal" # Offer the full card set by default (the model may narrow it). card_pool = list(ALL_CARD_IDS) return {"message": msg, "reasoning": "deterministic fallback policy", "rewards": rewards, "blessings": blessings, "next_round": cfg, "card_pool": card_pool} def _validate(decision: dict, summary: dict, prev_cfg: dict) -> dict: """Coerce a (possibly model-authored) decision into a safe shape.""" safe = _deterministic(summary, prev_cfg) # defaults if not isinstance(decision, dict): return safe valid_ids = {p["id"] for p in summary.get("players", [])} if isinstance(decision.get("rewards"), dict): rewards = {} for k, v in decision["rewards"].items(): if k in valid_ids: try: rewards[k] = max(0, min(300, int(v))) except Exception: pass if rewards: safe["rewards"] = rewards nxt = decision.get("next_round") if isinstance(nxt, dict): merged = dict(safe["next_round"]) for key in ("boss_hp", "boss_damage", "minion_hp", "minion_count"): if key in nxt: try: merged[key] = int(nxt[key]) except Exception: pass for key in ("spawn_interval", "boss_aggro"): if key in nxt: try: merged[key] = float(nxt[key]) except Exception: pass # mercy guardrail: attack speed is clamped into the wave's allowed band if "boss_attack_speed" in nxt: merged["boss_attack_speed"] = _clamp_attack_speed( nxt.get("boss_attack_speed"), summary.get("round") or 0) # boss attack pattern (also accepted at the top level) pattern = nxt.get("boss_pattern", decision.get("boss_pattern")) if pattern in ALL_PATTERN_IDS: merged["boss_pattern"] = pattern # enemy mix wave = nxt.get("wave") if isinstance(wave, dict): clean = {} for k in MINION_TYPES: try: clean[k] = max(0.0, float(wave.get(k, 0.0))) except Exception: clean[k] = 0.0 if sum(clean.values()) > 0: merged["wave"] = clean safe["next_round"] = merged # per-player blessings (cap at 3, roster + id checked) if isinstance(decision.get("blessings"), dict): blessings = {k: v for k, v in decision["blessings"].items() if k in valid_ids and v in BLESSINGS} safe["blessings"] = dict(list(blessings.items())[:3]) # level-up card pool pool = decision.get("card_pool") if isinstance(pool, list): valid = [c for c in pool if c in ALL_CARD_IDS] if valid: safe["card_pool"] = valid if isinstance(decision.get("message"), str): safe["message"] = decision["message"][:120] if isinstance(decision.get("reasoning"), str): safe["reasoning"] = decision["reasoning"][:300] return safe def decide(summary: dict, prev_cfg: dict) -> dict: """Produce a validated decision and persist an agent trace for the round.""" trace_id = uuid.uuid4().hex[:8] system, user = SYSTEM_PROMPT, _build_user_prompt(summary, prev_cfg) raw, source, error = "", "fallback", None requested_speed = None t0 = time.time() try: raw = _run_model(system, user) parsed = _extract_json(raw) if isinstance(parsed, dict) and isinstance(parsed.get("next_round"), dict): requested_speed = parsed["next_round"].get("boss_attack_speed") decision = _validate(parsed, summary, prev_cfg) if parsed else _deterministic(summary, prev_cfg) source = "nemotron" if parsed else "fallback(parse_failed)" except Exception as e: error = str(e) decision = _deterministic(summary, prev_cfg) latency = round(time.time() - t0, 2) rnd = summary.get("round") or 0 applied_speed = decision.get("next_round", {}).get("boss_attack_speed") mercy = { "floor": _mercy_floor(rnd), "max": MERCY_MAX, "requested": requested_speed, "applied": applied_speed, "clamped": requested_speed is not None and requested_speed != applied_speed, "note": "mercy floor rises with the wave — by ~wave 13 the GM can no " "longer slow the boss to protect the party", } trace = { "trace_id": trace_id, "round": summary.get("round"), "mercy": mercy, "ts": time.strftime("%Y-%m-%d %H:%M:%S"), "model": MODEL_ID, "source": source, "latency_sec": latency, "error": error, "input": {"system": system, "user": user, "round_summary": summary}, "raw_output": raw, "decision": decision, } _persist(trace) return decision SKILL_SYSTEM_PROMPT = ( "You are the Game Master AI for HuggingWizards. A wizard asks you to grant a " "new power-up. Invent ONE balanced skill and reply with ONE JSON object only. " "Schema (data only — never code):\n" '{"name": str, "trigger": one of ' + str(sorted(skillmod.TRIGGERS)) + ', "n": int (for every_n_attacks), "interval": float (for periodic),' ' "effect": one of ' + str(sorted(skillmod.EFFECTS)) + ', "radius": float, "damage": float, "count": int, "amount": float,' ' "slow": float 0-1, "color": "#rrggbb"}\n' "Pick fields that match the chosen effect. Keep it fun but not overpowered." ) def _deterministic_skill(prompt: str) -> dict: """Keyword-based fallback skill generator (local dev / model unavailable).""" t = (prompt or "").lower() frost = any(w in t for w in ("frost", "ice", "freeze", "slow", "cold")) if any(w in t for w in ("summon", "spirit", "minion", "ally", "pet", "wolf")): spec = {"name": "Summoned Spirit", "trigger": "every_n_attacks", "n": 7, "effect": "summon_ally", "count": 1, "color": "#a6ff8c"} elif any(w in t for w in ("heal", "regen", "life", "restore", "vamp", "mend")): spec = {"name": "Mending Light", "trigger": "periodic", "interval": 6, "effect": "heal", "amount": 30, "color": "#6effd0"} elif any(w in t for w in ("shield", "barrier", "ward", "protect", "absorb", "armor")): spec = {"name": "Arcane Bulwark", "trigger": "periodic", "interval": 8, "effect": "shield", "amount": 30, "color": "#9db4c9"} elif frost or any(w in t for w in ("explos", "blast", "aoe", "area", "detonat", "fire")): spec = {"name": "Frost Nova" if frost else "Arcane Detonation", "trigger": "every_n_attacks", "n": 6, "effect": "aoe_damage", "radius": 140, "damage": 45, "slow": 0.4 if frost else 0.0, "color": "#7ee0ff" if frost else "#ff7a4a"} elif any(w in t for w in ("nova", "spread", "shotgun", "burst", "ring", "radial")): spec = {"name": "Star Burst", "trigger": "every_n_attacks", "n": 5, "effect": "projectile_nova", "count": 10, "damage": 24, "color": "#ffd45e"} else: # default: an AoE explosion spec = {"name": "Arcane Detonation", "trigger": "every_n_attacks", "n": 6, "effect": "aoe_damage", "radius": 140, "damage": 45, "slow": 0.0, "color": "#ff7a4a"} return spec def generate_skill(prompt: str, context: dict | None = None) -> dict: """Turn a player's free-text wish into a validated, safe skill spec. Tries Nemotron; always falls back to deterministic keywords; the result is run through skills.validate_skill so it is bounded and code-free. """ trace_id = uuid.uuid4().hex[:8] user = f"Wizard's request: {prompt!r}\nContext: {json.dumps(context or {})}\nReturn the skill JSON." raw, source, error = "", "fallback", None t0 = time.time() spec = None try: raw = _run_model(SKILL_SYSTEM_PROMPT, user) parsed = _extract_json(raw) spec = skillmod.validate_skill(parsed) if parsed else None source = "nemotron" if spec else "fallback(parse_failed)" except Exception as e: error = str(e) if spec is None: spec = skillmod.validate_skill(_deterministic_skill(prompt)) latency = round(time.time() - t0, 2) _persist({ "trace_id": trace_id, "round": (context or {}).get("round"), "ts": time.strftime("%Y-%m-%d %H:%M:%S"), "model": MODEL_ID, "source": source, "latency_sec": latency, "error": error, "kind": "skill_request", "input": {"system": SKILL_SYSTEM_PROMPT, "user": user, "prompt": prompt}, "raw_output": raw, "decision": spec, }) return spec def _persist(trace: dict): RECENT_TRACES.insert(0, trace) del RECENT_TRACES[_MAX_RECENT:] try: rnd = trace.get("round") prefix = f"round_{rnd:03d}" if isinstance(rnd, int) else "skill" path = os.path.join(TRACE_DIR, f"{prefix}_{trace['trace_id']}.json") with trace_store.lock(): with open(path, "w", encoding="utf-8") as f: json.dump(trace, f, indent=2) except Exception as e: # pragma: no cover print(f"[gamemaster] failed to write trace: {e}")