HuggingWizards / game /gamemaster.py
Quazim0t0's picture
Update game/gamemaster.py
7f5d4f2 verified
Raw
History Blame
21.5 kB
"""Nemotron-4B Game Master.
At each round boundary the engine hands us a `round_summary`. We ask
NVIDIA's Nemotron-Mini-4B-Instruct to return a JSON decision covering:
- rewards: {player_id: gold}
- next_round: {boss_hp, boss_damage, minion_hp, minion_count, spawn_interval}
- message: short flavor line shown to players
- reasoning: why (kept in the trace)
This is *burst* GPU work that fits ZeroGPU's `@spaces.GPU` model perfectly: we
grab the GPU for one short inference per round, then release it. If the model
or GPU is unavailable (e.g. local dev / CI) we fall back to deterministic logic
so the game always keeps running.
Every decision is written to `traces/` as a self-contained agent trace.
"""
from __future__ import annotations
import json
import os
import re
import time
import uuid
from game.engine import ALL_CARD_IDS, ALL_PATTERN_IDS, BLESSINGS, MINION_TYPES
from game import skills as skillmod
from game import trace_store
MODEL_ID = "nvidia/Nemotron-Mini-4B-Instruct"
TRACE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "traces")
os.makedirs(TRACE_DIR, exist_ok=True)
# Background sync of every trace file to a HF dataset repo (no-op if not configured).
trace_store.start(TRACE_DIR)
# Lazily-initialised globals (loaded once, on first GPU call).
_tokenizer = None
_model = None
_load_failed = False
# In-memory ring of recent traces for the dashboard.
RECENT_TRACES: list[dict] = []
_MAX_RECENT = 25
# True when running inside a Hugging Face Space (ZeroGPU or otherwise).
_ON_SPACES = bool(os.environ.get("SPACE_ID"))
try: # `spaces` only exists on HF infra; degrade gracefully elsewhere.
import spaces # type: ignore
def _gpu(fn):
return spaces.GPU(duration=60)(fn)
except Exception: # pragma: no cover - local dev path
def _gpu(fn):
return fn
SYSTEM_PROMPT = (
"You are the Game Master / Director AI for HuggingWizards, a co-op pixel "
"survivors-arena where wizards fight waves of enemies and a boss. After each "
"wave you direct: rewards, the next wave's difficulty, the enemy mix, and the "
"pool of level-up cards players may be offered. Reply with ONE JSON object and "
"nothing else. Schema:\n"
'{"message": str (<=120 chars, in-character narration),'
' "reasoning": str (one sentence),'
' "rewards": {player_id: int gold 0-300},'
' "blessings": {player_id: blessing_id} (optional, bless 0-3 wizards),'
' "next_round": {"boss_hp": int, "boss_damage": int, "minion_hp": int,'
' "minion_count": int, "spawn_interval": float, "boss_aggro": float 0.7-3,'
' "boss_attack_speed": float 0.5-2.0, "boss_pattern": pattern_id,'
' "wave": {"grunt": float, "fast": float, "tank": float}},'
' "card_pool": [card_id, ...]}\n'
f"Enemy archetypes (wave weights, must sum > 0): {list(MINION_TYPES)}.\n"
f"Boss attack patterns: {ALL_PATTERN_IDS}. Switch the pattern between rounds "
"to keep fights fresh — sniper punishes kiting, artillery punishes camping, "
"swarm floods the arena, berserker charges relentlessly.\n"
f"Blessings: {BLESSINGS}. Bless wizards who earned it — surviving a brutal "
"wave, clutch plays, or to help a struggling wizard back on their feet "
"(extra_life / full_heal). Auras last one wave.\n"
f"Valid card_pool ids (offer 4-8): {ALL_CARD_IDS}.\n"
"boss_attack_speed sets how fast the boss attacks next wave (1.0 = normal). "
"Check every player's hp_pct and lives_left: if the party is badly hurt, "
"slow the boss (<1.0) so they can recover; if they are healthy, speed it up "
"(>1.0). MERCY DECAYS: each round the user prompt states the minimum you may "
"set — in later waves you can no longer slow the boss to protect them. "
"Values outside the allowed range are clamped.\n"
"Reward players for damage dealt, kills, and surviving. Scale difficulty up "
"after a victory and ease it slightly after a defeat. Introduce tougher enemy "
"archetypes as rounds progress. Curate cards to keep the run fun and winnable "
"for the number of players."
)
def _build_user_prompt(summary: dict, prev_cfg: dict) -> str:
rnd = summary.get("round") or 0
return (
"Current next-round config (you may adjust): "
+ json.dumps(prev_cfg)
+ "\nRound that just ended:\n"
+ json.dumps(summary, indent=2)
+ f"\nMercy floor this round: boss_attack_speed must be between "
+ f"{_mercy_floor(rnd)} and {MERCY_MAX}."
+ "\nReturn the JSON decision now."
)
def _load_model():
"""Load tokenizer + model, optionally quantized.
HUGGINGWIZARDS_QUANT = none | 4bit | 8bit | auto (default).
`auto` picks 4-bit on small local GPUs (<12 GB VRAM) and bf16 everywhere
else — ZeroGPU's H200 runs the 4B model comfortably in bf16, where it is
also faster per token than bitsandbytes 4-bit.
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
has_cuda = torch.cuda.is_available()
quant = os.environ.get("HUGGINGWIZARDS_QUANT", "auto").lower()
if quant == "auto":
if has_cuda and not _ON_SPACES:
vram = torch.cuda.get_device_properties(0).total_memory
quant = "4bit" if vram < 12 * 1024**3 else "none"
else:
quant = "none"
if quant in ("4bit", "8bit") and has_cuda:
from transformers import BitsAndBytesConfig
qcfg = (
BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16)
if quant == "4bit"
else BitsAndBytesConfig(load_in_8bit=True)
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, quantization_config=qcfg, device_map={"": 0}
)
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype=torch.bfloat16 if has_cuda else torch.float32
)
if has_cuda:
model = model.to("cuda") # ZeroGPU-safe (device_map="auto" is not)
model.eval()
print(f"[gamemaster] loaded {MODEL_ID} (quant={quant}, cuda={has_cuda})")
return tokenizer, model
def _ensure_model():
global _tokenizer, _model, _load_failed
if _model is not None or _load_failed:
return _model is not None
if os.environ.get("HUGGINGWIZARDS_NO_LLM"):
_load_failed = True # force deterministic fallback (local dev / CI)
return False
try:
_tokenizer, _model = _load_model()
return True
except Exception as e: # pragma: no cover
print(f"[gamemaster] model load failed, using fallback: {e}")
_load_failed = True
return False
# ---- mercy guardrail -------------------------------------------------------
# The GM may slow the boss's attack speed to help wounded parties, but its
# willingness to help decays as waves progress: the allowed floor rises from
# 0.5 toward 1.0 (no mercy) by ~wave 13. The ceiling is always 2.0.
MERCY_MAX = 2.0
def _mercy_floor(rnd: int) -> float:
return round(min(1.0, 0.5 + 0.04 * max(0, int(rnd or 0))), 2)
def _clamp_attack_speed(value, rnd: int) -> float:
try:
v = float(value)
except Exception:
v = 1.0
return round(max(_mercy_floor(rnd), min(MERCY_MAX, v)), 2)
@_gpu
def _run_model(system: str, user: str) -> str:
"""Single short generation on the GPU. Returns raw text."""
if not _ensure_model():
raise RuntimeError("model unavailable")
import torch
messages = [{"role": "system", "content": system},
{"role": "user", "content": user}]
# return_dict=True gives a BatchEncoding (input_ids + attention_mask) on
# every transformers version — newer releases return it by default, and
# passing it positionally to generate() crashes on `.shape`.
enc = _tokenizer.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
).to(_model.device)
with torch.no_grad():
out = _model.generate(
**enc, max_new_tokens=400, do_sample=True, temperature=0.7, top_p=0.9,
pad_token_id=_tokenizer.eos_token_id,
)
n_in = enc["input_ids"].shape[1]
text = _tokenizer.decode(out[0][n_in:], skip_special_tokens=True)
return text
# On Spaces, load the model at startup (ZeroGPU replays the .to("cuda") once a
# real GPU is attached) so the @spaces.GPU window is spent on inference only —
# lazy-loading inside the first GPU call would blow the 60 s duration on the
# weight download. This MUST come after the @spaces.GPU function above is
# defined: ZeroGPU's startup probe needs to see a registered GPU function, and
# the multi-minute weight download would otherwise delay that registration
# ("No @spaces.GPU function detected during startup"). Locally we stay lazy so
# dev/CI never downloads weights.
if _ON_SPACES and not os.environ.get("HUGGINGWIZARDS_NO_LLM"):
_ensure_model()
def _extract_json(text: str) -> dict | None:
m = re.search(r"\{.*\}", text, re.DOTALL)
if not m:
return None
try:
return json.loads(m.group(0))
except Exception:
return None
def _deterministic(summary: dict, prev_cfg: dict) -> dict:
"""Fallback Game Master logic — also a sane validation target."""
won = summary.get("result") == "victory"
rewards = {}
for p in summary.get("players", []):
gold = 20 + int(p.get("damage_dealt", 0) / 25) + p.get("kills", 0) * 5
if p.get("survived"):
gold += 25
rewards[p["id"]] = min(300, gold)
cfg = dict(prev_cfg)
n_players = max(1, len(summary.get("players", [])))
rnd = int(summary.get("round") or 1)
if won:
cfg["boss_hp"] = int(prev_cfg["boss_hp"] * 1.35) + 200 * n_players
cfg["boss_damage"] = min(80, prev_cfg["boss_damage"] + 2)
cfg["minion_count"] = min(40, prev_cfg["minion_count"] + 2)
cfg["minion_hp"] = int(prev_cfg["minion_hp"] * 1.15)
cfg["spawn_interval"] = max(1.5, prev_cfg["spawn_interval"] - 0.3)
cfg["boss_aggro"] = min(3.0, round(prev_cfg.get("boss_aggro", 1.0) + 0.2, 2))
msg = "Impressive. The next horde will not fall so easily."
else:
cfg["boss_hp"] = max(300, int(prev_cfg["boss_hp"] * 0.85))
cfg["boss_damage"] = max(4, prev_cfg["boss_damage"] - 1)
cfg["minion_count"] = max(0, prev_cfg["minion_count"] - 1)
cfg["spawn_interval"] = min(8.0, prev_cfg["spawn_interval"] + 0.5)
cfg["boss_aggro"] = max(0.7, round(prev_cfg.get("boss_aggro", 1.0) - 0.1, 2))
msg = "Rest, wizards. The arena bends slightly in your favor."
# Enemy mix escalates with the round: grunts always, fast from r2, tanks r4+.
cfg["wave"] = {
"grunt": 1.0,
"fast": round(min(0.8, max(0.0, (rnd - 1) * 0.25)), 2),
"tank": round(min(0.6, max(0.0, (rnd - 3) * 0.2)), 2),
}
# Rotate the boss's attack pattern so consecutive waves feel different.
cfg["boss_pattern"] = ALL_PATTERN_IDS[rnd % len(ALL_PATTERN_IDS)]
# Health-responsive attack speed: slow the boss for a wounded party, speed
# it up for a healthy one — always within the wave's mercy floor.
hps = [p.get("hp_pct", 100) for p in summary.get("players", []) if p.get("survived")]
avg_hp = sum(hps) / len(hps) if hps else 0
desired = 0.7 if avg_hp < 35 else 0.85 if avg_hp < 60 else (1.2 if won else 1.0)
cfg["boss_attack_speed"] = _clamp_attack_speed(desired, rnd)
# Blessings: after a defeat, shield the survivors; after a hard-won victory
# (someone died), give the most wounded survivor a warding aura.
blessings = {}
survivors = [p for p in summary.get("players", []) if p.get("survived")]
if not won:
for p in survivors[:3]:
blessings[p["id"]] = "warding"
elif survivors and any(not p.get("survived") for p in summary.get("players", [])):
weakest = min(survivors, key=lambda p: p.get("hp_pct", 100))
blessings[weakest["id"]] = "full_heal"
# Offer the full card set by default (the model may narrow it).
card_pool = list(ALL_CARD_IDS)
return {"message": msg, "reasoning": "deterministic fallback policy",
"rewards": rewards, "blessings": blessings,
"next_round": cfg, "card_pool": card_pool}
def _validate(decision: dict, summary: dict, prev_cfg: dict) -> dict:
"""Coerce a (possibly model-authored) decision into a safe shape."""
safe = _deterministic(summary, prev_cfg) # defaults
if not isinstance(decision, dict):
return safe
valid_ids = {p["id"] for p in summary.get("players", [])}
if isinstance(decision.get("rewards"), dict):
rewards = {}
for k, v in decision["rewards"].items():
if k in valid_ids:
try:
rewards[k] = max(0, min(300, int(v)))
except Exception:
pass
if rewards:
safe["rewards"] = rewards
nxt = decision.get("next_round")
if isinstance(nxt, dict):
merged = dict(safe["next_round"])
for key in ("boss_hp", "boss_damage", "minion_hp", "minion_count"):
if key in nxt:
try:
merged[key] = int(nxt[key])
except Exception:
pass
for key in ("spawn_interval", "boss_aggro"):
if key in nxt:
try:
merged[key] = float(nxt[key])
except Exception:
pass
# mercy guardrail: attack speed is clamped into the wave's allowed band
if "boss_attack_speed" in nxt:
merged["boss_attack_speed"] = _clamp_attack_speed(
nxt.get("boss_attack_speed"), summary.get("round") or 0)
# boss attack pattern (also accepted at the top level)
pattern = nxt.get("boss_pattern", decision.get("boss_pattern"))
if pattern in ALL_PATTERN_IDS:
merged["boss_pattern"] = pattern
# enemy mix
wave = nxt.get("wave")
if isinstance(wave, dict):
clean = {}
for k in MINION_TYPES:
try:
clean[k] = max(0.0, float(wave.get(k, 0.0)))
except Exception:
clean[k] = 0.0
if sum(clean.values()) > 0:
merged["wave"] = clean
safe["next_round"] = merged
# per-player blessings (cap at 3, roster + id checked)
if isinstance(decision.get("blessings"), dict):
blessings = {k: v for k, v in decision["blessings"].items()
if k in valid_ids and v in BLESSINGS}
safe["blessings"] = dict(list(blessings.items())[:3])
# level-up card pool
pool = decision.get("card_pool")
if isinstance(pool, list):
valid = [c for c in pool if c in ALL_CARD_IDS]
if valid:
safe["card_pool"] = valid
if isinstance(decision.get("message"), str):
safe["message"] = decision["message"][:120]
if isinstance(decision.get("reasoning"), str):
safe["reasoning"] = decision["reasoning"][:300]
return safe
def decide(summary: dict, prev_cfg: dict) -> dict:
"""Produce a validated decision and persist an agent trace for the round."""
trace_id = uuid.uuid4().hex[:8]
system, user = SYSTEM_PROMPT, _build_user_prompt(summary, prev_cfg)
raw, source, error = "", "fallback", None
requested_speed = None
t0 = time.time()
try:
raw = _run_model(system, user)
parsed = _extract_json(raw)
if isinstance(parsed, dict) and isinstance(parsed.get("next_round"), dict):
requested_speed = parsed["next_round"].get("boss_attack_speed")
decision = _validate(parsed, summary, prev_cfg) if parsed else _deterministic(summary, prev_cfg)
source = "nemotron" if parsed else "fallback(parse_failed)"
except Exception as e:
error = str(e)
decision = _deterministic(summary, prev_cfg)
latency = round(time.time() - t0, 2)
rnd = summary.get("round") or 0
applied_speed = decision.get("next_round", {}).get("boss_attack_speed")
mercy = {
"floor": _mercy_floor(rnd), "max": MERCY_MAX,
"requested": requested_speed, "applied": applied_speed,
"clamped": requested_speed is not None and requested_speed != applied_speed,
"note": "mercy floor rises with the wave — by ~wave 13 the GM can no "
"longer slow the boss to protect the party",
}
trace = {
"trace_id": trace_id,
"round": summary.get("round"),
"mercy": mercy,
"ts": time.strftime("%Y-%m-%d %H:%M:%S"),
"model": MODEL_ID,
"source": source,
"latency_sec": latency,
"error": error,
"input": {"system": system, "user": user, "round_summary": summary},
"raw_output": raw,
"decision": decision,
}
_persist(trace)
return decision
SKILL_SYSTEM_PROMPT = (
"You are the Game Master AI for HuggingWizards. A wizard asks you to grant a "
"new power-up. Invent ONE balanced skill and reply with ONE JSON object only. "
"Schema (data only — never code):\n"
'{"name": str, "trigger": one of '
+ str(sorted(skillmod.TRIGGERS))
+ ', "n": int (for every_n_attacks), "interval": float (for periodic),'
' "effect": one of '
+ str(sorted(skillmod.EFFECTS))
+ ', "radius": float, "damage": float, "count": int, "amount": float,'
' "slow": float 0-1, "color": "#rrggbb"}\n'
"Pick fields that match the chosen effect. Keep it fun but not overpowered."
)
def _deterministic_skill(prompt: str) -> dict:
"""Keyword-based fallback skill generator (local dev / model unavailable)."""
t = (prompt or "").lower()
frost = any(w in t for w in ("frost", "ice", "freeze", "slow", "cold"))
if any(w in t for w in ("summon", "spirit", "minion", "ally", "pet", "wolf")):
spec = {"name": "Summoned Spirit", "trigger": "every_n_attacks", "n": 7,
"effect": "summon_ally", "count": 1, "color": "#a6ff8c"}
elif any(w in t for w in ("heal", "regen", "life", "restore", "vamp", "mend")):
spec = {"name": "Mending Light", "trigger": "periodic", "interval": 6,
"effect": "heal", "amount": 30, "color": "#6effd0"}
elif any(w in t for w in ("shield", "barrier", "ward", "protect", "absorb", "armor")):
spec = {"name": "Arcane Bulwark", "trigger": "periodic", "interval": 8,
"effect": "shield", "amount": 30, "color": "#9db4c9"}
elif frost or any(w in t for w in ("explos", "blast", "aoe", "area", "detonat", "fire")):
spec = {"name": "Frost Nova" if frost else "Arcane Detonation",
"trigger": "every_n_attacks", "n": 6, "effect": "aoe_damage",
"radius": 140, "damage": 45, "slow": 0.4 if frost else 0.0,
"color": "#7ee0ff" if frost else "#ff7a4a"}
elif any(w in t for w in ("nova", "spread", "shotgun", "burst", "ring", "radial")):
spec = {"name": "Star Burst", "trigger": "every_n_attacks", "n": 5,
"effect": "projectile_nova", "count": 10, "damage": 24, "color": "#ffd45e"}
else: # default: an AoE explosion
spec = {"name": "Arcane Detonation", "trigger": "every_n_attacks", "n": 6,
"effect": "aoe_damage", "radius": 140, "damage": 45, "slow": 0.0,
"color": "#ff7a4a"}
return spec
def generate_skill(prompt: str, context: dict | None = None) -> dict:
"""Turn a player's free-text wish into a validated, safe skill spec.
Tries Nemotron; always falls back to deterministic keywords; the result is
run through skills.validate_skill so it is bounded and code-free.
"""
trace_id = uuid.uuid4().hex[:8]
user = f"Wizard's request: {prompt!r}\nContext: {json.dumps(context or {})}\nReturn the skill JSON."
raw, source, error = "", "fallback", None
t0 = time.time()
spec = None
try:
raw = _run_model(SKILL_SYSTEM_PROMPT, user)
parsed = _extract_json(raw)
spec = skillmod.validate_skill(parsed) if parsed else None
source = "nemotron" if spec else "fallback(parse_failed)"
except Exception as e:
error = str(e)
if spec is None:
spec = skillmod.validate_skill(_deterministic_skill(prompt))
latency = round(time.time() - t0, 2)
_persist({
"trace_id": trace_id, "round": (context or {}).get("round"),
"ts": time.strftime("%Y-%m-%d %H:%M:%S"), "model": MODEL_ID,
"source": source, "latency_sec": latency, "error": error,
"kind": "skill_request",
"input": {"system": SKILL_SYSTEM_PROMPT, "user": user, "prompt": prompt},
"raw_output": raw, "decision": spec,
})
return spec
def _persist(trace: dict):
RECENT_TRACES.insert(0, trace)
del RECENT_TRACES[_MAX_RECENT:]
try:
rnd = trace.get("round")
prefix = f"round_{rnd:03d}" if isinstance(rnd, int) else "skill"
path = os.path.join(TRACE_DIR, f"{prefix}_{trace['trace_id']}.json")
with trace_store.lock():
with open(path, "w", encoding="utf-8") as f:
json.dump(trace, f, indent=2)
except Exception as e: # pragma: no cover
print(f"[gamemaster] failed to write trace: {e}")