| --- |
| license: apache-2.0 |
| library_name: pytorch |
| tags: |
| - robotics |
| - world-model |
| - jepa |
| - lewm |
| - prism |
| - arx |
| - visuomotor |
| --- |
| |
| # PRISM-JEPA Β· red_cube (ARX-X5) β LeWM world model + PRISM action prior |
| |
| Complete deployable stack for goal-conditioned visuomotor planning on the ARX-X5 |
| **"red cube"** task: the **LeWM JEPA world model** + the **PRISM goal-conditioned action |
| prior** + **self-contained PRISM-MPPI inference code**. |
| |
| ## β οΈ Status β read first |
| |
| This is a **research artifact / deployment hand-off package, NOT a validated policy.** |
| |
| - Trained on 201 teleop demos (`Xia-2004/red_cube`, 42,165 frames, 5-DoF). |
| - The world model **converges cleanly and its forward model is accurate** (rollout pred/id |
| β 0.25 < 0.5), **but its MPPI cost surface is weak/flat** on this small-real-robot, |
| small-action data (CV β **0.14** βͺ 0.30 "discriminative" threshold). Consequence: the |
| planner's distinctive **cost-rescoring is dormant**, so in practice **PRISM β a |
| goal-conditioned BC prior** β in offline A/B it produces ~31% more expert-like actions |
| than vanilla LeWM-MPPI (which wanders), but adds **no measurable goal-progress** in the |
| world model's own latent metric (paired t-test p = 0.57). |
| - **Never run on the real robot.** Treat as a starting point; add workspace/velocity safety |
| limits and validate before any hardware run. |
| - Full analysis & how these numbers were obtained: project doc |
| `docs/30_red_cube_cv_investigation_and_prism.md`. |
|
|
| ## Contents |
|
|
| | file | description | |
| |---|---| |
| | `lewm_red_cube_epoch_100_object.ckpt` | LeWM world model β pickled JEPA: ViT-tiny encoder + AR transformer predictor + action encoder (~18M params) | |
| | `prior_head_red_cube.pt` | PRISM goal-conditioned action prior β `state_dict` + `config` + action `StandardScaler` (mean/scale) | |
| | `arx_inference_demo.py` | self-contained `PrismMPPIInference` (PoG-fused PRISM-MPPI; `use_prism=False` β vanilla LeWM-MPPI) | |
| | `jepa.py`, `module.py`, `prior_head.py` | model classes required to unpickle the ckpt and run the prior | |
|
|
| ## Observation / action space |
|
|
| - **Observation:** single top-down RGB frame, **224Γ224Γ3 uint8** (RealSense `camera_third`). |
| - **Goal:** an RGB goal image, same format (the prior + cost are conditioned on it). |
| - **Action:** 5-DoF delta end-effector **`[dx, dy, dz, dyaw, d_gripper]`**, raw units, one per |
| control tick. `plan()` returns one plan-step = `A_block = 5` ticks β shape **`(5, 5)`**. |
| |
| ## Dependencies |
| |
| `torch`, `numpy`, `einops`, and `transformers` (the encoder inside the ckpt is a HuggingFace |
| ViT, needed at unpickle time). The three bundled `.py` files must be importable from the |
| working directory. (If unpickling complains about a missing class, also `pip install |
| stable-pretraining`.) |
| |
| ## Deploy β receding-horizon control loop |
| |
| ```python |
| from arx_inference_demo import PrismMPPIInference |
| |
| planner = PrismMPPIInference( |
| lewm_ckpt = "lewm_red_cube_epoch_100_object.ckpt", |
| prior_ckpt = "prior_head_red_cube.pt", |
| use_prism = True, # True = PRISM (prior β MPPI via PoG fusion); False = vanilla LeWM-MPPI |
| device = "cuda", |
| ) |
| |
| goal_img = load_goal_image() # (224,224,3) uint8 β the task goal image |
| while not done: |
| obs = camera.read() # (224,224,3) uint8, top-down camera_third view |
| actions = planner.plan(obs, goal_img) # (5, 5) raw [dx,dy,dz,dyaw,d_gripper] |
| for a in actions: # receding horizon: execute the block, then replan |
| robot.execute(a) # (or execute fewer than 5 and replan more often) |
| ``` |
| |
| `plan()` runs one full PRISM-MPPI optimization and returns the first `A_block = 5` env-step |
| actions of the optimized plan, in **raw action units** (already de-normalized). |
|
|
| ## Key hyperparameters (`PrismMPPIInference` constructor) |
|
|
| | arg | default | meaning | |
| |---|---|---| |
| | `H` | 5 | planning horizon (plan-steps) | |
| | `A_block` | 5 | env-steps (ticks) per plan-step ("frameskip") | |
| | `K` | 128 | MPPI samples per iteration | |
| | `n_iters` | 30 | MPPI refinement iterations | |
| | `var_scale` | 1.0 | initial planner sampling std | |
| | `prior_sigma_scale` | 2.0 | multiplier on the prior Ο before PoG fusion (PRISM only) | |
| | `temperature` | 0.5 | MPPI softmax temperature | |
| | `history_size` | 3 | LeWM history-window length (**must match training**) | |
|
|
| `H`, `A_block`, `A_raw`, `history_size` must match the checkpoints β the constructor asserts |
| the prior head's config agrees. Change them only if you retrain. |
|
|
| ## PRISM vs vanilla (A/B) |
|
|
| Build a second planner with `use_prism=False` for a baseline (plain LeWM-MPPI, no prior, |
| same encoder/predictor/MPPI loop). On this task PRISM produces more expert-like actions; |
| vanilla tends to wander because the cost surface is flat. |
|
|
| ## Provenance |
|
|
| Data: [`Xia-2004/red_cube`](https://huggingface.co/datasets/Xia-2004/red_cube) (ARX-X5 |
| left-arm teleop). Sibling of `Xia-2004/arx-left-cube`. World-model architecture is identical |
| to the sim LeWM (ViT-tiny, embed_dim 192, predictor depth 6 / heads 16) β part of the |
| PRISM-JEPA project (sister of Newt-PRISM, CoRL 2026). |
| |