YuhaiW's picture
red_cube PRISM-JEPA deploy bundle: LeWM WM + PRISM prior + self-contained inference + README
8f017a6 verified
|
Raw
History Blame Contribute Delete
5.12 kB
---
license: apache-2.0
library_name: pytorch
tags:
- robotics
- world-model
- jepa
- lewm
- prism
- arx
- visuomotor
---
# PRISM-JEPA Β· red_cube (ARX-X5) β€” LeWM world model + PRISM action prior
Complete deployable stack for goal-conditioned visuomotor planning on the ARX-X5
**"red cube"** task: the **LeWM JEPA world model** + the **PRISM goal-conditioned action
prior** + **self-contained PRISM-MPPI inference code**.
## ⚠️ Status β€” read first
This is a **research artifact / deployment hand-off package, NOT a validated policy.**
- Trained on 201 teleop demos (`Xia-2004/red_cube`, 42,165 frames, 5-DoF).
- The world model **converges cleanly and its forward model is accurate** (rollout pred/id
β‰ˆ 0.25 < 0.5), **but its MPPI cost surface is weak/flat** on this small-real-robot,
small-action data (CV β‰ˆ **0.14** β‰ͺ 0.30 "discriminative" threshold). Consequence: the
planner's distinctive **cost-rescoring is dormant**, so in practice **PRISM β‰ˆ a
goal-conditioned BC prior** β€” in offline A/B it produces ~31% more expert-like actions
than vanilla LeWM-MPPI (which wanders), but adds **no measurable goal-progress** in the
world model's own latent metric (paired t-test p = 0.57).
- **Never run on the real robot.** Treat as a starting point; add workspace/velocity safety
limits and validate before any hardware run.
- Full analysis & how these numbers were obtained: project doc
`docs/30_red_cube_cv_investigation_and_prism.md`.
## Contents
| file | description |
|---|---|
| `lewm_red_cube_epoch_100_object.ckpt` | LeWM world model β€” pickled JEPA: ViT-tiny encoder + AR transformer predictor + action encoder (~18M params) |
| `prior_head_red_cube.pt` | PRISM goal-conditioned action prior β€” `state_dict` + `config` + action `StandardScaler` (mean/scale) |
| `arx_inference_demo.py` | self-contained `PrismMPPIInference` (PoG-fused PRISM-MPPI; `use_prism=False` β†’ vanilla LeWM-MPPI) |
| `jepa.py`, `module.py`, `prior_head.py` | model classes required to unpickle the ckpt and run the prior |
## Observation / action space
- **Observation:** single top-down RGB frame, **224Γ—224Γ—3 uint8** (RealSense `camera_third`).
- **Goal:** an RGB goal image, same format (the prior + cost are conditioned on it).
- **Action:** 5-DoF delta end-effector **`[dx, dy, dz, dyaw, d_gripper]`**, raw units, one per
control tick. `plan()` returns one plan-step = `A_block = 5` ticks β†’ shape **`(5, 5)`**.
## Dependencies
`torch`, `numpy`, `einops`, and `transformers` (the encoder inside the ckpt is a HuggingFace
ViT, needed at unpickle time). The three bundled `.py` files must be importable from the
working directory. (If unpickling complains about a missing class, also `pip install
stable-pretraining`.)
## Deploy β€” receding-horizon control loop
```python
from arx_inference_demo import PrismMPPIInference
planner = PrismMPPIInference(
lewm_ckpt = "lewm_red_cube_epoch_100_object.ckpt",
prior_ckpt = "prior_head_red_cube.pt",
use_prism = True, # True = PRISM (prior βŠ— MPPI via PoG fusion); False = vanilla LeWM-MPPI
device = "cuda",
)
goal_img = load_goal_image() # (224,224,3) uint8 β€” the task goal image
while not done:
obs = camera.read() # (224,224,3) uint8, top-down camera_third view
actions = planner.plan(obs, goal_img) # (5, 5) raw [dx,dy,dz,dyaw,d_gripper]
for a in actions: # receding horizon: execute the block, then replan
robot.execute(a) # (or execute fewer than 5 and replan more often)
```
`plan()` runs one full PRISM-MPPI optimization and returns the first `A_block = 5` env-step
actions of the optimized plan, in **raw action units** (already de-normalized).
## Key hyperparameters (`PrismMPPIInference` constructor)
| arg | default | meaning |
|---|---|---|
| `H` | 5 | planning horizon (plan-steps) |
| `A_block` | 5 | env-steps (ticks) per plan-step ("frameskip") |
| `K` | 128 | MPPI samples per iteration |
| `n_iters` | 30 | MPPI refinement iterations |
| `var_scale` | 1.0 | initial planner sampling std |
| `prior_sigma_scale` | 2.0 | multiplier on the prior Οƒ before PoG fusion (PRISM only) |
| `temperature` | 0.5 | MPPI softmax temperature |
| `history_size` | 3 | LeWM history-window length (**must match training**) |
`H`, `A_block`, `A_raw`, `history_size` must match the checkpoints β€” the constructor asserts
the prior head's config agrees. Change them only if you retrain.
## PRISM vs vanilla (A/B)
Build a second planner with `use_prism=False` for a baseline (plain LeWM-MPPI, no prior,
same encoder/predictor/MPPI loop). On this task PRISM produces more expert-like actions;
vanilla tends to wander because the cost surface is flat.
## Provenance
Data: [`Xia-2004/red_cube`](https://huggingface.co/datasets/Xia-2004/red_cube) (ARX-X5
left-arm teleop). Sibling of `Xia-2004/arx-left-cube`. World-model architecture is identical
to the sim LeWM (ViT-tiny, embed_dim 192, predictor depth 6 / heads 16) β€” part of the
PRISM-JEPA project (sister of Newt-PRISM, CoRL 2026).