import json import os import time from pathlib import Path from typing import Any try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*args, **kwargs): def _decorator(fn): return fn return _decorator spaces = _SpacesFallback() os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") os.environ.setdefault("HF_HOME", "/tmp/huggingface") os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules") os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") os.environ.setdefault("FORCE_QWENVL_VIDEO_READER", "torchcodec") os.environ.setdefault("VIDEO_MAX_PIXELS", "200704") os.environ.setdefault("FPS", "2.0") os.environ.setdefault("FPS_MAX_FRAMES", "240") os.environ.setdefault("FPS_MIN_FRAMES", "4") import gradio as gr import torch from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from gradio import Server from gradio_client.data_classes import FileData APP_ROOT = Path(__file__).resolve().parent INDEX_PATH = APP_ROOT / "index.html" EXAMPLES_DIR = APP_ROOT / "examples" EXAMPLE_STATIC_CACHE_DIR = APP_ROOT / "example_cache" EXAMPLE_CACHE_DIR = Path("/tmp/marlin_example_cache") MODEL_ID = "NemoStation/Marlin-2B" MODEL_DESCRIPTION = ( "Marlin-2B is a 2B video VLM for dense clip captions and timestamped event search." ) EXAMPLES = [ { "title": "Bowery Waltz", "file": "bowery_waltz.mp4", "task": "Caption", "event": "", "tokens": 768, "temperature": 0.0, }, { "title": "Skateboard shove-it", "file": "shove_it.mp4", "task": "Caption", "event": "", "tokens": 768, "temperature": 0.0, }, { "title": "Coyote run", "file": "coyote_run.mp4", "task": "Caption", "event": "", "tokens": 768, "temperature": 0.0, }, ] torch.set_grad_enabled(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def _load_model(): if os.getenv("MARLIN_SKIP_LOAD") == "1": return None, "Model loading skipped by MARLIN_SKIP_LOAD=1." from transformers import AutoModelForCausalLM started = time.perf_counter() model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, dtype=torch.bfloat16, attn_implementation="sdpa", low_cpu_mem_usage=True, ) model = model.to("cuda").eval() _ = model.processor elapsed = time.perf_counter() - started print(f"Loaded {MODEL_ID} in {elapsed:.1f}s.", flush=True) return model, MODEL_DESCRIPTION try: MARLIN, LOAD_STATUS = _load_model() except Exception as exc: MARLIN = None LOAD_STATUS = f"Model failed to load: {type(exc).__name__}: {exc}" print(LOAD_STATUS, flush=True) def _video_path(video: Any) -> str: if video is None: raise ValueError("Upload a video or choose an example first.") if isinstance(video, str): return video if hasattr(video, "path"): return str(video.path) if isinstance(video, dict): for key in ("path", "name"): value = video.get(key) if value: return str(value) if isinstance(video, (tuple, list)) and video: return str(video[0]) raise ValueError("Could not read the provided video path.") def _duration_seconds(path: str) -> float | None: try: import av with av.open(path) as container: if container.duration: return float(container.duration / av.time_base) stream = next((s for s in container.streams if s.type == "video"), None) if stream and stream.duration and stream.time_base: return float(stream.duration * stream.time_base) except Exception: return None return None def _clean_events(events: list[dict[str, Any]]) -> list[dict[str, Any]]: out = [] for event in events: try: start = float(event["start"]) end = float(event["end"]) except (KeyError, TypeError, ValueError): continue if end <= start: continue out.append( { "start": start, "end": end, "description": str(event.get("description", "")).strip(), } ) return out def _status_line(task: str, elapsed: float, video_duration: float | None, detail: str) -> str: video_part = f"{video_duration:.1f}s video" if video_duration else "video duration unavailable" return f"{task} completed in {elapsed:.1f}s on GPU | {video_part} | {detail}" def _estimate_duration(video, task, event_query, max_caption_tokens, temperature, *args, **kwargs): if task == "Find": return 75 tokens = int(max_caption_tokens or 768) return min(180, 90 + max(0, tokens - 512) // 12) def _run_model(video_path: str, task: str, event_query: str, max_caption_tokens: int, temperature: float): if MARLIN is None: raise RuntimeError(LOAD_STATUS) task_name = task or "Caption" do_sample = float(temperature or 0.0) > 0 video_duration = _duration_seconds(video_path) started = time.perf_counter() if task_name == "Find": query = (event_query or "").strip() if not query: raise ValueError("Enter an event to find in the video.") result = MARLIN.find( video_path, event=query, max_new_tokens=64, do_sample=do_sample, temperature=float(temperature or 1.0), ) elapsed = time.perf_counter() - started span = result.get("span") span_payload = None events = [] if span: start, end = float(span[0]), float(span[1]) span_payload = {"start": start, "end": end, "label": query} events = [{"start": start, "end": end, "description": query}] return { "ok": True, "task": "Find", "scene": "", "events": events, "span": span_payload, "raw": result.get("raw", ""), "status": _status_line("Find", elapsed, video_duration, f"format_ok={result.get('format_ok')}"), "duration": video_duration, "timings": {"gpu_seconds": elapsed}, } result = MARLIN.caption( video_path, max_new_tokens=int(max_caption_tokens or 768), do_sample=do_sample, temperature=float(temperature or 1.0), ) elapsed = time.perf_counter() - started events = _clean_events(list(result.get("events") or [])) return { "ok": True, "task": "Caption", "scene": result.get("scene", ""), "events": events, "span": None, "raw": result.get("caption", ""), "status": _status_line("Caption", elapsed, video_duration, f"{len(events)} events parsed"), "duration": video_duration, "timings": {"gpu_seconds": elapsed}, } @spaces.GPU(duration=_estimate_duration) def _analyze_gpu(video, task: str, event_query: str, max_caption_tokens: int, temperature: float): return _run_model(_video_path(video), task, event_query, max_caption_tokens, temperature) @spaces.GPU(duration=1) def _zerogpu_probe() -> str: return "ready" def _static_example_cache_path(index: int) -> Path: return EXAMPLE_STATIC_CACHE_DIR / f"v2_example_{index}.json" def _example_cache_path(index: int) -> Path: return EXAMPLE_CACHE_DIR / f"v2_example_{index}.json" def _read_example_cache(index: int) -> dict[str, Any] | None: for path in (_static_example_cache_path(index), _example_cache_path(index)): if not path.exists(): continue try: return json.loads(path.read_text(encoding="utf-8")) except Exception: continue return None def _write_example_cache(index: int, payload: dict[str, Any]) -> None: EXAMPLE_CACHE_DIR.mkdir(parents=True, exist_ok=True) _example_cache_path(index).write_text(json.dumps(payload), encoding="utf-8") app = Server(title="Marlin-2B - Video Understanding") @app.get("/", response_class=HTMLResponse) async def homepage() -> str: return INDEX_PATH.read_text(encoding="utf-8") @app.get("/health") async def health() -> JSONResponse: return JSONResponse( { "ok": MARLIN is not None, "model": MODEL_ID, "description": MODEL_DESCRIPTION, "examples": len(EXAMPLES), } ) @app.get("/examples.json") async def examples_json() -> JSONResponse: def _example_payload(index: int, example: dict[str, Any]) -> dict[str, Any]: thumbnail = Path(example["file"]).with_suffix(".jpg").name payload = { **example, "url": f"/examples/{example['file']}", "cached": _read_example_cache(index) is not None, } if (EXAMPLES_DIR / thumbnail).exists(): payload["thumbnail"] = f"/examples/{thumbnail}" playback = Path(example["file"]).with_suffix(".webm").name if (EXAMPLES_DIR / playback).exists(): payload["playback_url"] = f"/examples/{playback}" return payload return JSONResponse( { "examples": [ _example_payload(i, example) for i, example in enumerate(EXAMPLES) ] } ) @app.api(name="analyze", concurrency_limit=1, time_limit=180) def analyze_api( video: FileData, task: str = "Caption", event_query: str = "", max_caption_tokens: int = 768, temperature: float = 0.0, ) -> str: try: return json.dumps(_analyze_gpu(video, task, event_query, max_caption_tokens, temperature)) except Exception as exc: print(f"[analyze] {type(exc).__name__}: {exc}", flush=True) return json.dumps({"ok": False, "error": str(exc)}) @app.api(name="analyze_example", concurrency_limit=1, time_limit=180) def analyze_example_api(index: int) -> str: try: example_index = int(index) if example_index < 0 or example_index >= len(EXAMPLES): raise ValueError("Unknown example.") cached = _read_example_cache(example_index) if cached is not None: cached["cached"] = True return json.dumps(cached) example = EXAMPLES[example_index] payload = _analyze_gpu( str(EXAMPLES_DIR / example["file"]), example["task"], example["event"], int(example["tokens"]), float(example["temperature"]), ) payload["cached"] = False _write_example_cache(example_index, payload) return json.dumps(payload) except Exception as exc: print(f"[analyze_example] {type(exc).__name__}: {exc}", flush=True) return json.dumps({"ok": False, "error": str(exc)}) app.mount("/examples", StaticFiles(directory=str(EXAMPLES_DIR)), name="examples") demo = app if __name__ == "__main__": demo.launch(show_error=True, ssr_mode=False)