""" Annotated example: running inference with an Owl IDM model. The InferencePipeline handles: - Loading model weights (local or from Hugging Face Hub) - Sliding window inference over arbitrary-length videos - Log1p scaling reversal for mouse outputs - Optional torch.compile for faster repeated inference Example usage (local): pipeline = InferencePipeline.from_pretrained( config_path="configs/vpt_simple.yml", checkpoint_path="checkpoints/simpler_vpt/ema/step_50000.pt" ) Example usage (HF Hub): pipeline = InferencePipeline.from_pretrained("username/owl-idm-vpt-v0") # video: [b, n, c, h, w] tensor normalized to range [-1, 1] button_preds, mouse_preds = pipeline(video) # button_preds: [b, n, n_buttons] bool — one entry per configured button # mouse_preds: [b, n, 2] float — (dx, dy) in raw pixel space """ import torch import os from tqdm import tqdm from owl_idms.configs import load_config, get_button_labels, get_n_buttons from owl_idms.models import get_model_cls class InferencePipeline: """ Inference pipeline for IDM models. Implements sliding window inference: for each frame i in the input video, a window of `window_length` frames centered on i is fed to the model, which predicts the controls active at that frame. Edge frames are padded by repeating the first/last frame. """ def __init__(self, model, config, device='cuda', compile_model=True): """ Args: model: The IDM model (VPT_IDM or similar) config: Full OmegaConf config (must have .train and .model sections) device: Device to run inference on compile_model: Whether to torch.compile (faster after warmup, slower first call) """ self.config = config self.device = device self.window_length = config.train.window_length self.use_log1p_scaling = getattr(config.train, 'use_log1p_scaling', True) self.button_labels = get_button_labels(config.model) self.model = model.to(device=device, dtype=torch.bfloat16) self.model.eval() if compile_model: print("Compiling model for inference...") self.model = torch.compile(self.model, mode='max-autotune') print("Model compiled!") @classmethod def from_pretrained(cls, model_id_or_path, checkpoint_path=None, device='cuda', compile_model=True, token=None): """ Load a pretrained model from local files or Hugging Face Hub. Args: model_id_or_path: HF Hub repo ID (e.g. "username/owl-idm-vpt-v0") OR local path to a config YAML file checkpoint_path: Path to .pt checkpoint (only needed for local loading) device: Device to run on compile_model: Whether to torch.compile the model token: HF API token (for private repos) Examples: # From HF Hub pipeline = InferencePipeline.from_pretrained("username/owl-idm-vpt-v0") # From local files pipeline = InferencePipeline.from_pretrained( "configs/vpt_simple.yml", checkpoint_path="checkpoints/simpler_vpt/ema/step_50000.pt" ) """ is_local = os.path.exists(model_id_or_path) or model_id_or_path.endswith('.yml') if is_local: if checkpoint_path is None: raise ValueError("checkpoint_path is required when loading from local files") config_path = model_id_or_path print(f"Loading from local files: {config_path}, {checkpoint_path}") else: try: from huggingface_hub import hf_hub_download except ImportError: raise ImportError("Install huggingface_hub: pip install huggingface_hub") print(f"Loading from Hugging Face Hub: {model_id_or_path}") config_path = hf_hub_download(repo_id=model_id_or_path, filename="config.yml", token=token) checkpoint_path = hf_hub_download(repo_id=model_id_or_path, filename="model.pt", token=token) config = load_config(config_path) model_cls = get_model_cls(config.model.model_id) model = model_cls(config.model) # Checkpoints saved via upload_to_hf.py contain raw EMA weights (just state_dict) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint) print(f"Loaded checkpoint from {checkpoint_path}") return cls(model, config, device=device, compile_model=compile_model) @torch.no_grad() def __call__(self, videos, window_size=None, show_progress=True): """ Run sliding window inference on a batch of videos. Args: videos: [b, n, c, h, w] float tensor, normalized to [-1, 1] window_size: Override the window size from config (optional) show_progress: Show a tqdm progress bar Returns: button_preds: [b, n, n_buttons] bool — True = button pressed mouse_preds: [b, n, 2] float — (dx, dy) mouse delta in pixels Button order matches the `buttons` list in the config YAML. Use pipeline.button_labels to get the label for each index. """ if window_size is None: window_size = self.window_length b, n, c, h, w = videos.shape videos = videos.to(device=self.device, dtype=torch.bfloat16) # Pad start/end by repeating edge frames so every frame gets a full window middle_idx = (window_size - 1) // 2 pad_start = middle_idx pad_end = window_size - 1 - middle_idx padded = torch.cat([ videos[:, 0:1].expand(-1, pad_start, -1, -1, -1), videos, videos[:, -1:].expand(-1, pad_end, -1, -1, -1), ], dim=1) button_preds = [] mouse_preds = [] iterator = tqdm(range(n), desc="Running inference") if show_progress else range(n) for i in iterator: window = padded[:, i:i + window_size] # [b, window_size, c, h, w] # Model is always in eval mode here; returns middle-frame predictions # button_logits: [b, n_buttons], mouse_pred: [b, 2] button_logits, mouse_pred = self.model(window) button_preds.append(button_logits.clone()) mouse_preds.append(mouse_pred.clone()) # [n, b, ...] -> [b, n, ...] button_preds = torch.stack(button_preds, dim=1) mouse_preds = torch.stack(mouse_preds, dim=1) # Threshold logits to get boolean button states button_preds = torch.sigmoid(button_preds) > 0.5 # Mouse was predicted in log1p space during training; invert that here if self.use_log1p_scaling: mouse_preds = torch.sign(mouse_preds) * torch.expm1(torch.abs(mouse_preds)) return button_preds, mouse_preds if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Run Owl IDM inference") parser.add_argument("--config", type=str, required=True) parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--no-compile", action="store_true") args = parser.parse_args() pipeline = InferencePipeline.from_pretrained( args.config, args.checkpoint, device=args.device, compile_model=not args.no_compile ) print(f"\nPipeline ready!") print(f" Window length: {pipeline.window_length}") print(f" Buttons ({len(pipeline.button_labels)}): {pipeline.button_labels}") print(f" Log1p scaling: {pipeline.use_log1p_scaling}") print(f" Device: {pipeline.device}")