| """arx_inference_demo.py — standalone PRISM-MPPI inference for ARX cube task. |
| |
| This file is **self-contained**: it depends only on the bundled |
| `jepa.py`, `module.py`, `prior_head.py`, plus standard torch / numpy. |
| No `stable_worldmodel` import — the MPPI loop is re-implemented inline. |
| |
| Intended use by a downstream consumer (e.g., the ARX deployment side): |
| |
| from arx_inference_demo import PrismMPPIInference |
| |
| planner = PrismMPPIInference( |
| lewm_ckpt = "lewm_arx.ckpt", |
| prior_ckpt = "prior_head_arx.pt", |
| device = "cuda", |
| ) |
| |
| # In the control loop: |
| while not done: |
| obs_uint8 = camera.read() # (224, 224, 3) uint8 RGB |
| goal_uint8 = goal_image # (224, 224, 3) uint8 RGB |
| actions = planner.plan(obs_uint8, goal_uint8) |
| # → (A_block, 5) float32, raw action units |
| for a in actions: |
| robot.execute(a) # step the robot |
| |
| `plan()` performs one full PRISM-MPPI optimization and returns the first |
| A_block = 5 env-step actions of the optimized plan. The caller may choose |
| to execute all 5 then replan (receding-horizon, k=A_block), or execute |
| fewer and replan more often. |
| |
| PRISM-MPPI summary: |
| 1. JEPA encoder turns current obs + goal image into latent embeddings z_t, z_g. |
| 2. PRISM prior head maps (z_t, z_g) → (μ_p, σ_p) over the next |
| H × A_block × A_raw normalized actions. |
| 3. We seed an MPPI distribution N(0, var_scale I) and PoG-fuse with the |
| prior to get N(fused_μ, fused_σ²). The variance is FROZEN through MPPI |
| iterations (this is the PRISM-MPPI signature; see paper §3). |
| 4. Each iteration samples K candidate action sequences, rolls them out via |
| the LeWM ARPredictor in latent space, computes cost = MSE(predicted |
| final z, z_g), reweights candidates by exp(-β·cost), updates the mean. |
| 5. After n_iters iterations, the first A_block entries of the mean are |
| returned (denormalized to raw env action units via the saved |
| StandardScaler). |
| """ |
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| |
| import jepa |
| import module |
| from prior_head import PriorHead |
|
|
| IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) |
| IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) |
|
|
|
|
| def _preprocess(img_uint8: np.ndarray, device: torch.device) -> torch.Tensor: |
| """uint8 (H, W, 3) → float (1, 3, 224, 224), ImageNet-normalized.""" |
| assert img_uint8.shape == (224, 224, 3), \ |
| f"Expected (224, 224, 3) image, got {img_uint8.shape}" |
| t = torch.from_numpy(img_uint8).permute(2, 0, 1).float().div(255.0).unsqueeze(0) |
| t = t.to(device) |
| mean = IMAGENET_MEAN.to(device) |
| std = IMAGENET_STD.to(device) |
| return (t - mean) / std |
|
|
|
|
| def _pog_fusion(mean, std, mu_p, sg_p, sigma_floor=0.05): |
| """Product-of-Gaussians fusion. Matches prism_mppi.pog_fusion.""" |
| eps = 1e-8 |
| tau_base = 1.0 / (std ** 2 + eps) |
| tau_p = 1.0 / (sg_p ** 2 + eps) |
| tau_c = tau_base + tau_p |
| fused_mean = (tau_base * mean + tau_p * mu_p) / tau_c |
| fused_std = (1.0 / tau_c).sqrt().clamp(min=sigma_floor) |
| return fused_mean, fused_std |
|
|
|
|
| class PrismMPPIInference: |
| """Standalone PRISM-MPPI planner for ARX cube task. |
| |
| Supports two modes via the `use_prism` constructor flag — kept on a |
| single class so that PRISM and vanilla-MPPI A/B comparisons use the |
| exact same encoder, predictor, MPPI loop, and StandardScaler. The |
| only difference between the two modes is whether the PoG fusion at |
| init time uses the prior head's (μ_p, σ_p) or not. The prior-head |
| checkpoint is always loaded — its StandardScaler (action |
| normalization) is shared by both modes so the comparison is |
| apples-to-apples in raw action units. |
| |
| Args (paper defaults — change only if you know what you're doing): |
| lewm_ckpt: path to lewm_arx.ckpt (pickled JEPA module) |
| prior_ckpt: path to prior_head_arx.pt (PRISM head state_dict + scaler) |
| use_prism: if True (default), inject the PRISM prior via PoG fusion. |
| If False, skip the prior — the planner becomes vanilla |
| LeWM-MPPI from N(0, var_scale) seed. Use this for paper- |
| grade real-robot A/B against PRISM-MPPI. |
| H: planning horizon in plan-steps (default 5) |
| A_block: env-steps per plan-step (default 5, "frameskip") |
| K: num MPPI samples per iteration (default 128) |
| n_iters: num MPPI refinement iterations (default 30) |
| var_scale: initial planner std (default 1.0) |
| temperature: MPPI softmax temperature β = 1/temperature (default 0.5) |
| sigma_floor: lower bound on fused σ (default 0.05); only used by PRISM |
| prior_sigma_scale: multiplier on prior σ_p before fusion (default 2.0, |
| matches the paper's PRISM-MPPI s=2.0 setting); only used by PRISM |
| history_size: LeWM history-window length (default 3; must match training) |
| device: 'cuda' or 'cpu' |
| """ |
|
|
| def __init__( |
| self, |
| lewm_ckpt: str | Path, |
| prior_ckpt: str | Path, |
| use_prism: bool = True, |
| H: int = 5, |
| A_block: int = 5, |
| K: int = 128, |
| n_iters: int = 30, |
| var_scale: float = 1.0, |
| temperature: float = 0.5, |
| sigma_floor: float = 0.05, |
| prior_sigma_scale: float = 2.0, |
| history_size: int = 3, |
| device: str = "cuda", |
| ): |
| self.device = torch.device(device) |
| self.use_prism = bool(use_prism) |
| self.H = H |
| self.A_block = A_block |
| self.K = K |
| self.n_iters = n_iters |
| self.var_scale = var_scale |
| self.beta = 1.0 / temperature |
| self.sigma_floor = sigma_floor |
| self.prior_sigma_scale = prior_sigma_scale |
| self.history_size = history_size |
|
|
| |
| print(f"[init] loading LeWM ckpt: {lewm_ckpt}") |
| self.lewm = torch.load( |
| str(lewm_ckpt), map_location=self.device, weights_only=False, |
| ) |
| self.lewm.to(self.device).eval() |
| for p in self.lewm.parameters(): |
| p.requires_grad_(False) |
|
|
| |
| print(f"[init] loading prior head + scaler: {prior_ckpt}") |
| pck = torch.load(str(prior_ckpt), map_location=self.device, weights_only=False) |
| cfg = pck["config"] |
| self.A_raw = int(cfg["A_raw"]) |
| assert cfg["H"] == self.H and cfg["A_block"] == self.A_block, ( |
| f"Ckpt config mismatch: H={cfg['H']} A_block={cfg['A_block']} " |
| f"vs runtime H={self.H} A_block={self.A_block}" |
| ) |
|
|
| if self.use_prism: |
| self.head = PriorHead(**cfg).to(self.device).eval() |
| self.head.load_state_dict(pck["state_dict"]) |
| for p in self.head.parameters(): |
| p.requires_grad_(False) |
| else: |
| self.head = None |
|
|
| |
| self.scaler_mean = torch.tensor(pck["scaler_mean"], device=self.device).float() |
| self.scaler_scale = torch.tensor(pck["scaler_scale"], device=self.device).float() |
| mode_str = "PRISM-MPPI" if self.use_prism else "vanilla LeWM-MPPI (PRISM off)" |
| print(f"[init] mode = {mode_str}") |
| print(f"[init] z_dim={cfg['z_dim']} H={self.H} A_block={self.A_block} " |
| f"A_raw={self.A_raw}") |
| print(f"[init] device={self.device} K={self.K} n_iters={self.n_iters}") |
|
|
| @torch.no_grad() |
| def _encode(self, img_uint8: np.ndarray) -> torch.Tensor: |
| """uint8 image → (1, D) CLS embedding.""" |
| x = _preprocess(img_uint8, self.device) |
| |
| info = {"pixels": x.unsqueeze(1)} |
| info = self.lewm.encode(info) |
| return info["emb"][:, 0] |
|
|
| @torch.no_grad() |
| def _prior(self, z_t: torch.Tensor, z_g: torch.Tensor): |
| """PRISM head: (1, D), (1, D) → (μ, σ) of shape (1, H, A_block, A_raw) |
| in normalized action space.""" |
| return self.head(z_t, z_g) |
|
|
| @torch.no_grad() |
| def _rollout_costs( |
| self, |
| z_t: torch.Tensor, |
| z_g: torch.Tensor, |
| action_candidates: torch.Tensor, |
| ) -> torch.Tensor: |
| """Rollout each candidate via LeWM AR predictor, compute final-z MSE to z_g.""" |
| B, K, T_total, A = action_candidates.shape |
| assert T_total == self.H * self.A_block |
| D = z_t.shape[-1] |
| HS = self.history_size |
|
|
| |
| |
| emb = z_t.unsqueeze(1).expand(B, K, D).reshape(B * K, D) |
| emb = emb.unsqueeze(1).expand(-1, HS, -1).contiguous() |
|
|
| |
| act_seq = action_candidates.reshape(B * K, T_total, A) |
|
|
| |
| act_plan = act_seq.reshape(B * K, self.H, self.A_block * A) |
|
|
| |
| |
| act_emb = self.lewm.action_encoder(act_plan) |
|
|
| |
| for t in range(self.H): |
| emb_trunc = emb[:, -HS:] |
| act_trunc = act_emb[:, max(0, t - HS + 1): t + 1] |
| |
| if act_trunc.shape[1] < HS: |
| pad = act_trunc[:, :1].expand(-1, HS - act_trunc.shape[1], -1) |
| act_trunc = torch.cat([pad, act_trunc], dim=1) |
| pred = self.lewm.predict(emb_trunc, act_trunc)[:, -1:] |
| emb = torch.cat([emb, pred], dim=1) |
|
|
| |
| pred_final = emb[:, -1] |
| goal = z_g.unsqueeze(1).expand(B, K, D).reshape(B * K, D) |
| cost = F.mse_loss(pred_final, goal, reduction="none").sum(dim=-1) |
| return cost.reshape(B, K) |
|
|
| @torch.no_grad() |
| def plan(self, obs_uint8: np.ndarray, goal_uint8: np.ndarray) -> np.ndarray: |
| """One MPPI optimization (PRISM or vanilla depending on `use_prism`). |
| |
| Returns (A_block, A_raw) actions in raw env units. |
| """ |
| |
| z_t = self._encode(obs_uint8) |
| z_g = self._encode(goal_uint8) |
|
|
| |
| shape = (1, self.H * self.A_block, self.A_raw) |
| mean = torch.zeros(shape, device=self.device) |
| std = torch.full(shape, self.var_scale, device=self.device) |
|
|
| |
| if self.use_prism: |
| mu_p, sg_p = self._prior(z_t, z_g) |
| mu_p_flat = mu_p.reshape(*shape) |
| sg_p_flat = sg_p.reshape(*shape) * self.prior_sigma_scale |
| mean, std = _pog_fusion(mean, std, mu_p_flat, sg_p_flat, self.sigma_floor) |
|
|
| |
| |
| for it in range(self.n_iters): |
| noise = torch.randn( |
| 1, self.K, self.H * self.A_block, self.A_raw, device=self.device, |
| ) |
| cands = mean.unsqueeze(1) + noise * std.unsqueeze(1) |
| |
|
|
| cost = self._rollout_costs(z_t, z_g, cands) |
| log_w = -self.beta * (cost - cost.min(dim=-1, keepdim=True).values) |
| w = torch.softmax(log_w, dim=-1) |
|
|
| |
| mean = (w.unsqueeze(-1).unsqueeze(-1) * cands).sum(dim=1) |
| |
|
|
| |
| first_block_norm = mean[0, : self.A_block] |
| first_block_raw = first_block_norm * self.scaler_scale + self.scaler_mean |
| return first_block_raw.cpu().numpy().astype(np.float32) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| ap = argparse.ArgumentParser() |
| ap.add_argument( |
| "--lewm-ckpt", default=".stable-wm/lewm_arx_epoch_100_object.ckpt", |
| ) |
| ap.add_argument("--prior-ckpt", default="prior_head_arx.pt") |
| ap.add_argument("--h5", default=".stable-wm/arx_left_cube.h5") |
| ap.add_argument("--seed", type=int, default=0) |
| ap.add_argument("--no-prism", action="store_true", |
| help="Run vanilla LeWM-MPPI (no PRISM prior). Use for A/B comparison.") |
| args = ap.parse_args() |
|
|
| |
| import h5py |
| print(f"\n[demo] loading sample from {args.h5}") |
| with h5py.File(args.h5, "r") as f: |
| obs = f["pixels"][0] |
| goal = f["goal_pixels"][0] |
| ground_truth_action = f["action"][0] |
| print(f"[demo] obs.shape={obs.shape} goal.shape={goal.shape} " |
| f"obs.dtype={obs.dtype}") |
|
|
| |
| print() |
| planner = PrismMPPIInference( |
| lewm_ckpt=args.lewm_ckpt, |
| prior_ckpt=args.prior_ckpt, |
| use_prism=not args.no_prism, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) |
|
|
| |
| mode = "vanilla LeWM-MPPI" if args.no_prism else "PRISM-MPPI" |
| print(f"\n[demo] running {mode} on the sample obs + its goal image…") |
| import time |
| t0 = time.time() |
| actions = planner.plan(obs, goal) |
| dt = time.time() - t0 |
| print(f"[demo] planned in {dt:.2f}s") |
| print(f"[demo] action sequence (A_block × A_raw): shape={actions.shape}") |
| print(f"[demo] first action: {actions[0].tolist()}") |
| print(f"[demo] ground-truth (t=0): {ground_truth_action.tolist()}") |
| print(f"[demo] |Δ|: {np.linalg.norm(actions[0] - ground_truth_action):.4f}") |
|
|