import os import time from typing import Any import spaces os.environ.setdefault("GRADIO_ANALYTICS_ENABLED", "False") os.environ.setdefault("HF_HOME", "/tmp/huggingface") os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules") os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") os.environ.setdefault("FORCE_QWENVL_VIDEO_READER", "torchcodec") os.environ.setdefault("VIDEO_MAX_PIXELS", "200704") os.environ.setdefault("FPS", "2.0") os.environ.setdefault("FPS_MAX_FRAMES", "240") os.environ.setdefault("FPS_MIN_FRAMES", "4") import gradio as gr import torch from PIL import Image, ImageDraw, ImageFont from transformers import AutoModelForCausalLM MODEL_ID = "NemoStation/Marlin-2B" MODEL_LINK = ( "[Marlin-2B](https://huggingface.co/NemoStation/Marlin-2B) is a 2B video VLM " "for dense clip captions and timestamped event search." ) torch.set_grad_enabled(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True EXAMPLES = [ ["examples/bowery_waltz.mp4", "Caption", "", 768, 0.0], ["examples/shove_it.mp4", "Find", "the skateboard spins under the rider's feet", 256, 0.0], ["examples/coyote_run.mp4", "Find", "the coyote runs across the scene", 256, 0.0], ] def _load_model(): if os.getenv("MARLIN_SKIP_LOAD") == "1": return None, "Model loading skipped by MARLIN_SKIP_LOAD=1." started = time.perf_counter() model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, dtype=torch.bfloat16, attn_implementation="sdpa", low_cpu_mem_usage=True, ) model = model.to("cuda").eval() _ = model.processor elapsed = time.perf_counter() - started print(f"Loaded {MODEL_ID} in {elapsed:.1f}s.", flush=True) return model, MODEL_LINK try: MARLIN, LOAD_STATUS = _load_model() except Exception as exc: # Keep the UI alive so logs and API expose the startup failure. MARLIN = None LOAD_STATUS = f"Model failed to load: {type(exc).__name__}: {exc}" print(LOAD_STATUS, flush=True) def _video_path(video: Any) -> str: if video is None: raise gr.Error("Upload a video or choose an example first.") if isinstance(video, str): return video if isinstance(video, dict): for key in ("path", "name"): value = video.get(key) if value: return str(value) if isinstance(video, (tuple, list)) and video: return str(video[0]) raise gr.Error("Could not read the provided video path.") def _duration_seconds(path: str) -> float | None: try: import av with av.open(path) as container: if container.duration: return float(container.duration / av.time_base) stream = next((s for s in container.streams if s.type == "video"), None) if stream and stream.duration and stream.time_base: return float(stream.duration * stream.time_base) except Exception: return None return None def _event_rows(events: list[dict[str, Any]]) -> list[list[Any]]: rows = [] for event in events: rows.append([ f"{float(event['start']):.1f}", f"{float(event['end']):.1f}", event.get("description", ""), ]) return rows def _timeline( events: list[dict[str, Any]], span: tuple[float, float] | None, duration: float | None, ) -> Image.Image: width, height = 960, 220 left, right = 86, 910 axis_y = 64 bar_h = 20 font = ImageFont.load_default() img = Image.new("RGB", (width, height), "#f8fafc") draw = ImageDraw.Draw(img) max_time = duration or 1.0 for event in events: max_time = max(max_time, float(event["end"])) if span: max_time = max(max_time, span[1]) max_time = max(1.0, max_time) draw.line((left, axis_y, right, axis_y), fill="#334155", width=2) for i in range(6): t = max_time * i / 5 x = left + int((right - left) * i / 5) draw.line((x, axis_y - 6, x, axis_y + 6), fill="#334155", width=1) draw.text((x - 12, axis_y + 12), f"{t:.0f}s", fill="#334155", font=font) y = 104 if span: start, end = span x0 = left + int((right - left) * start / max_time) x1 = left + int((right - left) * end / max_time) draw.rounded_rectangle((x0, y, max(x0 + 4, x1), y + bar_h), radius=5, fill="#ef4444") draw.text((left, y - 18), "Find span", fill="#991b1b", font=font) y += 42 for index, event in enumerate(events[:3], start=1): start, end = float(event["start"]), float(event["end"]) x0 = left + int((right - left) * start / max_time) x1 = left + int((right - left) * end / max_time) color = ["#2563eb", "#16a34a", "#7c3aed"][index - 1] draw.rounded_rectangle((x0, y, max(x0 + 4, x1), y + bar_h), radius=5, fill=color) label = str(event.get("description", ""))[:92] draw.text((left, y - 16), f"Event {index}: {label}", fill="#0f172a", font=font) y += 38 if not events and not span: draw.text((left, 120), "No timestamped events or span parsed yet.", fill="#64748b", font=font) return img def _status_line(task: str, elapsed: float, video_duration: float | None, detail: str) -> str: video_part = f"{video_duration:.1f}s video" if video_duration else "video duration unavailable" return f"{task} completed in {elapsed:.1f}s on GPU | {video_part} | {detail}" def _estimate_duration(video, task, event_query, max_caption_tokens, temperature, *args, **kwargs): if task == "Find": return 75 tokens = int(max_caption_tokens or 768) return min(180, 90 + max(0, tokens - 512) // 12) @spaces.GPU(duration=_estimate_duration) def analyze(video, task: str, event_query: str, max_caption_tokens: int, temperature: float): if MARLIN is None: raise gr.Error(LOAD_STATUS) path = _video_path(video) task_name = task or "Caption" do_sample = float(temperature or 0.0) > 0 video_duration = _duration_seconds(path) started = time.perf_counter() if task_name == "Find": query = (event_query or "").strip() if not query: raise gr.Error("Enter an event to find in the video.") result = MARLIN.find( path, event=query, max_new_tokens=64, do_sample=do_sample, temperature=float(temperature or 1.0), ) elapsed = time.perf_counter() - started span = result.get("span") span_tuple = tuple(span) if span else None span_text = f"{span_tuple[0]:.1f}s to {span_tuple[1]:.1f}s" if span_tuple else "No valid span parsed." timeline = _timeline([], span_tuple, video_duration) status = _status_line(task_name, elapsed, video_duration, f"format_ok={result.get('format_ok')}") return "", [], span_text, timeline, result.get("raw", ""), status result = MARLIN.caption( path, max_new_tokens=int(max_caption_tokens or 768), do_sample=do_sample, temperature=float(temperature or 1.0), ) elapsed = time.perf_counter() - started events = list(result.get("events") or []) rows = _event_rows(events) timeline = _timeline(events, None, video_duration) status = _status_line(task_name, elapsed, video_duration, f"{len(events)} events parsed") return result.get("scene", ""), rows, "", timeline, result.get("caption", ""), status css = """ main, .gradio-container, .contain { max-width: 1280px !important; margin-inline: auto !important; } .output-markdown textarea { font-family: ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; } """ with gr.Blocks(title="Marlin-2B Video Understanding", fill_width=True) as demo: gr.Markdown("# Marlin-2B Video Understanding") gr.Markdown(LOAD_STATUS) with gr.Row(equal_height=True): with gr.Column(scale=5, min_width=320): video_in = gr.Video(label="Video", sources=["upload"], height=330) with gr.Row(): task_in = gr.Radio(["Caption", "Find"], value="Caption", label="Task") temperature_in = gr.Slider(0.0, 0.8, value=0.0, step=0.1, label="Temperature") event_in = gr.Textbox( label="Event to find", placeholder="the skateboard spins under the rider's feet", lines=2, visible=False, ) max_tokens_in = gr.Slider( 128, 2048, value=768, step=128, label="Caption token cap", ) run_btn = gr.Button("Analyze Video", variant="primary") with gr.Column(scale=7, min_width=360): timeline_out = gr.Image(type="pil", label="Timeline", height=240) span_out = gr.Textbox(label="Find span", interactive=False) scene_out = gr.Textbox(label="Scene", lines=5, interactive=False) events_out = gr.Dataframe( headers=["Start", "End", "Description"], datatype=["str", "str", "str"], label="Caption events", row_count=(0, "dynamic"), wrap=True, interactive=False, ) raw_out = gr.Textbox(label="Raw model output", lines=8, interactive=False) status_out = gr.Textbox(label="Status", interactive=False) run_btn.click( analyze, inputs=[video_in, task_in, event_in, max_tokens_in, temperature_in], outputs=[scene_out, events_out, span_out, timeline_out, raw_out, status_out], api_name="analyze", ) task_in.change( lambda task: gr.update(visible=task == "Find"), inputs=task_in, outputs=event_in, queue=False, ) gr.Examples( examples=EXAMPLES, inputs=[video_in, task_in, event_in, max_tokens_in, temperature_in], outputs=[scene_out, events_out, span_out, timeline_out, raw_out, status_out], fn=analyze, label="Examples", examples_per_page=3, cache_examples=True, cache_mode="lazy", ) demo.queue(max_size=12, default_concurrency_limit=1).launch(ssr_mode=False, css=css)