Spaces:
Running on Zero
Running on Zero
| """Nemotron-4B Game Master. | |
| At each round boundary the engine hands us a `round_summary`. We ask | |
| NVIDIA's Nemotron-Mini-4B-Instruct to return a JSON decision covering: | |
| - rewards: {player_id: gold} | |
| - next_round: {boss_hp, boss_damage, minion_hp, minion_count, spawn_interval} | |
| - message: short flavor line shown to players | |
| - reasoning: why (kept in the trace) | |
| This is *burst* GPU work that fits ZeroGPU's `@spaces.GPU` model perfectly: we | |
| grab the GPU for one short inference per round, then release it. If the model | |
| or GPU is unavailable (e.g. local dev / CI) we fall back to deterministic logic | |
| so the game always keeps running. | |
| Every decision is written to `traces/` as a self-contained agent trace. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| import time | |
| import uuid | |
| from game.engine import ALL_CARD_IDS, ALL_PATTERN_IDS, BLESSINGS, MINION_TYPES | |
| from game import skills as skillmod | |
| from game import trace_store | |
| MODEL_ID = "nvidia/Nemotron-Mini-4B-Instruct" | |
| TRACE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "traces") | |
| os.makedirs(TRACE_DIR, exist_ok=True) | |
| # Background sync of every trace file to a HF dataset repo (no-op if not configured). | |
| trace_store.start(TRACE_DIR) | |
| # Lazily-initialised globals (loaded once, on first GPU call). | |
| _tokenizer = None | |
| _model = None | |
| _load_failed = False | |
| # In-memory ring of recent traces for the dashboard. | |
| RECENT_TRACES: list[dict] = [] | |
| _MAX_RECENT = 25 | |
| # True when running inside a Hugging Face Space (ZeroGPU or otherwise). | |
| _ON_SPACES = bool(os.environ.get("SPACE_ID")) | |
| try: # `spaces` only exists on HF infra; degrade gracefully elsewhere. | |
| import spaces # type: ignore | |
| def _gpu(fn): | |
| return spaces.GPU(duration=60)(fn) | |
| except Exception: # pragma: no cover - local dev path | |
| def _gpu(fn): | |
| return fn | |
| SYSTEM_PROMPT = ( | |
| "You are the Game Master / Director AI for HuggingWizards, a co-op pixel " | |
| "survivors-arena where wizards fight waves of enemies and a boss. After each " | |
| "wave you direct: rewards, the next wave's difficulty, the enemy mix, and the " | |
| "pool of level-up cards players may be offered. Reply with ONE JSON object and " | |
| "nothing else. Schema:\n" | |
| '{"message": str (<=120 chars, in-character narration),' | |
| ' "reasoning": str (one sentence),' | |
| ' "rewards": {player_id: int gold 0-300},' | |
| ' "blessings": {player_id: blessing_id} (optional, bless 0-3 wizards),' | |
| ' "next_round": {"boss_hp": int, "boss_damage": int, "minion_hp": int,' | |
| ' "minion_count": int, "spawn_interval": float, "boss_aggro": float 0.7-3,' | |
| ' "boss_attack_speed": float 0.5-2.0, "boss_pattern": pattern_id,' | |
| ' "wave": {"grunt": float, "fast": float, "tank": float}},' | |
| ' "card_pool": [card_id, ...]}\n' | |
| f"Enemy archetypes (wave weights, must sum > 0): {list(MINION_TYPES)}.\n" | |
| f"Boss attack patterns: {ALL_PATTERN_IDS}. Switch the pattern between rounds " | |
| "to keep fights fresh — sniper punishes kiting, artillery punishes camping, " | |
| "swarm floods the arena, berserker charges relentlessly.\n" | |
| f"Blessings: {BLESSINGS}. Bless wizards who earned it — surviving a brutal " | |
| "wave, clutch plays, or to help a struggling wizard back on their feet " | |
| "(extra_life / full_heal). Auras last one wave.\n" | |
| f"Valid card_pool ids (offer 4-8): {ALL_CARD_IDS}.\n" | |
| "boss_attack_speed sets how fast the boss attacks next wave (1.0 = normal). " | |
| "Check every player's hp_pct and lives_left: if the party is badly hurt, " | |
| "slow the boss (<1.0) so they can recover; if they are healthy, speed it up " | |
| "(>1.0). MERCY DECAYS: each round the user prompt states the minimum you may " | |
| "set — in later waves you can no longer slow the boss to protect them. " | |
| "Values outside the allowed range are clamped.\n" | |
| "Reward players for damage dealt, kills, and surviving. Scale difficulty up " | |
| "after a victory and ease it slightly after a defeat. Introduce tougher enemy " | |
| "archetypes as rounds progress. Curate cards to keep the run fun and winnable " | |
| "for the number of players." | |
| ) | |
| def _build_user_prompt(summary: dict, prev_cfg: dict) -> str: | |
| rnd = summary.get("round") or 0 | |
| return ( | |
| "Current next-round config (you may adjust): " | |
| + json.dumps(prev_cfg) | |
| + "\nRound that just ended:\n" | |
| + json.dumps(summary, indent=2) | |
| + f"\nMercy floor this round: boss_attack_speed must be between " | |
| + f"{_mercy_floor(rnd)} and {MERCY_MAX}." | |
| + "\nReturn the JSON decision now." | |
| ) | |
| def _load_model(): | |
| """Load tokenizer + model, optionally quantized. | |
| HUGGINGWIZARDS_QUANT = none | 4bit | 8bit | auto (default). | |
| `auto` picks 4-bit on small local GPUs (<12 GB VRAM) and bf16 everywhere | |
| else — ZeroGPU's H200 runs the 4B model comfortably in bf16, where it is | |
| also faster per token than bitsandbytes 4-bit. | |
| """ | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| has_cuda = torch.cuda.is_available() | |
| quant = os.environ.get("HUGGINGWIZARDS_QUANT", "auto").lower() | |
| if quant == "auto": | |
| if has_cuda and not _ON_SPACES: | |
| vram = torch.cuda.get_device_properties(0).total_memory | |
| quant = "4bit" if vram < 12 * 1024**3 else "none" | |
| else: | |
| quant = "none" | |
| if quant in ("4bit", "8bit") and has_cuda: | |
| from transformers import BitsAndBytesConfig | |
| qcfg = ( | |
| BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16) | |
| if quant == "4bit" | |
| else BitsAndBytesConfig(load_in_8bit=True) | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, quantization_config=qcfg, device_map={"": 0} | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, dtype=torch.bfloat16 if has_cuda else torch.float32 | |
| ) | |
| if has_cuda: | |
| model = model.to("cuda") # ZeroGPU-safe (device_map="auto" is not) | |
| model.eval() | |
| print(f"[gamemaster] loaded {MODEL_ID} (quant={quant}, cuda={has_cuda})") | |
| return tokenizer, model | |
| def _ensure_model(): | |
| global _tokenizer, _model, _load_failed | |
| if _model is not None or _load_failed: | |
| return _model is not None | |
| if os.environ.get("HUGGINGWIZARDS_NO_LLM"): | |
| _load_failed = True # force deterministic fallback (local dev / CI) | |
| return False | |
| try: | |
| _tokenizer, _model = _load_model() | |
| return True | |
| except Exception as e: # pragma: no cover | |
| print(f"[gamemaster] model load failed, using fallback: {e}") | |
| _load_failed = True | |
| return False | |
| # ---- mercy guardrail ------------------------------------------------------- | |
| # The GM may slow the boss's attack speed to help wounded parties, but its | |
| # willingness to help decays as waves progress: the allowed floor rises from | |
| # 0.5 toward 1.0 (no mercy) by ~wave 13. The ceiling is always 2.0. | |
| MERCY_MAX = 2.0 | |
| def _mercy_floor(rnd: int) -> float: | |
| return round(min(1.0, 0.5 + 0.04 * max(0, int(rnd or 0))), 2) | |
| def _clamp_attack_speed(value, rnd: int) -> float: | |
| try: | |
| v = float(value) | |
| except Exception: | |
| v = 1.0 | |
| return round(max(_mercy_floor(rnd), min(MERCY_MAX, v)), 2) | |
| def _run_model(system: str, user: str) -> str: | |
| """Single short generation on the GPU. Returns raw text.""" | |
| if not _ensure_model(): | |
| raise RuntimeError("model unavailable") | |
| import torch | |
| messages = [{"role": "system", "content": system}, | |
| {"role": "user", "content": user}] | |
| # return_dict=True gives a BatchEncoding (input_ids + attention_mask) on | |
| # every transformers version — newer releases return it by default, and | |
| # passing it positionally to generate() crashes on `.shape`. | |
| enc = _tokenizer.apply_chat_template( | |
| messages, add_generation_prompt=True, return_tensors="pt", return_dict=True | |
| ).to(_model.device) | |
| with torch.no_grad(): | |
| out = _model.generate( | |
| **enc, max_new_tokens=400, do_sample=True, temperature=0.7, top_p=0.9, | |
| pad_token_id=_tokenizer.eos_token_id, | |
| ) | |
| n_in = enc["input_ids"].shape[1] | |
| text = _tokenizer.decode(out[0][n_in:], skip_special_tokens=True) | |
| return text | |
| # On Spaces, load the model at startup (ZeroGPU replays the .to("cuda") once a | |
| # real GPU is attached) so the @spaces.GPU window is spent on inference only — | |
| # lazy-loading inside the first GPU call would blow the 60 s duration on the | |
| # weight download. This MUST come after the @spaces.GPU function above is | |
| # defined: ZeroGPU's startup probe needs to see a registered GPU function, and | |
| # the multi-minute weight download would otherwise delay that registration | |
| # ("No @spaces.GPU function detected during startup"). Locally we stay lazy so | |
| # dev/CI never downloads weights. | |
| if _ON_SPACES and not os.environ.get("HUGGINGWIZARDS_NO_LLM"): | |
| _ensure_model() | |
| def _extract_json(text: str) -> dict | None: | |
| m = re.search(r"\{.*\}", text, re.DOTALL) | |
| if not m: | |
| return None | |
| try: | |
| return json.loads(m.group(0)) | |
| except Exception: | |
| return None | |
| def _deterministic(summary: dict, prev_cfg: dict) -> dict: | |
| """Fallback Game Master logic — also a sane validation target.""" | |
| won = summary.get("result") == "victory" | |
| rewards = {} | |
| for p in summary.get("players", []): | |
| gold = 20 + int(p.get("damage_dealt", 0) / 25) + p.get("kills", 0) * 5 | |
| if p.get("survived"): | |
| gold += 25 | |
| rewards[p["id"]] = min(300, gold) | |
| cfg = dict(prev_cfg) | |
| n_players = max(1, len(summary.get("players", []))) | |
| rnd = int(summary.get("round") or 1) | |
| if won: | |
| cfg["boss_hp"] = int(prev_cfg["boss_hp"] * 1.35) + 200 * n_players | |
| cfg["boss_damage"] = min(80, prev_cfg["boss_damage"] + 2) | |
| cfg["minion_count"] = min(40, prev_cfg["minion_count"] + 2) | |
| cfg["minion_hp"] = int(prev_cfg["minion_hp"] * 1.15) | |
| cfg["spawn_interval"] = max(1.5, prev_cfg["spawn_interval"] - 0.3) | |
| cfg["boss_aggro"] = min(3.0, round(prev_cfg.get("boss_aggro", 1.0) + 0.2, 2)) | |
| msg = "Impressive. The next horde will not fall so easily." | |
| else: | |
| cfg["boss_hp"] = max(300, int(prev_cfg["boss_hp"] * 0.85)) | |
| cfg["boss_damage"] = max(4, prev_cfg["boss_damage"] - 1) | |
| cfg["minion_count"] = max(0, prev_cfg["minion_count"] - 1) | |
| cfg["spawn_interval"] = min(8.0, prev_cfg["spawn_interval"] + 0.5) | |
| cfg["boss_aggro"] = max(0.7, round(prev_cfg.get("boss_aggro", 1.0) - 0.1, 2)) | |
| msg = "Rest, wizards. The arena bends slightly in your favor." | |
| # Enemy mix escalates with the round: grunts always, fast from r2, tanks r4+. | |
| cfg["wave"] = { | |
| "grunt": 1.0, | |
| "fast": round(min(0.8, max(0.0, (rnd - 1) * 0.25)), 2), | |
| "tank": round(min(0.6, max(0.0, (rnd - 3) * 0.2)), 2), | |
| } | |
| # Rotate the boss's attack pattern so consecutive waves feel different. | |
| cfg["boss_pattern"] = ALL_PATTERN_IDS[rnd % len(ALL_PATTERN_IDS)] | |
| # Health-responsive attack speed: slow the boss for a wounded party, speed | |
| # it up for a healthy one — always within the wave's mercy floor. | |
| hps = [p.get("hp_pct", 100) for p in summary.get("players", []) if p.get("survived")] | |
| avg_hp = sum(hps) / len(hps) if hps else 0 | |
| desired = 0.7 if avg_hp < 35 else 0.85 if avg_hp < 60 else (1.2 if won else 1.0) | |
| cfg["boss_attack_speed"] = _clamp_attack_speed(desired, rnd) | |
| # Blessings: after a defeat, shield the survivors; after a hard-won victory | |
| # (someone died), give the most wounded survivor a warding aura. | |
| blessings = {} | |
| survivors = [p for p in summary.get("players", []) if p.get("survived")] | |
| if not won: | |
| for p in survivors[:3]: | |
| blessings[p["id"]] = "warding" | |
| elif survivors and any(not p.get("survived") for p in summary.get("players", [])): | |
| weakest = min(survivors, key=lambda p: p.get("hp_pct", 100)) | |
| blessings[weakest["id"]] = "full_heal" | |
| # Offer the full card set by default (the model may narrow it). | |
| card_pool = list(ALL_CARD_IDS) | |
| return {"message": msg, "reasoning": "deterministic fallback policy", | |
| "rewards": rewards, "blessings": blessings, | |
| "next_round": cfg, "card_pool": card_pool} | |
| def _validate(decision: dict, summary: dict, prev_cfg: dict) -> dict: | |
| """Coerce a (possibly model-authored) decision into a safe shape.""" | |
| safe = _deterministic(summary, prev_cfg) # defaults | |
| if not isinstance(decision, dict): | |
| return safe | |
| valid_ids = {p["id"] for p in summary.get("players", [])} | |
| if isinstance(decision.get("rewards"), dict): | |
| rewards = {} | |
| for k, v in decision["rewards"].items(): | |
| if k in valid_ids: | |
| try: | |
| rewards[k] = max(0, min(300, int(v))) | |
| except Exception: | |
| pass | |
| if rewards: | |
| safe["rewards"] = rewards | |
| nxt = decision.get("next_round") | |
| if isinstance(nxt, dict): | |
| merged = dict(safe["next_round"]) | |
| for key in ("boss_hp", "boss_damage", "minion_hp", "minion_count"): | |
| if key in nxt: | |
| try: | |
| merged[key] = int(nxt[key]) | |
| except Exception: | |
| pass | |
| for key in ("spawn_interval", "boss_aggro"): | |
| if key in nxt: | |
| try: | |
| merged[key] = float(nxt[key]) | |
| except Exception: | |
| pass | |
| # mercy guardrail: attack speed is clamped into the wave's allowed band | |
| if "boss_attack_speed" in nxt: | |
| merged["boss_attack_speed"] = _clamp_attack_speed( | |
| nxt.get("boss_attack_speed"), summary.get("round") or 0) | |
| # boss attack pattern (also accepted at the top level) | |
| pattern = nxt.get("boss_pattern", decision.get("boss_pattern")) | |
| if pattern in ALL_PATTERN_IDS: | |
| merged["boss_pattern"] = pattern | |
| # enemy mix | |
| wave = nxt.get("wave") | |
| if isinstance(wave, dict): | |
| clean = {} | |
| for k in MINION_TYPES: | |
| try: | |
| clean[k] = max(0.0, float(wave.get(k, 0.0))) | |
| except Exception: | |
| clean[k] = 0.0 | |
| if sum(clean.values()) > 0: | |
| merged["wave"] = clean | |
| safe["next_round"] = merged | |
| # per-player blessings (cap at 3, roster + id checked) | |
| if isinstance(decision.get("blessings"), dict): | |
| blessings = {k: v for k, v in decision["blessings"].items() | |
| if k in valid_ids and v in BLESSINGS} | |
| safe["blessings"] = dict(list(blessings.items())[:3]) | |
| # level-up card pool | |
| pool = decision.get("card_pool") | |
| if isinstance(pool, list): | |
| valid = [c for c in pool if c in ALL_CARD_IDS] | |
| if valid: | |
| safe["card_pool"] = valid | |
| if isinstance(decision.get("message"), str): | |
| safe["message"] = decision["message"][:120] | |
| if isinstance(decision.get("reasoning"), str): | |
| safe["reasoning"] = decision["reasoning"][:300] | |
| return safe | |
| def decide(summary: dict, prev_cfg: dict) -> dict: | |
| """Produce a validated decision and persist an agent trace for the round.""" | |
| trace_id = uuid.uuid4().hex[:8] | |
| system, user = SYSTEM_PROMPT, _build_user_prompt(summary, prev_cfg) | |
| raw, source, error = "", "fallback", None | |
| requested_speed = None | |
| t0 = time.time() | |
| try: | |
| raw = _run_model(system, user) | |
| parsed = _extract_json(raw) | |
| if isinstance(parsed, dict) and isinstance(parsed.get("next_round"), dict): | |
| requested_speed = parsed["next_round"].get("boss_attack_speed") | |
| decision = _validate(parsed, summary, prev_cfg) if parsed else _deterministic(summary, prev_cfg) | |
| source = "nemotron" if parsed else "fallback(parse_failed)" | |
| except Exception as e: | |
| error = str(e) | |
| decision = _deterministic(summary, prev_cfg) | |
| latency = round(time.time() - t0, 2) | |
| rnd = summary.get("round") or 0 | |
| applied_speed = decision.get("next_round", {}).get("boss_attack_speed") | |
| mercy = { | |
| "floor": _mercy_floor(rnd), "max": MERCY_MAX, | |
| "requested": requested_speed, "applied": applied_speed, | |
| "clamped": requested_speed is not None and requested_speed != applied_speed, | |
| "note": "mercy floor rises with the wave — by ~wave 13 the GM can no " | |
| "longer slow the boss to protect the party", | |
| } | |
| trace = { | |
| "trace_id": trace_id, | |
| "round": summary.get("round"), | |
| "mercy": mercy, | |
| "ts": time.strftime("%Y-%m-%d %H:%M:%S"), | |
| "model": MODEL_ID, | |
| "source": source, | |
| "latency_sec": latency, | |
| "error": error, | |
| "input": {"system": system, "user": user, "round_summary": summary}, | |
| "raw_output": raw, | |
| "decision": decision, | |
| } | |
| _persist(trace) | |
| return decision | |
| SKILL_SYSTEM_PROMPT = ( | |
| "You are the Game Master AI for HuggingWizards. A wizard asks you to grant a " | |
| "new power-up. Invent ONE balanced skill and reply with ONE JSON object only. " | |
| "Schema (data only — never code):\n" | |
| '{"name": str, "trigger": one of ' | |
| + str(sorted(skillmod.TRIGGERS)) | |
| + ', "n": int (for every_n_attacks), "interval": float (for periodic),' | |
| ' "effect": one of ' | |
| + str(sorted(skillmod.EFFECTS)) | |
| + ', "radius": float, "damage": float, "count": int, "amount": float,' | |
| ' "slow": float 0-1, "color": "#rrggbb"}\n' | |
| "Pick fields that match the chosen effect. Keep it fun but not overpowered." | |
| ) | |
| def _deterministic_skill(prompt: str) -> dict: | |
| """Keyword-based fallback skill generator (local dev / model unavailable).""" | |
| t = (prompt or "").lower() | |
| frost = any(w in t for w in ("frost", "ice", "freeze", "slow", "cold")) | |
| if any(w in t for w in ("summon", "spirit", "minion", "ally", "pet", "wolf")): | |
| spec = {"name": "Summoned Spirit", "trigger": "every_n_attacks", "n": 7, | |
| "effect": "summon_ally", "count": 1, "color": "#a6ff8c"} | |
| elif any(w in t for w in ("heal", "regen", "life", "restore", "vamp", "mend")): | |
| spec = {"name": "Mending Light", "trigger": "periodic", "interval": 6, | |
| "effect": "heal", "amount": 30, "color": "#6effd0"} | |
| elif any(w in t for w in ("shield", "barrier", "ward", "protect", "absorb", "armor")): | |
| spec = {"name": "Arcane Bulwark", "trigger": "periodic", "interval": 8, | |
| "effect": "shield", "amount": 30, "color": "#9db4c9"} | |
| elif frost or any(w in t for w in ("explos", "blast", "aoe", "area", "detonat", "fire")): | |
| spec = {"name": "Frost Nova" if frost else "Arcane Detonation", | |
| "trigger": "every_n_attacks", "n": 6, "effect": "aoe_damage", | |
| "radius": 140, "damage": 45, "slow": 0.4 if frost else 0.0, | |
| "color": "#7ee0ff" if frost else "#ff7a4a"} | |
| elif any(w in t for w in ("nova", "spread", "shotgun", "burst", "ring", "radial")): | |
| spec = {"name": "Star Burst", "trigger": "every_n_attacks", "n": 5, | |
| "effect": "projectile_nova", "count": 10, "damage": 24, "color": "#ffd45e"} | |
| else: # default: an AoE explosion | |
| spec = {"name": "Arcane Detonation", "trigger": "every_n_attacks", "n": 6, | |
| "effect": "aoe_damage", "radius": 140, "damage": 45, "slow": 0.0, | |
| "color": "#ff7a4a"} | |
| return spec | |
| def generate_skill(prompt: str, context: dict | None = None) -> dict: | |
| """Turn a player's free-text wish into a validated, safe skill spec. | |
| Tries Nemotron; always falls back to deterministic keywords; the result is | |
| run through skills.validate_skill so it is bounded and code-free. | |
| """ | |
| trace_id = uuid.uuid4().hex[:8] | |
| user = f"Wizard's request: {prompt!r}\nContext: {json.dumps(context or {})}\nReturn the skill JSON." | |
| raw, source, error = "", "fallback", None | |
| t0 = time.time() | |
| spec = None | |
| try: | |
| raw = _run_model(SKILL_SYSTEM_PROMPT, user) | |
| parsed = _extract_json(raw) | |
| spec = skillmod.validate_skill(parsed) if parsed else None | |
| source = "nemotron" if spec else "fallback(parse_failed)" | |
| except Exception as e: | |
| error = str(e) | |
| if spec is None: | |
| spec = skillmod.validate_skill(_deterministic_skill(prompt)) | |
| latency = round(time.time() - t0, 2) | |
| _persist({ | |
| "trace_id": trace_id, "round": (context or {}).get("round"), | |
| "ts": time.strftime("%Y-%m-%d %H:%M:%S"), "model": MODEL_ID, | |
| "source": source, "latency_sec": latency, "error": error, | |
| "kind": "skill_request", | |
| "input": {"system": SKILL_SYSTEM_PROMPT, "user": user, "prompt": prompt}, | |
| "raw_output": raw, "decision": spec, | |
| }) | |
| return spec | |
| def _persist(trace: dict): | |
| RECENT_TRACES.insert(0, trace) | |
| del RECENT_TRACES[_MAX_RECENT:] | |
| try: | |
| rnd = trace.get("round") | |
| prefix = f"round_{rnd:03d}" if isinstance(rnd, int) else "skill" | |
| path = os.path.join(TRACE_DIR, f"{prefix}_{trace['trace_id']}.json") | |
| with trace_store.lock(): | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(trace, f, indent=2) | |
| except Exception as e: # pragma: no cover | |
| print(f"[gamemaster] failed to write trace: {e}") | |