"""Modal backend for AI Prof — the brain (Nemotron 3 Nano) served over vLLM. OpenAI-compatible endpoint that the Gradio app points ``BRAIN_BASE_URL`` at. The GPU container scales to zero when idle. Nemotron 3 Nano is a hybrid Mamba-2 + MoE *reasoning* model. On first run vLLM JIT-compiles CUDA kernels — notably FlashInfer's CUTLASS fused-MoE kernel (slow, minutes). We persist those compile caches to Volumes and warm ONCE, so every later cold start reuses them (~tens of seconds) instead of recompiling on an expensive GPU. Cost discipline: ``max_containers=1`` so a request burst can never fan out into multiple GPUs, and warming is a single controlled ``modal run`` — never a curl storm against the live endpoint. Bring-up: modal run modal_app.py::download_model # 1. pull weights to a Volume (CPU, cheap) modal run modal_app.py::warm # 2. ONE GPU run: compile + cache kernels modal deploy modal_app.py # 3. serve; cold starts now reuse the cache # the printed *.modal.run URL + "/v1" is BRAIN_BASE_URL """ from __future__ import annotations import json import subprocess import time import urllib.request import modal # --- what to serve ----------------------------------------------------------- MODEL_NAME = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16" SERVED_NAME = "nemotron-3-nano" # must match BRAIN_MODEL in the app's .env GPU = "H100" MAX_MODEL_LEN = 16384 # ample for slide reading + outline + history; small KV cache VLLM_PORT = 8000 MINUTES = 60 app = modal.App("ai-prof-brain") # Persistent caches: weights, vLLM torch.compile artifacts, and FlashInfer's JIT # kernels. Mounting the FlashInfer cache is what stops the CUTLASS MoE kernel from # recompiling on every cold start. hf_cache = modal.Volume.from_name("ai-prof-hf-cache", create_if_missing=True) vllm_cache = modal.Volume.from_name("ai-prof-vllm-cache", create_if_missing=True) flashinfer_cache = modal.Volume.from_name("ai-prof-flashinfer-cache", create_if_missing=True) triton_cache = modal.Volume.from_name("ai-prof-triton-cache", create_if_missing=True) VOLUMES = { "/root/.cache/huggingface": hf_cache, "/root/.cache/vllm": vllm_cache, "/root/.cache/flashinfer": flashinfer_cache, "/root/.triton": triton_cache, # Mamba2 SSD Triton kernel cache (~11 min compile) } # CUDA *devel* base ships nvcc — Nemotron-H's Mamba-2 kernels JIT-compile at init. vllm_image = ( modal.Image.from_registry("nvidia/cuda:12.8.1-devel-ubuntu22.04", add_python="3.12") .entrypoint([]) .pip_install("vllm>=0.12.0", "huggingface_hub[hf_transfer]>=0.27") .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) ) def _vllm_cmd() -> list[str]: # No --enforce-eager: we WANT torch.compile + CUDA-graph capture so the # artifacts get cached to the vLLM volume (FAST_BOOT=False pattern). return [ "vllm", "serve", MODEL_NAME, "--served-model-name", SERVED_NAME, "--host", "0.0.0.0", "--port", str(VLLM_PORT), "--max-model-len", str(MAX_MODEL_LEN), "--max-num-seqs", "8", "--tensor-parallel-size", "1", # Modal Volumes mount as 9P, which vLLM does not recognize as a network # filesystem. Force parallel prefetch instead of reading 13 shards # serially; the serial path takes roughly ten minutes for this checkpoint. "--safetensors-load-strategy", "prefetch", "--trust-remote-code", "--reasoning-parser", "nemotron_v3", ] def _wait_healthy(timeout_s: int = 25 * MINUTES) -> None: base = f"http://127.0.0.1:{VLLM_PORT}" deadline = time.time() + timeout_s while time.time() < deadline: try: urllib.request.urlopen(f"{base}/health", timeout=5) return except Exception: time.sleep(5) raise TimeoutError("vLLM did not become healthy in time") @app.function( image=vllm_image, volumes={"/root/.cache/huggingface": hf_cache}, timeout=60 * MINUTES, ) def download_model(model_name: str = MODEL_NAME) -> None: """Pull weights to the Volume on CPU (no GPU billed during the big download).""" from huggingface_hub import snapshot_download print(f"Downloading {model_name} -> /root/.cache/huggingface ...") snapshot_download(model_name, ignore_patterns=["*.pt", "*.pth"]) hf_cache.commit() print("Done.") @app.function(image=vllm_image, gpu=GPU, volumes=VOLUMES, timeout=60 * MINUTES) def warm() -> None: """One controlled GPU run: boot vLLM, fire a single request to trigger every kernel compile, then commit the compile caches. Run with ``modal run``.""" proc = subprocess.Popen(_vllm_cmd()) try: print("Waiting for vLLM to compile + become healthy (first time is slow)...") _wait_healthy() req = urllib.request.Request( f"http://127.0.0.1:{VLLM_PORT}/v1/chat/completions", data=json.dumps( { "model": SERVED_NAME, "messages": [{"role": "user", "content": "Say hello."}], "max_tokens": 32, } ).encode(), headers={"Content-Type": "application/json"}, ) print("Response:", urllib.request.urlopen(req, timeout=120).read().decode()[:400]) finally: proc.terminate() try: proc.wait(timeout=30) except Exception: proc.kill() vllm_cache.commit() flashinfer_cache.commit() triton_cache.commit() print("Warm complete — compile caches committed. Cold starts will now be fast.") @app.function( image=vllm_image, gpu=GPU, volumes=VOLUMES, scaledown_window=20 * MINUTES, # keep warm between requests; cold start is ~2 min timeout=60 * MINUTES, max_containers=1, # cost guard: a request burst can never fan out into many GPUs ) @modal.concurrent(max_inputs=8) # vLLM batches concurrent requests on the one replica @modal.web_server(port=VLLM_PORT, startup_timeout=30 * MINUTES) def serve() -> None: print("Launching:", " ".join(_vllm_cmd())) subprocess.Popen(_vllm_cmd())