""" RF-DETR Realtime Demo — gradio.Server + WebRTC edition. Per-session GPU pipeline: - One @spaces.GPU(duration=60) **generator** holds the GPU window for the whole session and runs in ZeroGPU's worker subprocess. - The parent feeds (frame, params) tuples to the child via a per-session multiprocessing.Queue (proper IPC across the fork boundary). - The generator yields annotated frames back, one per IPC roundtrip. - A worker thread in the parent iterates `for annotated in gen:` and drops each frame into a per-session thread-local queue.Queue. - `AnnotatedTrack.recv()` (the outbound WebRTC video track) pulls from that thread-local queue. Multi-user safe: every state object lives inside a per-session dataclass keyed by an unguessable session_id; no module-level mutable state crosses session boundaries. """ import asyncio import contextvars import fractions import io import json import multiprocessing import os import queue import tempfile import threading import time import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional import httpx import numpy as np import supervision as sv import torch from aiortc import ( RTCConfiguration, RTCDataChannel, RTCIceServer, RTCPeerConnection, RTCSessionDescription, VideoStreamTrack, ) from aiortc.mediastreams import MediaStreamError from av import VideoFrame from fastapi import File, Form, UploadFile from fastapi.responses import FileResponse, HTMLResponse, JSONResponse from gradio import Server from PIL import Image try: import spaces IS_ZERO_GPU = True print("ZeroGPU environment detected") except ImportError: IS_ZERO_GPU = False print("Running without ZeroGPU") class _NoSpaces: @staticmethod def GPU(*args, **kwargs): if args and callable(args[0]): return args[0] def deco(fn): return fn return deco spaces = _NoSpaces() # type: ignore from transformers import ( AutoImageProcessor, AutoModelForInstanceSegmentation, AutoModelForObjectDetection, ) # --- Configuration --- WEBCAM_MAX_SHORT_EDGE = 384 SESSION_GPU_BUDGET_SECONDS = 60 TASK_DETECTION = 0 TASK_SEGMENTATION = 1 VIDEO_CLOCK_RATE = 90000 VIDEO_TIME_BASE = fractions.Fraction(1, VIDEO_CLOCK_RATE) def _get_device() -> str: if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" return "cpu" DEVICE = _get_device() DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 if DEVICE == "cuda": torch.backends.cudnn.benchmark = True BOX_ANNOTATOR = sv.BoxAnnotator(thickness=2) LABEL_ANNOTATOR = sv.LabelAnnotator(text_scale=0.6, text_thickness=1, text_padding=4) MASK_ANNOTATOR = sv.MaskAnnotator(opacity=0.35) print("Loading RF-DETR models...") DET_PROCESSOR = AutoImageProcessor.from_pretrained( "Roboflow/rf-detr-medium", revision="refs/pr/1", device=DEVICE, ) DET_MODEL = AutoModelForObjectDetection.from_pretrained( "Roboflow/rf-detr-medium", torch_dtype=DTYPE, attn_implementation="sdpa", ).to(DEVICE).eval() WEBCAM_SEG_PROCESSOR = AutoImageProcessor.from_pretrained( "Roboflow/rf-detr-seg-medium", revision="refs/pr/1", device=DEVICE, ) WEBCAM_SEG_MODEL = AutoModelForInstanceSegmentation.from_pretrained( "Roboflow/rf-detr-seg-medium", torch_dtype=DTYPE, attn_implementation="sdpa", ).to(DEVICE).eval() UPLOAD_SEG_PROCESSOR = AutoImageProcessor.from_pretrained( "Roboflow/rf-detr-seg-medium", device=DEVICE, ) UPLOAD_SEG_MODEL = AutoModelForInstanceSegmentation.from_pretrained( "Roboflow/rf-detr-seg-medium", torch_dtype=DTYPE, attn_implementation="sdpa", ).to(DEVICE).eval() DET_ID2LABEL = DET_MODEL.config.id2label WEBCAM_SEG_ID2LABEL = WEBCAM_SEG_MODEL.config.id2label UPLOAD_SEG_ID2LABEL = UPLOAD_SEG_MODEL.config.id2label print("Models loaded.") def _downscale(image_array: np.ndarray, max_short_edge: int) -> np.ndarray: h, w = image_array.shape[:2] short_edge = min(h, w) if short_edge <= max_short_edge: return image_array scale = max_short_edge / short_edge new_w, new_h = int(w * scale), int(h * scale) return np.array(Image.fromarray(image_array).resize((new_w, new_h), Image.BILINEAR)) def _run_detection(image_array: np.ndarray, score_threshold: float) -> np.ndarray: inputs = DET_PROCESSOR(images=image_array, return_tensors="pt").to(DEVICE) with torch.inference_mode(): outputs = DET_MODEL(**inputs) h, w = image_array.shape[:2] results = DET_PROCESSOR.post_process_object_detection( outputs, threshold=score_threshold, target_sizes=torch.tensor([[h, w]], device=DEVICE), )[0] detections = sv.Detections.from_transformers( transformers_results=results, id2label=DET_ID2LABEL, ) if len(detections) == 0: return image_array labels = [ f"{DET_ID2LABEL[cid]} {c:.2f}" for cid, c in zip(detections.class_id, detections.confidence) ] scene = BOX_ANNOTATOR.annotate(scene=image_array.copy(), detections=detections) return LABEL_ANNOTATOR.annotate(scene=scene, detections=detections, labels=labels) def _run_segmentation( image_array: np.ndarray, processor, model, id2label: dict, score_threshold: float, mask_threshold: float, ) -> np.ndarray: inputs = processor(images=image_array, return_tensors="pt").to(DEVICE) with torch.inference_mode(): outputs = model(**inputs) h, w = image_array.shape[:2] result = processor.post_process_instance_segmentation( outputs, threshold=score_threshold, mask_threshold=mask_threshold, target_sizes=[(h, w)], )[0] segments = result["segments_info"] if not segments: return image_array segmentation_map = result["segmentation"].cpu().numpy() n = len(segments) masks = np.empty((n, h, w), dtype=bool) confidences = np.empty(n, dtype=np.float32) class_ids = np.empty(n, dtype=int) boxes = np.empty((n, 4), dtype=np.float32) keep = 0 for seg in segments: binary_mask = segmentation_map == seg["id"] ys, xs = np.where(binary_mask) if ys.size == 0: continue masks[keep] = binary_mask confidences[keep] = seg["score"] class_ids[keep] = seg["label_id"] boxes[keep] = [xs.min(), ys.min(), xs.max(), ys.max()] keep += 1 if keep == 0: return image_array detections = sv.Detections( xyxy=boxes[:keep], mask=masks[:keep], confidence=confidences[:keep], class_id=class_ids[:keep], ) labels = [ f"{id2label[cid]} {c:.2f}" for cid, c in zip(detections.class_id, detections.confidence) ] scene = MASK_ANNOTATOR.annotate(scene=image_array.copy(), detections=detections) scene = BOX_ANNOTATOR.annotate(scene=scene, detections=detections) return LABEL_ANNOTATOR.annotate(scene=scene, detections=detections, labels=labels) # --- Per-session GPU generator --- # multiprocessing.Queue / Event cannot be pickled across spaces' IPC; they # must be inherited via fork. So we build a fresh @spaces.GPU generator per # session, with the queue + event captured by closure. spaces forks the # worker subprocess on first call, and the closure is inherited. def make_session_gpu_generator(input_mp_queue, stop_mp_event): @spaces.GPU(duration=SESSION_GPU_BUDGET_SECONDS) def gpu_streaming_generator(): while not stop_mp_event.is_set(): try: job = input_mp_queue.get(timeout=0.1) except queue.Empty: continue except Exception: return if job is None: return arr, task, score_th, mask_th = job if task == TASK_DETECTION: annotated = _run_detection(arr, score_th) else: annotated = _run_segmentation( arr, WEBCAM_SEG_PROCESSOR, WEBCAM_SEG_MODEL, WEBCAM_SEG_ID2LABEL, score_th, mask_th, ) yield annotated return gpu_streaming_generator def webcam_worker(session: "Session"): """Consumes the GPU generator and pushes annotated frames into the per-session output queue so the outbound WebRTC track can pick them up. """ try: gen_fn = make_session_gpu_generator(session.input_mp_queue, session.stop_mp_event) for annotated in gen_fn(): if session.stop_event.is_set(): break # Drop stale outputs so the encoder always sees the freshest frame. while not session.frame_out.empty(): try: session.frame_out.get_nowait() except queue.Empty: break try: session.frame_out.put_nowait(annotated) except queue.Full: pass except Exception as e: print(f"[{session.sid[:8]}] worker error: {e!r}") finally: session.stop_event.set() session.stop_mp_event.set() # Best-effort notify the client that the session ended. _notify_session_ended(session) def _notify_session_ended(session: "Session"): ch = session.data_channel if ch is None: return msg = json.dumps({"type": "session_ended", "reason": "stopped"}) def _send(): try: if ch.readyState == "open": ch.send(msg) except Exception: pass try: session.loop.call_soon_threadsafe(_send) except Exception: pass # --- Cloudflare TURN credentials via HF's hosted proxy --- CLOUDFLARE_FASTRTC_TURN_URL = "https://turn.fastrtc.org/credentials" _FALLBACK_STUN = [RTCIceServer(urls=["stun:stun.cloudflare.com:3478"])] async def _fetch_ice_servers() -> List[RTCIceServer]: hf_token = os.getenv("HF_TOKEN", "").strip() if not hf_token: return list(_FALLBACK_STUN) try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get( CLOUDFLARE_FASTRTC_TURN_URL, headers={"Authorization": f"Bearer {hf_token}"}, params={"ttl": 600}, ) resp.raise_for_status() payload = resp.json() except Exception as e: print(f"[turn] credential fetch failed: {e!r}; falling back to STUN") return list(_FALLBACK_STUN) out: List[RTCIceServer] = [] for s in payload.get("iceServers", []): urls = s.get("urls") if not urls: continue out.append(RTCIceServer( urls=urls, username=s.get("username"), credential=s.get("credential"), )) return out or list(_FALLBACK_STUN) # --- Per-session state --- @dataclass class Session: sid: str pc: RTCPeerConnection loop: asyncio.AbstractEventLoop input_mp_queue: "multiprocessing.Queue" frame_out: "queue.Queue" stop_event: threading.Event stop_mp_event: "multiprocessing.synchronize.Event" params: dict params_lock: threading.Lock worker_thread: Optional[threading.Thread] = None input_pump_task: Optional[asyncio.Task] = None data_channel: Optional[RTCDataChannel] = None _sessions: Dict[str, Session] = {} _sessions_lock = threading.Lock() # --- Pump inbound webcam track into the IPC queue --- async def _pump_input(track: VideoStreamTrack, session: Session): while not session.stop_event.is_set(): try: frame = await track.recv() except MediaStreamError: return except Exception as e: print(f"[{session.sid[:8]}] input pump error: {e!r}") return try: arr = frame.to_ndarray(format="rgb24") except Exception: continue arr = _downscale(arr, WEBCAM_MAX_SHORT_EDGE) with session.params_lock: job = ( arr, int(session.params["task"]), float(session.params["score_threshold"]), float(session.params["mask_threshold"]), ) # Drop any stale item so the child always processes the freshest frame. try: while not session.input_mp_queue.empty(): session.input_mp_queue.get_nowait() except Exception: pass try: session.input_mp_queue.put_nowait(job) except Exception: pass # --- Outbound annotated track --- class AnnotatedTrack(VideoStreamTrack): kind = "video" def __init__(self, session: Session): super().__init__() self.session = session self._start: Optional[float] = None self._last: Optional[np.ndarray] = None async def recv(self): loop = asyncio.get_running_loop() while True: if self.session.stop_event.is_set(): raise MediaStreamError try: arr = await loop.run_in_executor( None, lambda: self.session.frame_out.get(timeout=0.1), ) self._last = arr break except queue.Empty: if self._last is not None: arr = self._last break continue if self._start is None: self._start = time.monotonic() pts = int((time.monotonic() - self._start) * VIDEO_CLOCK_RATE) new_frame = VideoFrame.from_ndarray(arr, format="rgb24") new_frame.pts = pts new_frame.time_base = VIDEO_TIME_BASE return new_frame # --- App / routes --- app = Server() @app.get("/", response_class=HTMLResponse) async def homepage(): p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") with open(p, "r", encoding="utf-8") as f: return f.read() @app.post("/api/webrtc/offer") async def webrtc_offer(body: dict): sid = (body or {}).get("session_id") or str(uuid.uuid4()) sdp = (body or {}).get("sdp") sdp_type = (body or {}).get("type") or "offer" params_in = (body or {}).get("params") or {} if not sdp: return JSONResponse({"error": "missing sdp"}, status_code=400) # Tear down any prior session with the same id (e.g. page reload). with _sessions_lock: prior = _sessions.pop(sid, None) if prior is not None: await _teardown(prior) ice_servers = await _fetch_ice_servers() pc = RTCPeerConnection(configuration=RTCConfiguration(iceServers=ice_servers)) loop = asyncio.get_running_loop() session = Session( sid=sid, pc=pc, loop=loop, input_mp_queue=multiprocessing.Queue(maxsize=2), frame_out=queue.Queue(maxsize=2), stop_event=threading.Event(), stop_mp_event=multiprocessing.Event(), params={ "task": int(params_in.get("task", TASK_DETECTION)), "score_threshold": float(params_in.get("score_threshold", 0.4)), "mask_threshold": float(params_in.get("mask_threshold", 0.5)), }, params_lock=threading.Lock(), ) with _sessions_lock: _sessions[sid] = session # Outbound annotated track; added before setRemoteDescription so it's # negotiated in the answer SDP. pc.addTrack(AnnotatedTrack(session)) @pc.on("track") def on_track(track): if track.kind != "video": return session.input_pump_task = asyncio.create_task(_pump_input(track, session)) @track.on("ended") async def _ended(): session.stop_event.set() session.stop_mp_event.set() @pc.on("datachannel") def on_dc(channel: RTCDataChannel): if channel.label != "control": return session.data_channel = channel @channel.on("message") def on_msg(raw): if not isinstance(raw, str): return try: msg = json.loads(raw) except Exception: return if not isinstance(msg, dict) or msg.get("type") != "params": return p = msg.get("params") or {} with session.params_lock: for key, conv in ( ("task", int), ("score_threshold", float), ("mask_threshold", float), ): if key in p: try: session.params[key] = conv(p[key]) except (TypeError, ValueError): pass @pc.on("connectionstatechange") async def on_state(): print(f"[{sid[:8]}] pc state: {pc.connectionState}") if pc.connectionState in ("failed", "closed", "disconnected"): await _teardown(session) await pc.setRemoteDescription(RTCSessionDescription(sdp=sdp, type=sdp_type)) answer = await pc.createAnswer() await pc.setLocalDescription(answer) # Start the worker thread (drives the @spaces.GPU generator). # contextvars.copy_context() carries the request's ZeroGPU LocalContext # into the worker thread so @spaces.GPU can allocate against the right # caller. ctx = contextvars.copy_context() worker = threading.Thread( target=ctx.run, args=(webcam_worker, session), daemon=True, ) session.worker_thread = worker worker.start() return JSONResponse({ "session_id": sid, "sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "budget_seconds": SESSION_GPU_BUDGET_SECONDS, }) async def _teardown(session: Session): session.stop_event.set() session.stop_mp_event.set() with _sessions_lock: _sessions.pop(session.sid, None) try: session.input_mp_queue.put_nowait(None) # sentinel except Exception: pass try: await session.pc.close() except Exception: pass @app.get("/api/webrtc/ice-servers") async def webrtc_ice_servers(): """Return ICE servers (incl. Cloudflare TURN) in browser RTCConfiguration shape. The browser needs TURN to relay through when neither side has a public IP — without this it'll only have STUN and ICE will fail on most real-world networks. """ hf_token = os.getenv("HF_TOKEN", "").strip() if not hf_token: return JSONResponse({"iceServers": [ {"urls": ["stun:stun.cloudflare.com:3478"]}, ]}) try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get( CLOUDFLARE_FASTRTC_TURN_URL, headers={"Authorization": f"Bearer {hf_token}"}, params={"ttl": 600}, ) resp.raise_for_status() payload = resp.json() except Exception as e: print(f"[turn] credential fetch failed: {e!r}") return JSONResponse({"iceServers": [ {"urls": ["stun:stun.cloudflare.com:3478"]}, ]}) return JSONResponse(payload) @app.post("/api/webrtc/stop") async def webrtc_stop(body: dict): sid = (body or {}).get("session_id", "") with _sessions_lock: s = _sessions.pop(sid, None) if s is not None: await _teardown(s) return JSONResponse({"status": "stopped"}) # --- Upload endpoints (image + video) --- @spaces.GPU(duration=30) def _gpu_infer_image( arr: np.ndarray, task: str, score_threshold: float, mask_threshold: float, ) -> np.ndarray: if task == "detection": return _run_detection(arr, score_threshold) return _run_segmentation( arr, UPLOAD_SEG_PROCESSOR, UPLOAD_SEG_MODEL, UPLOAD_SEG_ID2LABEL, score_threshold, mask_threshold, ) @app.post("/api/infer_image") async def infer_image( image: UploadFile = File(...), task: str = Form("detection"), score_threshold: float = Form(0.4), mask_threshold: float = Form(0.5), ): data = await image.read() img = Image.open(io.BytesIO(data)).convert("RGB") arr = np.array(img) annotated = _gpu_infer_image(arr, task, float(score_threshold), float(mask_threshold)) out_path = str(Path(tempfile.mkdtemp(prefix="rf-detr-img-")) / "annotated.png") Image.fromarray(annotated).save(out_path) return FileResponse(out_path, media_type="image/png") @spaces.GPU(duration=60) def _gpu_process_video( in_path: str, out_path: str, task: str, score_threshold: float, mask_threshold: float, ): import cv2 cap = cv2.VideoCapture(in_path) if not cap.isOpened(): raise RuntimeError("Could not open video") fps = int(cap.get(cv2.CAP_PROP_FPS) or 30) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) try: while True: ok, frame = cap.read() if not ok: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if task == "detection": ann = _run_detection(frame_rgb, score_threshold) else: ann = _run_segmentation( frame_rgb, UPLOAD_SEG_PROCESSOR, UPLOAD_SEG_MODEL, UPLOAD_SEG_ID2LABEL, score_threshold, mask_threshold, ) writer.write(cv2.cvtColor(ann, cv2.COLOR_RGB2BGR)) finally: cap.release() writer.release() @app.post("/api/infer_video") async def infer_video( video: UploadFile = File(...), task: str = Form("detection"), score_threshold: float = Form(0.4), mask_threshold: float = Form(0.5), ): in_dir = tempfile.mkdtemp(prefix="rf-detr-vid-in-") in_path = str(Path(in_dir) / (video.filename or "input.mp4")) with open(in_path, "wb") as f: f.write(await video.read()) out_path = str(Path(tempfile.mkdtemp(prefix="rf-detr-vid-")) / "annotated.mp4") _gpu_process_video(in_path, out_path, task, float(score_threshold), float(mask_threshold)) return FileResponse(out_path, media_type="video/mp4") if IS_ZERO_GPU: spaces.GPU(lambda: None) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860, show_error=True)