YuhaiW commited on
Commit
d575951
·
verified ·
1 Parent(s): a1a3bf5

Initial release: LeWM + PRISM prior head for Cube

Browse files
Files changed (3) hide show
  1. README.md +78 -0
  2. lewm_object.ckpt +3 -0
  3. prior_head_cube.pt +3 -0
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ pipeline_tag: robotics
5
+ tags:
6
+ - robotics
7
+ - jepa
8
+ - world-model
9
+ - visual-manipulation
10
+ - prism
11
+ - mppi
12
+ - planning
13
+ ---
14
+
15
+ # PRISM-JEPA · Cube (OGBench `cube-single`)
16
+
17
+ A JEPA visual world model paired with a learned **action prior head** that
18
+ biases an MPPI planner via closed-form Product-of-Gaussians fusion. Together
19
+ they form the **PRISM-MPPI** pipeline from the PRISM paper, evaluated on
20
+ OGBench's `cube-single` task.
21
+
22
+ | File | Size | Role |
23
+ |---|---|---|
24
+ | `lewm_object.ckpt` | 69 MB | LeWM JEPA model (encoder + AR predictor), pickled `swm.World` object |
25
+ | `prior_head_cube.pt` | 2.3 MB | PRISM action prior head — 3-layer MLP, trained with β-NLL (β = 0.5) on cube demos |
26
+
27
+ ## Headline result (K = 128, mean ± std over seeds {0, 1, 42})
28
+
29
+ | Method | SR (%) |
30
+ |-----------------------------------------|--------:|
31
+ | Vanilla MPPI | 44.0 |
32
+ | BC-only (prior mean → planner) | 66.0 |
33
+ | **PRISM-MPPI (s = 1)** | **79.3 ± 6.1** |
34
+
35
+ PRISM-MPPI's only hyperparameter is the prior scale `s`; we use `s = 1` (no
36
+ inflation). See the paper Sec. 4.4 for the full `s`-sweep.
37
+
38
+ ## Usage
39
+
40
+ ```bash
41
+ # 1. Clone the code repo and set up the env (see the repo README):
42
+ git clone git@github.com:YuhaiW/prism-jepa.git
43
+ cd prism-jepa && source .venv/bin/activate
44
+ export STABLEWM_HOME=~/.stable-wm
45
+
46
+ # 2. Download these weights:
47
+ hf download YuhaiW/prism-jepa-cube --local-dir ./hf_cube
48
+ mkdir -p $STABLEWM_HOME/cube
49
+ mv hf_cube/lewm_object.ckpt $STABLEWM_HOME/cube/
50
+ mv hf_cube/prior_head_cube.pt .
51
+
52
+ # 3. Run PRISM-MPPI:
53
+ python eval_prism_head.py --config-name=cube policy=cube/lewm solver=mppi \
54
+ +head.injection_mode=pog +head.sigma_scale=1.0 \
55
+ +head.ckpt=prior_head_cube.pt eval.num_eval=50
56
+
57
+ # Vanilla MPPI baseline (no prior):
58
+ python eval_prism_head.py --config-name=cube policy=cube/lewm solver=mppi \
59
+ +head.injection_mode=none eval.num_eval=50
60
+ ```
61
+
62
+ ## Training (summary)
63
+
64
+ - **`lewm_object.ckpt`** — LeWM JEPA trained from scratch on OGBench
65
+ `cube_single_expert` (upstream recipe in `train.py`).
66
+ - **`prior_head_cube.pt`** — 3-layer MLP mapping JEPA embedding `z_t` to a
67
+ per-coordinate Gaussian `(μ_p, σ_p)` over the next action block. β-NLL loss
68
+ (β = 0.5), `sigma_floor = 0.05`, Adam, 50 epochs. See `train_prior_head.py`.
69
+
70
+ ## Citation
71
+
72
+ _BibTeX to be added._
73
+
74
+ ## License
75
+
76
+ MIT for the PRISM head + integration code. The vendored LeWM code in
77
+ `prism-jepa` inherits its upstream license — see
78
+ [le-wm](https://github.com/lucas-maes/le-wm).
lewm_object.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82d37a9d9338d8c23005017ab5c1ff91c8b5e3fd51fafbd620af8457c381d125
3
+ size 72344949
prior_head_cube.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bbfacb047d7ea68370d07a56185099807cc1a9536034fbe53cdbfb3f6d78dec
3
+ size 2358901