"""app.py — HOI-DETR ZeroGPU demo.
On every cold start this clones (or pulls) the HOI-DETR GitHub repo so any
change pushed to GitHub is reflected automatically without touching this Space.
"""
import os
import sys
import subprocess
import tempfile
import traceback
# ── HfFolder shim (must precede `import spaces` and `import gradio`) ──────────
# gradio 4.44's oauth.py AND the `spaces` package both do
# `from huggingface_hub import HfFolder`, but the container ships
# huggingface_hub 1.x which removed HfFolder. Inject a minimal stand-in. This
# only imports huggingface_hub (no CUDA), so it is safe to run before `spaces`.
import huggingface_hub as _hfh
if not hasattr(_hfh, "HfFolder"):
class _HfFolderShim:
path_token = None
@staticmethod
def get_token():
return (os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN"))
@classmethod
def save_token(cls, token):
pass
@classmethod
def delete_token(cls):
pass
_hfh.HfFolder = _HfFolderShim
del _hfh
# ZeroGPU: `spaces` MUST be imported before torch / any CUDA-related package,
# otherwise it raises "CUDA has been initialized before importing the spaces
# package". The mmcv bootstrap below imports torch, so import spaces first.
import spaces
# ── clone / update HOI-DETR from GitHub ──────────────────────────────────────
REPO = "/home/user/HOI-DETR"
REPO_URL = "https://github.com/AhmadDarKhalil/HOI-DETR.git"
if os.path.isdir(REPO):
subprocess.run(["git", "-C", REPO, "pull", "--ff-only"], check=False)
else:
subprocess.run(["git", "clone", "--depth=1", REPO_URL, REPO], check=True)
sys.path.insert(0, REPO)
sys.path.insert(0, os.path.join(REPO, "demo")) # for `import configs` inside helpers.py
# ── ensure mmcv-full (with CUDA ops) is importable ───────────────────────────
# HF's build phase can't compile mmcv (isolated pip, no Space Variables), so we
# handle it at runtime. Compiling from source takes ~9 min, so we cache the
# built wheel on the Hub keyed by torch/cuda/python and reuse it on later cold
# starts (download ~1 min). MMCV_CACHE_REPO must allow writes via HF_TOKEN.
MMCV_GIT = "git+https://github.com/open-mmlab/mmcv.git@v1.7.2"
MMCV_CACHE_REPO = os.environ.get("MMCV_CACHE_REPO", "ahmaddarkhalil/hoi-detr")
def _ensure_mmcv():
import shutil
import glob
def run(cmd, env=None, check=True):
print("[bootstrap] $", " ".join(cmd), flush=True)
return subprocess.run(cmd, env=env, check=check)
def have_mmcv():
try:
import mmcv # noqa: F811
from mmcv.ops import RoIAlign # noqa: F401 — proves CUDA ops present
print(f"[bootstrap] mmcv {mmcv.__version__} (with ops) ready",
flush=True)
return True
except Exception:
return False
if have_mmcv():
return
import torch
tver = torch.__version__
cuver = (torch.version.cuda or "none")
pytag = f"cp{sys.version_info.major}{sys.version_info.minor}"
# The wheel FILENAME must stay PEP 427-valid for pip to install it; encode
# the torch/cuda/python cache key in the Hub PATH instead.
wheel_basename = f"mmcv_full-1.7.2-{pytag}-{pytag}-linux_x86_64.whl"
cache_key = f"torch{tver.replace('+', '_')}-cu{cuver.replace('.', '')}-{pytag}"
cache_path = f"mmcv_wheels/{cache_key}/{wheel_basename}"
token = os.environ.get("HF_TOKEN")
print(f"[bootstrap] torch={tver} cuda={cuver}; cache={cache_path}", flush=True)
# 1) Try a cached prebuilt wheel from the Hub.
try:
from huggingface_hub import hf_hub_download
whl = hf_hub_download(repo_id=MMCV_CACHE_REPO, filename=cache_path,
token=token)
run([sys.executable, "-m", "pip", "install", whl])
if have_mmcv():
print("[bootstrap] installed cached mmcv wheel", flush=True)
return
except Exception as e:
print(f"[bootstrap] no usable cached wheel ({e!r}); building", flush=True)
# 2) Build from source.
print(f"[bootstrap] system nvcc: {shutil.which('nvcc')}", flush=True)
run(["bash", "-lc", "gcc --version | head -1 || true"], check=False)
print(f"[bootstrap] CUDA_HOME={os.environ.get('CUDA_HOME')}", flush=True)
env = dict(os.environ)
env["MMCV_WITH_OPS"] = "1"
env["FORCE_CUDA"] = "1"
env.setdefault("TORCH_CUDA_ARCH_LIST", "12.0+PTX") # RTX Pro 6000 = sm_120
env.setdefault("MAX_JOBS", "4")
outdir = "/tmp/mmcv_wheel"
run([sys.executable, "-m", "pip", "wheel", "--no-build-isolation",
"--no-deps", "-w", outdir, MMCV_GIT], env=env)
built = sorted(glob.glob(os.path.join(outdir, "mmcv_full-*.whl")))[0]
run([sys.executable, "-m", "pip", "install", built]) # also pulls addict
if not have_mmcv():
raise RuntimeError("mmcv built but import still fails")
print("[bootstrap] built mmcv from source", flush=True)
# 3) Cache the wheel for future cold starts (best-effort).
try:
if token:
from huggingface_hub import upload_file
upload_file(path_or_fileobj=built, path_in_repo=cache_path,
repo_id=MMCV_CACHE_REPO, token=token,
commit_message="cache mmcv-full wheel")
print(f"[bootstrap] cached wheel -> {MMCV_CACHE_REPO}/{cache_path}",
flush=True)
except Exception as e:
print(f"[bootstrap] wheel cache upload skipped ({e!r})", flush=True)
_ensure_mmcv()
# ── mmdet version-gate patch ──────────────────────────────────────────────────
# mmdet 2.25.3 asserts mmcv < 1.6.0; we run 1.7.2 (PyTorch 2.x support).
import mmcv as _mmcv_mod
_real_mmcv_ver = _mmcv_mod.__version__
_mmcv_mod.__version__ = "1.5.0"
import mmdet # noqa: F401 — version gate reads "1.5.0" here, passes
_mmcv_mod.__version__ = _real_mmcv_ver
del _mmcv_mod, _real_mmcv_ver
# ─────────────────────────────────────────────────────────────────────────────
import math
import cv2
import mmcv
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download
# ── defensive gradio api_info guards (no-ops on modern gradio) ───────────────
# Insurance against gradio_client schema-walk crashes on odd JSON-schema nodes
# (seen on old gradio: "argument of type 'bool' is not iterable",
# "unhashable type: 'dict'"). Each guard is wrapped so it is a safe no-op if the
# internals differ in the installed gradio version.
try:
import gradio_client.utils as _gcu
if hasattr(_gcu, "get_type"):
_orig_get_type = _gcu.get_type
def _safe_get_type(schema):
if not isinstance(schema, dict):
return "Any"
return _orig_get_type(schema)
_gcu.get_type = _safe_get_type
if hasattr(_gcu, "_json_schema_to_python_type"):
_orig_json_to_py = _gcu._json_schema_to_python_type
def _safe_json_to_py(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig_json_to_py(schema, defs)
_gcu._json_schema_to_python_type = _safe_json_to_py
except Exception as _e: # noqa: BLE001
print(f"[patch] gradio_client schema guard skipped: {_e!r}", flush=True)
try:
import gradio.blocks as _gb
if hasattr(_gb.Blocks, "get_api_info"):
_orig_get_api_info = _gb.Blocks.get_api_info
def _safe_get_api_info(self, *args, **kwargs):
try:
return _orig_get_api_info(self, *args, **kwargs)
except Exception as e: # noqa: BLE001
print(f"[patch] get_api_info suppressed: {e!r}", flush=True)
return {"named_endpoints": {}, "unnamed_endpoints": {}}
_gb.Blocks.get_api_info = _safe_get_api_info
except Exception as _e: # noqa: BLE001
print(f"[patch] get_api_info guard skipped: {_e!r}", flush=True)
# ─────────────────────────────────────────────────────────────────────────────
from mmdet.apis import init_detector
from mmdet.datasets.pipelines import Compose
from projects import * # noqa: F401,F403 — registers Co-DETR custom modules
from configs import CLASS_NAMES
from helpers import (
find_interaction_branch, run_inference,
call_interaction, compute_style, draw_ui,
)
# ── model ─────────────────────────────────────────────────────────────────────
MODEL_CONFIG = os.path.join(
REPO,
"projects/configs/co_dino_vit/"
"co_dino_5scale_vit_large_coco_with_relation_only_all_losses_custom.py",
)
DEVICE = "cuda:0"
NMS_IOU = 0.5
DEFAULT_THR = 0.3
CHECKPOINT = os.environ.get("CKPT_PATH") or hf_hub_download(
repo_id="ahmaddarkhalil/hoi-detr", filename="epoch_5.pth"
)
# Load at module level — ZeroGPU emulates CUDA here so init_detector works.
model = init_detector(MODEL_CONFIG, CHECKPOINT, device=DEVICE)
model.eval()
test_pipeline = Compose(model.cfg.data.test.pipeline)
interaction_branch = find_interaction_branch(model.query_head)
def _save_vis(bgr_image, source_path):
stem = os.path.splitext(os.path.basename(source_path or "image"))[0]
out_dir = tempfile.mkdtemp(prefix="hoi_")
out_path = os.path.join(out_dir, f"{stem}_pred.png")
mmcv.imwrite(bgr_image, out_path)
return out_path
def _annotate_bgr(orig_img, score_thr):
"""Run HOI detection on a BGR image array and return the annotated frame.
Shared by the image and video paths. run_inference loads from a file, so we
stage the frame to a temp jpg (reused across calls).
"""
tmp = os.path.join(tempfile.gettempdir(), "hoi_frame_in.jpg")
mmcv.imwrite(orig_img, tmp)
dets, embeds = run_inference(
model, test_pipeline, tmp,
device=DEVICE, class_names=CLASS_NAMES,
score_thr=score_thr, nms_iou=NMS_IOU,
)
vis = orig_img.copy()
if not dets:
return vis
hands = [d for d in dets if d["class_id"] == 0]
firsts = [d for d in dets if d["class_id"] == 1]
seconds = [d for d in dets if d["class_id"] == 2]
hf_inters, fs_inters = [], []
for h in hands:
for f in firsts:
ok, prob = call_interaction(
interaction_branch,
embeds[h["query_idx"]], embeds[f["query_idx"]],
)
if ok:
hf_inters.append((h, f, prob))
for f in firsts:
for so in seconds:
ok, prob = call_interaction(
interaction_branch,
embeds[f["query_idx"]], embeds[so["query_idx"]],
)
if ok:
fs_inters.append((f, so, prob))
draw_ui(vis, dets, hf_inters, fs_inters, compute_style(vis.shape),
verbose_labels=True)
return vis
@spaces.GPU(duration=10)
def predict_image(image_path, score_thr):
# Empty/cleared input (e.g. a webcam frame that wasn't captured) arrives as
# None; just clear the output instead of erroring out of mmcv.imread.
if not image_path:
return None
try:
vis = _annotate_bgr(mmcv.imread(image_path), score_thr)
return _save_vis(vis, image_path)
except Exception as e:
traceback.print_exc()
raise gr.Error(f"{type(e).__name__}: {e}")
# Cap processed frames so a single run stays light on the ZeroGPU quota and
# fits the per-call budget; longer clips are temporally subsampled (output fps
# lowered to match so playback speed is preserved).
MAX_VIDEO_FRAMES = 100
def _open_video(video_path, work_dir):
"""Always normalize the input to a clean constant-rate H.264 mp4 via ffmpeg
before reading. Browser webcam recordings (especially desktop Safari/Chrome
.webm / fragmented mp4) open in OpenCV but then decode to garbled frames;
decoding through ffmpeg first makes reading reliable on every platform.
Some desktop webcam recordings are saved with a broken container (EBML
header), which ffmpeg can't demux — a lenient second pass is attempted
before giving up."""
norm = os.path.join(work_dir, "normalized_in.mp4")
attempts = [
["ffmpeg", "-y", "-i", video_path,
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-r", "25", "-an", norm],
# Lenient recovery for malformed/corrupt webcam webm.
["ffmpeg", "-y", "-err_detect", "ignore_err",
"-fflags", "+genpts+discardcorrupt", "-i", video_path,
"-c:v", "libx264", "-pix_fmt", "yuv420p", "-r", "25", "-an", norm],
]
for cmd in attempts:
r = subprocess.run(cmd, capture_output=True, text=True)
if (r.returncode == 0 and os.path.exists(norm)
and os.path.getsize(norm) > 0):
cap = cv2.VideoCapture(norm)
if cap.isOpened():
print("[video] reading ffmpeg-normalized input", flush=True)
return cap
print(f"[video] normalize attempt rc={r.returncode}:\n"
f"{r.stderr[-500:]}", flush=True)
print("[video] all normalize attempts failed; reading original", flush=True)
return cv2.VideoCapture(video_path)
@spaces.GPU(duration=30)
def predict_video(video_path, score_thr, progress=gr.Progress()):
if not video_path:
return None
try:
work_dir = tempfile.mkdtemp(prefix="hoi_vid_")
frames_dir = os.path.join(work_dir, "frames")
os.makedirs(frames_dir, exist_ok=True)
cap = _open_video(video_path, work_dir)
in_fps = cap.get(cv2.CAP_PROP_FPS)
if not in_fps or in_fps != in_fps or in_fps <= 0: # 0 / NaN guard
in_fps = 24.0
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
step = max(1, math.ceil(total / MAX_VIDEO_FRAMES)) if total > 0 else 1
out_fps = max(1.0, in_fps / step)
expected = (total // step) if total > 0 else None
print(f"[video] in={video_path} fps={in_fps:.2f} total={total} "
f"step={step} out_fps={out_fps:.2f}", flush=True)
# Annotate frames to PNGs (all forced to one even size), then let ffmpeg
# assemble a browser-friendly H.264/yuv420p mp4.
target, idx, n_proc = None, 0, 0
while True:
ok, frame = cap.read()
if not ok:
break
if idx % step == 0:
vis = _annotate_bgr(frame, score_thr)
if target is None:
h, w = vis.shape[:2]
target = (w - (w % 2), h - (h % 2)) # even dims
if (vis.shape[1], vis.shape[0]) != target:
vis = cv2.resize(vis, target)
cv2.imwrite(os.path.join(frames_dir, f"{n_proc:06d}.png"), vis)
n_proc += 1
if expected:
progress(n_proc / expected, desc=f"Frame {n_proc}/{expected}")
if n_proc >= MAX_VIDEO_FRAMES:
break
idx += 1
cap.release()
print(f"[video] processed {n_proc} frames, size={target}", flush=True)
if n_proc == 0:
raise gr.Error(
"Couldn't read this video. Some desktop browsers save webcam "
"recordings in a broken format we can't decode — please upload "
"a video file instead, or record from a phone.")
out_path = os.path.join(work_dir, "out.mp4")
cmd = ["ffmpeg", "-y", "-framerate", f"{out_fps:.4f}",
"-i", os.path.join(frames_dir, "%06d.png"),
"-c:v", "libx264", "-profile:v", "baseline", "-level", "3.1",
"-pix_fmt", "yuv420p", "-movflags", "+faststart",
"-an", out_path]
r = subprocess.run(cmd, capture_output=True, text=True)
print(f"[video] ffmpeg rc={r.returncode}; "
f"out_exists={os.path.exists(out_path)}", flush=True)
if r.returncode != 0 or not os.path.exists(out_path):
print(f"[video] ffmpeg failed:\n{r.stderr[-1500:]}", flush=True)
raise gr.Error("Failed to encode the output video.")
return out_path
except gr.Error:
raise
except Exception as e:
traceback.print_exc()
raise gr.Error(f"{type(e).__name__}: {e}")
# ── UI ────────────────────────────────────────────────────────────────────────
_APP_DIR = os.path.dirname(os.path.abspath(__file__))
_IMG_DIR = os.path.join(_APP_DIR, "examples", "images")
_VID_DIR = os.path.join(_APP_DIR, "examples", "videos")
def _logo_data_uri():
"""Inline the HOI-DETR logo as a base64 data URI so it renders in the
header without depending on gradio file-serving / allowed_paths."""
import base64
p = os.path.join(_APP_DIR, "assets", "logo_hoi_detr.png")
try:
with open(p, "rb") as f:
b64 = base64.b64encode(f.read()).decode()
return "data:image/png;base64," + b64
except OSError:
return ""
_LOGO_URI = _logo_data_uri()
def _list(d, exts):
return sorted(
os.path.join(d, f) for f in os.listdir(d) if f.lower().endswith(exts)
) if os.path.isdir(d) else []
img_examples = _list(_IMG_DIR, (".jpg", ".jpeg", ".png"))
vid_examples = _list(_VID_DIR, (".mp4", ".mov", ".webm", ".avi"))
def _downscale_jpg(path, out_dir, max_side=1280, quality=85):
"""Make a smaller JPG copy so example images load fast (some originals are
several MB). The detector resizes internally, so this doesn't affect
detections."""
try:
img = cv2.imread(path)
if img is None:
return path
h, w = img.shape[:2]
s = max_side / max(h, w)
if s < 1:
img = cv2.resize(img, (int(w * s), int(h * s)))
stem = os.path.splitext(os.path.basename(path))[0]
p = os.path.join(out_dir, stem + ".jpg")
cv2.imwrite(p, img, [cv2.IMWRITE_JPEG_QUALITY, quality])
return p
except Exception as e: # noqa: BLE001
print(f"[example] downscale failed for {path}: {e!r}", flush=True)
return path
_IMG_LITE_DIR = tempfile.mkdtemp(prefix="hoi_imgex_")
img_examples = [_downscale_jpg(p, _IMG_LITE_DIR) for p in img_examples]
def _make_poster(video_path, out_dir):
"""Extract the first frame as a JPG (native aspect, downscaled) so video
examples always show a visible thumbnail — gradio's auto video posters
don't render on mobile. Shown with object-fit: contain so the whole frame
is visible (no cropping out heads)."""
try:
cap = cv2.VideoCapture(video_path)
ok, frame = cap.read()
cap.release()
if not ok:
return None
h, w = frame.shape[:2]
s = 320.0 / max(h, w)
if s < 1:
frame = cv2.resize(frame, (int(w * s), int(h * s)))
stem = os.path.splitext(os.path.basename(video_path))[0]
p = os.path.join(out_dir, f"{stem}_poster.jpg")
cv2.imwrite(p, frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
return p
except Exception as e: # noqa: BLE001
print(f"[poster] failed for {video_path}: {e!r}", flush=True)
return None
_POSTER_DIR = tempfile.mkdtemp(prefix="hoi_posters_")
# (poster_or_video_path, video_path) per example
vid_thumbs = [(_make_poster(v, _POSTER_DIR) or v, v) for v in vid_examples]
# Keep the default gradio theme (orange primary + original font); just center
# the container and make the tab labels a bit larger / more obvious.
_CSS = (
".gradio-container {max-width: 1400px !important; margin: auto;}"
" button.tab-nav-button, .tab-nav button {font-size: 1.05rem !important;"
" font-weight: 600 !important;}"
# Keep the example-video thumbnails small and left-aligned on desktop.
" #vid_examples_gallery {max-width: 520px; margin-left: 0 !important;"
" margin-right: auto !important;}"
" #vid_examples_gallery img {object-fit: contain !important;}"
# Tighten the gap between the intro text and the tabs block.
" #main_tabs {margin-top: -10px;}"
)
# Detection colors echoed in the header text (match the drawn box colors).
_C_HAND, _C_FIRST, _C_SECOND = "#DC3220", "#FFC20A", "#14B4D2"
with gr.Blocks(title="HOI-DETR — Hand–Object Interaction Detection",
css=_CSS) as demo:
# Single HTML block (Markdown strips inline styles) so the colored words
# render and the title/links/description spacing is tight. Academic-style
# bracketed link badges sit right under the title.
_LINK = ("display:inline-block;padding:1px 6px;border:1px solid currentColor;"
"border-radius:6px;text-decoration:none;font-weight:600;"
"font-size:0.92rem;color:#C2410C;")
_logo_img = (
""
% _LOGO_URI
) if _LOGO_URI else "🖐️"
gr.HTML(
"
" "🌐 Project Page " "💻 Code
" "Detects " "hands, the " "first object held, and the " "second object it contacts, with their " "interaction links. Works on images and video — pick a " "tab below, then try an example or upload your own.
" % (_LINK, _LINK, _C_HAND, _C_FIRST, _C_SECOND) ) with gr.Tabs(elem_id="main_tabs"): # ── Image tab ──────────────────────────────────────────────── with gr.Tab("🖼️ Image"): with gr.Row(equal_height=True): with gr.Column(): img_in = gr.Image(type="filepath", label="Input image", height=320) img_thr = gr.Slider(0.05, 1.0, value=DEFAULT_THR, step=0.05, label="Score threshold") img_btn = gr.Button("Detect", variant="primary") with gr.Column(): img_out = gr.Image(label="HOI predictions", height=320) if img_examples: gr.Examples( examples=[[p] for p in img_examples], inputs=[img_in], outputs=img_out, fn=lambda p: predict_image(p, DEFAULT_THR), cache_examples=False, examples_per_page=len(img_examples), label="Example images — click to run", ) img_btn.click(predict_image, [img_in, img_thr], img_out) # ── Video tab ──────────────────────────────────────────────── with gr.Tab("🎬 Video"): with gr.Row(equal_height=True): with gr.Column(): vid_in = gr.Video(label="Input video", height=320, sources=["upload", "webcam"]) vid_thr = gr.Slider(0.05, 1.0, value=DEFAULT_THR, step=0.05, label="Score threshold") vid_btn = gr.Button("Process video", variant="primary") gr.Markdown( f"Processes up to {MAX_VIDEO_FRAMES} frames " "(longer clips are subsampled); this can take a minute." "") with gr.Column(): # autoplay (muted, no audio) so the result plays on its own # where the browser allows it. vid_out = gr.Video(label="HOI predictions", height=320, autoplay=True) if vid_thumbs: gr.Markdown("**Example videos** — tap a thumbnail to load, " "then press *Process video*") vid_gallery = gr.Gallery( value=[t for t, _ in vid_thumbs], columns=2, height=200, object_fit="contain", allow_preview=False, show_label=False, elem_id="vid_examples_gallery") def _pick_video(evt: gr.SelectData): return vid_examples[evt.index] vid_gallery.select(_pick_video, None, vid_in) vid_btn.click(predict_video, [vid_in, vid_thr], vid_out) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, allowed_paths=[_IMG_DIR, _VID_DIR, _POSTER_DIR, _IMG_LITE_DIR])