--- license: apache-2.0 library_name: pytorch tags: - robotics - world-model - jepa - lewm - prism - arx - visuomotor --- # PRISM-JEPA · red_cube (ARX-X5) — LeWM world model + PRISM action prior Complete deployable stack for goal-conditioned visuomotor planning on the ARX-X5 **"red cube"** task: the **LeWM JEPA world model** + the **PRISM goal-conditioned action prior** + **self-contained PRISM-MPPI inference code**. ## ⚠️ Status — read first This is a **research artifact / deployment hand-off package, NOT a validated policy.** - Trained on 201 teleop demos (`Xia-2004/red_cube`, 42,165 frames, 5-DoF). - The world model **converges cleanly and its forward model is accurate** (rollout pred/id ≈ 0.25 < 0.5), **but its MPPI cost surface is weak/flat** on this small-real-robot, small-action data (CV ≈ **0.14** ≪ 0.30 "discriminative" threshold). Consequence: the planner's distinctive **cost-rescoring is dormant**, so in practice **PRISM ≈ a goal-conditioned BC prior** — in offline A/B it produces ~31% more expert-like actions than vanilla LeWM-MPPI (which wanders), but adds **no measurable goal-progress** in the world model's own latent metric (paired t-test p = 0.57). - **Never run on the real robot.** Treat as a starting point; add workspace/velocity safety limits and validate before any hardware run. - Full analysis & how these numbers were obtained: project doc `docs/30_red_cube_cv_investigation_and_prism.md`. ## Contents | file | description | |---|---| | `lewm_red_cube_epoch_100_object.ckpt` | LeWM world model — pickled JEPA: ViT-tiny encoder + AR transformer predictor + action encoder (~18M params) | | `prior_head_red_cube.pt` | PRISM goal-conditioned action prior — `state_dict` + `config` + action `StandardScaler` (mean/scale) | | `arx_inference_demo.py` | self-contained `PrismMPPIInference` (PoG-fused PRISM-MPPI; `use_prism=False` → vanilla LeWM-MPPI) | | `jepa.py`, `module.py`, `prior_head.py` | model classes required to unpickle the ckpt and run the prior | ## Observation / action space - **Observation:** single top-down RGB frame, **224×224×3 uint8** (RealSense `camera_third`). - **Goal:** an RGB goal image, same format (the prior + cost are conditioned on it). - **Action:** 5-DoF delta end-effector **`[dx, dy, dz, dyaw, d_gripper]`**, raw units, one per control tick. `plan()` returns one plan-step = `A_block = 5` ticks → shape **`(5, 5)`**. ## Dependencies `torch`, `numpy`, `einops`, and `transformers` (the encoder inside the ckpt is a HuggingFace ViT, needed at unpickle time). The three bundled `.py` files must be importable from the working directory. (If unpickling complains about a missing class, also `pip install stable-pretraining`.) ## Deploy — receding-horizon control loop ```python from arx_inference_demo import PrismMPPIInference planner = PrismMPPIInference( lewm_ckpt = "lewm_red_cube_epoch_100_object.ckpt", prior_ckpt = "prior_head_red_cube.pt", use_prism = True, # True = PRISM (prior ⊗ MPPI via PoG fusion); False = vanilla LeWM-MPPI device = "cuda", ) goal_img = load_goal_image() # (224,224,3) uint8 — the task goal image while not done: obs = camera.read() # (224,224,3) uint8, top-down camera_third view actions = planner.plan(obs, goal_img) # (5, 5) raw [dx,dy,dz,dyaw,d_gripper] for a in actions: # receding horizon: execute the block, then replan robot.execute(a) # (or execute fewer than 5 and replan more often) ``` `plan()` runs one full PRISM-MPPI optimization and returns the first `A_block = 5` env-step actions of the optimized plan, in **raw action units** (already de-normalized). ## Key hyperparameters (`PrismMPPIInference` constructor) | arg | default | meaning | |---|---|---| | `H` | 5 | planning horizon (plan-steps) | | `A_block` | 5 | env-steps (ticks) per plan-step ("frameskip") | | `K` | 128 | MPPI samples per iteration | | `n_iters` | 30 | MPPI refinement iterations | | `var_scale` | 1.0 | initial planner sampling std | | `prior_sigma_scale` | 2.0 | multiplier on the prior σ before PoG fusion (PRISM only) | | `temperature` | 0.5 | MPPI softmax temperature | | `history_size` | 3 | LeWM history-window length (**must match training**) | `H`, `A_block`, `A_raw`, `history_size` must match the checkpoints — the constructor asserts the prior head's config agrees. Change them only if you retrain. ## PRISM vs vanilla (A/B) Build a second planner with `use_prism=False` for a baseline (plain LeWM-MPPI, no prior, same encoder/predictor/MPPI loop). On this task PRISM produces more expert-like actions; vanilla tends to wander because the cost surface is flat. ## Provenance Data: [`Xia-2004/red_cube`](https://huggingface.co/datasets/Xia-2004/red_cube) (ARX-X5 left-arm teleop). Sibling of `Xia-2004/arx-left-cube`. World-model architecture is identical to the sim LeWM (ViT-tiny, embed_dim 192, predictor depth 6 / heads 16) — part of the PRISM-JEPA project (sister of Newt-PRISM, CoRL 2026).