Initial release: LeWM + PRISM prior head for Cube
Browse files- README.md +78 -0
- lewm_object.ckpt +3 -0
- prior_head_cube.pt +3 -0
README.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
pipeline_tag: robotics
|
| 5 |
+
tags:
|
| 6 |
+
- robotics
|
| 7 |
+
- jepa
|
| 8 |
+
- world-model
|
| 9 |
+
- visual-manipulation
|
| 10 |
+
- prism
|
| 11 |
+
- mppi
|
| 12 |
+
- planning
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# PRISM-JEPA · Cube (OGBench `cube-single`)
|
| 16 |
+
|
| 17 |
+
A JEPA visual world model paired with a learned **action prior head** that
|
| 18 |
+
biases an MPPI planner via closed-form Product-of-Gaussians fusion. Together
|
| 19 |
+
they form the **PRISM-MPPI** pipeline from the PRISM paper, evaluated on
|
| 20 |
+
OGBench's `cube-single` task.
|
| 21 |
+
|
| 22 |
+
| File | Size | Role |
|
| 23 |
+
|---|---|---|
|
| 24 |
+
| `lewm_object.ckpt` | 69 MB | LeWM JEPA model (encoder + AR predictor), pickled `swm.World` object |
|
| 25 |
+
| `prior_head_cube.pt` | 2.3 MB | PRISM action prior head — 3-layer MLP, trained with β-NLL (β = 0.5) on cube demos |
|
| 26 |
+
|
| 27 |
+
## Headline result (K = 128, mean ± std over seeds {0, 1, 42})
|
| 28 |
+
|
| 29 |
+
| Method | SR (%) |
|
| 30 |
+
|-----------------------------------------|--------:|
|
| 31 |
+
| Vanilla MPPI | 44.0 |
|
| 32 |
+
| BC-only (prior mean → planner) | 66.0 |
|
| 33 |
+
| **PRISM-MPPI (s = 1)** | **79.3 ± 6.1** |
|
| 34 |
+
|
| 35 |
+
PRISM-MPPI's only hyperparameter is the prior scale `s`; we use `s = 1` (no
|
| 36 |
+
inflation). See the paper Sec. 4.4 for the full `s`-sweep.
|
| 37 |
+
|
| 38 |
+
## Usage
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
# 1. Clone the code repo and set up the env (see the repo README):
|
| 42 |
+
git clone git@github.com:YuhaiW/prism-jepa.git
|
| 43 |
+
cd prism-jepa && source .venv/bin/activate
|
| 44 |
+
export STABLEWM_HOME=~/.stable-wm
|
| 45 |
+
|
| 46 |
+
# 2. Download these weights:
|
| 47 |
+
hf download YuhaiW/prism-jepa-cube --local-dir ./hf_cube
|
| 48 |
+
mkdir -p $STABLEWM_HOME/cube
|
| 49 |
+
mv hf_cube/lewm_object.ckpt $STABLEWM_HOME/cube/
|
| 50 |
+
mv hf_cube/prior_head_cube.pt .
|
| 51 |
+
|
| 52 |
+
# 3. Run PRISM-MPPI:
|
| 53 |
+
python eval_prism_head.py --config-name=cube policy=cube/lewm solver=mppi \
|
| 54 |
+
+head.injection_mode=pog +head.sigma_scale=1.0 \
|
| 55 |
+
+head.ckpt=prior_head_cube.pt eval.num_eval=50
|
| 56 |
+
|
| 57 |
+
# Vanilla MPPI baseline (no prior):
|
| 58 |
+
python eval_prism_head.py --config-name=cube policy=cube/lewm solver=mppi \
|
| 59 |
+
+head.injection_mode=none eval.num_eval=50
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Training (summary)
|
| 63 |
+
|
| 64 |
+
- **`lewm_object.ckpt`** — LeWM JEPA trained from scratch on OGBench
|
| 65 |
+
`cube_single_expert` (upstream recipe in `train.py`).
|
| 66 |
+
- **`prior_head_cube.pt`** — 3-layer MLP mapping JEPA embedding `z_t` to a
|
| 67 |
+
per-coordinate Gaussian `(μ_p, σ_p)` over the next action block. β-NLL loss
|
| 68 |
+
(β = 0.5), `sigma_floor = 0.05`, Adam, 50 epochs. See `train_prior_head.py`.
|
| 69 |
+
|
| 70 |
+
## Citation
|
| 71 |
+
|
| 72 |
+
_BibTeX to be added._
|
| 73 |
+
|
| 74 |
+
## License
|
| 75 |
+
|
| 76 |
+
MIT for the PRISM head + integration code. The vendored LeWM code in
|
| 77 |
+
`prism-jepa` inherits its upstream license — see
|
| 78 |
+
[le-wm](https://github.com/lucas-maes/le-wm).
|
lewm_object.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82d37a9d9338d8c23005017ab5c1ff91c8b5e3fd51fafbd620af8457c381d125
|
| 3 |
+
size 72344949
|
prior_head_cube.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bbfacb047d7ea68370d07a56185099807cc1a9536034fbe53cdbfb3f6d78dec
|
| 3 |
+
size 2358901
|