"""LeHome Challenge 2026 - single-model residual policy (v4 global). Loads SmolVLA backbone (frozen) and applies a state-only residual MLP averaged from 40 Seen garments under sparse-reward residual RL. Inference is deterministic: no exploration noise, no online updates. Required env vars: LEHOME_VLA_POLICY_PATH SmolVLA backbone directory LEHOME_VLA_DATASET_ROOT LeRobot dataset root LEHOME_RESIDUAL_CHECKPOINT path to residual_averaged.pt Optional: LEHOME_RESIDUAL_SCALE default 0.03 LEHOME_TASK_DESCRIPTION default fold the garment on the table LEHOME_POLICY_DEVICE default cuda:0 """ from __future__ import annotations import os from pathlib import Path from typing import Dict, Optional import numpy as np import torch from torch import nn from .base_policy import BasePolicy from .lerobot_policy import LeRobotPolicy from .registry import PolicyRegistry class _ResidualActor(nn.Module): def __init__(self, state_dim: int, action_dim: int, hidden_dims=(256, 256)): super().__init__() dims = [state_dim, *hidden_dims, action_dim] layers: list[nn.Module] = [] for i in range(len(dims) - 2): layers.append(nn.Linear(dims[i], dims[i + 1])) layers.append(nn.ReLU()) layers.append(nn.Linear(dims[-2], dims[-1])) self.network = nn.Sequential(*layers) def forward(self, state: torch.Tensor) -> torch.Tensor: return self.network(state) @PolicyRegistry.register("residual_v4_global") class ResidualV4GlobalPolicy(BasePolicy): """SmolVLA + state-only residual (single model, deterministic).""" def __init__( self, policy_path: Optional[str] = None, model_path: Optional[str] = None, dataset_root: Optional[str] = None, task_description: str = "fold the garment on the table", device: str = "cuda", **kwargs, ): super().__init__(**kwargs) backbone_path = ( os.environ.get("LEHOME_VLA_POLICY_PATH") or policy_path or model_path or "" ) ds_root = os.environ.get("LEHOME_VLA_DATASET_ROOT") or dataset_root or "" residual_ckpt = os.environ.get("LEHOME_RESIDUAL_CHECKPOINT", "") residual_scale = float(os.environ.get("LEHOME_RESIDUAL_SCALE", "0.03")) task_desc = os.environ.get("LEHOME_TASK_DESCRIPTION", task_description) policy_device = os.environ.get( "LEHOME_POLICY_DEVICE", "cuda:0" if torch.cuda.is_available() else "cpu", ) if not backbone_path or not Path(backbone_path).exists(): raise FileNotFoundError( f"VLA backbone not found (set LEHOME_VLA_POLICY_PATH): {backbone_path}" ) if not ds_root or not Path(ds_root).exists(): raise FileNotFoundError( f"Dataset root not found (set LEHOME_VLA_DATASET_ROOT): {ds_root}" ) if not residual_ckpt or not Path(residual_ckpt).exists(): raise FileNotFoundError( f"Residual checkpoint not found (set LEHOME_RESIDUAL_CHECKPOINT): {residual_ckpt}" ) self.policy_device = torch.device(policy_device) self.residual_scale = residual_scale self.base_policy = LeRobotPolicy( policy_path=backbone_path, dataset_root=ds_root, task_description=task_desc, device=str(self.policy_device), ) payload = torch.load(residual_ckpt, map_location="cpu") self.state_dim = int(payload["state_dim"]) self.action_dim = int(payload["action_dim"]) hidden_dims = tuple(payload["hidden_dims"]) self.actor = _ResidualActor( state_dim=self.state_dim, action_dim=self.action_dim, hidden_dims=hidden_dims, ).to(self.policy_device) self.actor.load_state_dict(payload["model_state_dict"]) self.actor.eval() print( f"[ResidualV4GlobalPolicy] backbone={backbone_path} " f"residual={residual_ckpt} scale={self.residual_scale} " f"state_dim={self.state_dim} action_dim={self.action_dim} hidden={hidden_dims}" ) def reset(self): self.base_policy.reset() def select_action(self, observation: Dict[str, np.ndarray]) -> np.ndarray: base_action = self.base_policy.select_action(observation).astype(np.float32, copy=False) state_np = observation.get( "observation.state", np.zeros(self.state_dim, dtype=np.float32) ).astype(np.float32, copy=False) state = torch.from_numpy(state_np).to(self.policy_device) with torch.inference_mode(): residual = self.actor(state.unsqueeze(0)).squeeze(0).cpu().numpy().astype(np.float32) return (base_action + self.residual_scale * residual).astype(np.float32, copy=False)