--- license: mit tags: - robotics - ogbench - jepa - world-model - prism-mppi - model-predictive-control - manipulation library_name: pytorch pipeline_tag: robotics --- # PRISM-JEPA — OGBench Cube (sim) JEPA world model + PRISM action prior for the **OGBench `cube-single`** manipulation task. These are the exact weights used to produce the headline **PRISM-MPPI** number in the paper. **Project page:** [yuhaiw.github.io/PRISM_web](https://yuhaiw.github.io/PRISM_web/) **Sister repo for PushT:** [`YuhaiW/prism-jepa-pusht`](https://huggingface.co/YuhaiW/prism-jepa-pusht) **Code:** [`YuhaiW/prism-jepa`](https://github.com/YuhaiW/prism-jepa) --- ## Headline result (mean ± std over 3 seeds {0, 1, 42}, K = 128) | | Vanilla MPPI | BC-only | **PRISM-MPPI (s = 1)** | |:------------:|:------------:|:-------:|:----------------------:| | **Cube SR (%)** | 44.0 | 66.0 | **79.3 ± 6.1** | `s = 1` is the only PRISM-specific hyperparameter; see paper §4.4 for the sigma-scale sweep. --- ## Bundle | File | Size | Role | |------|------|------| | `lewm_object.ckpt` | ~72 MB | Pickled LeWM (frozen JEPA encoder + AR predictor) | | `prior_head_cube.pt` | ~2 MB | PRISM prior head (3-layer MLP, β-NLL β=0.5, σ-floor 0.05) | | `jepa.py`, `module.py` | ~10 KB | Model classes (needed to unpickle the LeWM ckpt) | | `prior_head.py` | ~3 KB | `PriorHead` class | | `requirements.txt` | <1 KB | Pinned runtime dependencies | | `README.md` | — | This file | --- ## Reproduce the paper result ```bash # 1. Clone the eval/training code git clone https://github.com/YuhaiW/prism-jepa.git cd prism-jepa uv venv --python=3.10 && source .venv/bin/activate uv pip install stable-worldmodel[train] uv pip install opencv-python pygame mujoco pymunk scikit-image hdf5plugin export STABLEWM_HOME=$PWD/.stable-wm # 2. Pull the weights from this repo pip install huggingface_hub hf download YuhaiW/prism-jepa-cube --local-dir ./hf_cube mkdir -p $STABLEWM_HOME/cube mv hf_cube/lewm_object.ckpt $STABLEWM_HOME/cube/ mv hf_cube/prior_head_cube.pt . # 3. Run PRISM-MPPI (paper main result) python eval_prism_head.py --config-name=cube policy=cube/lewm solver=mppi \ +head.injection_mode=pog +head.sigma_scale=1.0 \ +head.ckpt=prior_head_cube.pt \ solver.num_samples=128 eval.num_eval=50 seed=0 # repeat with seed=1, seed=42 to reproduce the mean (~79%) ``` The eval also needs the OGBench cube dataset (used for normalization stats at eval time). See the upstream LeWM collection [`quentinll/lewm`](https://huggingface.co/collections/quentinll/lewm) and drop `cube_single_expert.h5` under `$STABLEWM_HOME/ogbench/`. ### Vanilla MPPI baseline (no prior) ```bash python eval_prism_head.py --config-name=cube policy=cube/lewm solver=mppi \ +head.injection_mode=none solver.num_samples=128 eval.num_eval=50 seed=0 ``` --- ## Training recipe The world model was trained from scratch on OGBench `cube-single-expert` following the upstream LeWM recipe (`python train.py data=cube`). The prior head was then trained with the world model frozen: ```bash python train_prior_head.py task=cube epochs=50 batch_size=512 ``` β-NLL loss (β = 0.5), σ floored at 0.05, AdamW, cosine LR. ~30 min on a single RTX 5090. --- ## How PRISM-MPPI works (one paragraph) A standard MPPI planner samples action sequences from `N(0, σ_π²)` and scores them by `‖ẑ_{t+H} − z_g‖²` in JEPA latent space. PRISM trains a lightweight prior head `g_φ(z_t, z_g) → (μ_p, σ_p)` from offline demonstrations, then fuses it with the planner's default sampling distribution at the initial step via the closed-form Product-of-Gaussians: ``` σ_init² = ((s·σ_p)⁻² + σ_π⁻²)⁻¹ μ_init = σ_init² · μ_p / (s·σ_p)² ``` The MPPI cost stays purely visual (embedding MSE to goal) — no reward, no Q-shortcut. PRISM only re-shapes where samples are drawn from, not how they are scored, which is why the eval-time goal mismatch that hurts pure BC-style policies does **not** hurt PRISM-MPPI. --- ## Citation _BibTeX TBA — paper under review._ ## License MIT. World-model code vendored from [LeWM](https://github.com/lucas-maes/le-wm) retains its upstream MIT copyright.