- PRISM-LeWM for Franka FR3 Planar PushT β v4 (multi-step supervision)
PRISM-LeWM for Franka FR3 Planar PushT β v4 (multi-step supervision)
Research preview of a LeWM-style JEPA world model trained on 200 real-robot demonstrations (51 k frames, Franka FR3, 10 Hz) using multi-step autoregressive predictor supervision. This is the first public release of LeWM trained with this scheme.
Status: deploy-marginal. Cost-surface discriminativity improved 2.17Γ over LeWM-default training (CV 0.13 β 0.29), and the Test A in-distribution deploy gate is the closest to PASS we've achieved (3.93Γ vs PASS β€ 2Γ). Real-robot SR is not yet measured. Use as a baseline / starting point for further work, not as a production-ready policy.
TL;DR
| Source dataset | Rongxuan-Zhou/pusht_lewm_fr3 (200 demos, 51 492 frames @ 10 Hz, Franka FR3) |
| Action space | 2-dim (dx, dy) β planar delta-EE in meters |
| Encoder | ViT-tiny from scratch (5.5 M params) |
| Predictor | LeWM-default ARPredictor (10.79 M params, depth 6 / heads 16 / mlp 2048) |
| Training novelty | 3-step autoregressive supervision (predictor sees own predictions during training; TD-MPC-style stop-gradient between steps). LeWM-default is 1-step teacher-forced. |
| Prior head | PriorHead MLP, val MSE drop 14.7 % (marginal vs 15 % gate) |
| Cost-surface CV @ H=5 | 0.29 vs sim PushT 0.13 vs LeWM-default-on-our-data 0.13 |
| Bundle size | ~ 80 MB |
What's in here
| File | Size | Role |
|---|---|---|
lewm_pusht_lewm_fr3_v4b.ckpt |
72 MB | LeWM world model (JEPA encoder + multi-step-trained AR predictor) |
prior_head_pusht_lewm_fr3_v4b.pt |
2 MB | PRISM prior head + StandardScaler |
franka_pusht_lewm_inference.py |
13 KB | Self-contained PRISM-MPPI / vanilla-MPPI inference loop |
jepa.py, module.py |
~ 14 KB | Model classes (required to unpickle) |
prior_head.py |
2.4 KB | PriorHead class |
requirements.txt |
0.4 KB | Runtime deps |
README.md |
this file | Usage + deploy guidance |
Installation
pip install huggingface_hub
python -c "from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='YuhaiW/prism-jepa-franka-pusht-v4', \
local_dir='./franka_pusht_v4_bundle')"
cd franka_pusht_v4_bundle/
pip install -r requirements.txt
PyTorch β₯ 2.1 + a CUDA GPU recommended. CPU works but plan() is
~ 10Γ slower.
Quick start β running PRISM-MPPI
import numpy as np
from franka_pusht_lewm_inference import (
PrismMPPIInference,
pad_2d_to_6d_franka,
)
# PRISM-MPPI (uses the prior head to bias MPPI's initial action distribution)
planner = PrismMPPIInference(
lewm_ckpt = "lewm_pusht_lewm_fr3_v4b.ckpt",
prior_ckpt = "prior_head_pusht_lewm_fr3_v4b.pt",
use_prism = True,
device = "cuda",
)
# In your control loop:
obs_uint8 = camera.read_d455_agent_view() # (224, 224, 3) uint8 RGB
goal_uint8 = goal_image # (224, 224, 3) uint8 RGB
actions_2d = planner.plan(obs_uint8, goal_uint8)
# β (5, 2) array of (dx, dy) deltas in meters
for a2d in actions_2d:
a6d = pad_2d_to_6d_franka(a2d) # (6,) dx, dy, 0, 0, 0, 0
robot.send_delta_ee(a6d)
For vanilla MPPI (no prior, useful for A/B comparison):
planner = PrismMPPIInference(..., use_prism=False, ...)
How MPPI works in this code (~ 4 sentences)
For each plan() call:
- Encode the current
obsandgoalimages via the JEPA encoder β latent vectorsz_t,z_g(each 192-dim). - Init MPPI distribution as N(0, var_scaleΒ²) in StandardScaler-normalized
action space. If
use_prism=True, PoG-fuse this with the prior head's (ΞΌ_p, Ο_p) β biased starting distribution. - MPPI loop (30 iterations): sample K=128 candidate action sequences
from the current N(ΞΌ, ΟΒ²); for each, autoregressively roll out the
predictor for H=5 plan-steps to get a predicted
z_5; computecost = βz_5 - z_gβΒ²; reweight candidates bysoftmax(-Ξ² Β· cost); updateΞΌto the importance-weighted mean. Ο is frozen across iterations β this is the PRISM-MPPI signature. - Return the first
A_block = 5env-step actions of the converged ΞΌ, denormalised back to raw env units (meters).
If you'd like to tweak: H, K, n_iters, var_scale, temperature,
prior_sigma_scale, history_size are all constructor args of
PrismMPPIInference. Defaults match the LeWM-sim paper conventions.
Franka FR3 deployment notes
Action format
The model outputs 2-dim actions (dx, dy) in meters per 10 Hz env-step.
The 4 other dims (dz, drx, dry, drz) that some Franka teleop pipelines
record are floating-point jitter in our dataset and were dropped at
the data level β see Β§"Why drop 4 dims" below for evidence.
For Franka FR3 controllers that expect 6-dim (dx, dy, dz, drx, dry, drz),
pad with zeros:
a2d = planner.plan(obs, goal)[0]
a6d = pad_2d_to_6d_franka(a2d) # (6,) = [dx, dy, 0, 0, 0, 0]
robot.send_delta_ee(a6d)
Training distribution per env-step:
dxrange Β±0.025 m, std 0.005 mdyrange Β±0.030 m, std 0.006 m- Frame rate: 10 Hz β average speed ~ 5 cm/s
Suggested safety clamps
Even a stable model occasionally outputs actions a few Γ the training range. Clamp before sending:
ACTION_CLAMP_2D = np.array([0.030, 0.035]) # ~1.2Γ max(|training|)
def clamp_safety(a):
return np.clip(a, -ACTION_CLAMP_2D, +ACTION_CLAMP_2D)
actions_2d = clamp_safety(planner.plan(obs, goal))
Also on the hardware side:
- Workspace bounding box: roughly
x β [0.35, 0.81],y β [-0.30, 0.37],z = 0.099 Β± 0.005(planar-lock height) in robot base frame - Operator e-stop physically reachable
- First N trials at 0.5Γ velocity scaling
- Time-based MAX_STEPS cap (β 5 s at 10 Hz)
Receding-horizon control loop
CONTROL_DT = 0.1 # 10 Hz (matches training)
N_EXEC = 5 # = A_block; execute then replan
MAX_STEPS = 50 # β 5 s safety cap
step = 0
while step < MAX_STEPS:
obs = preprocess_to_224(camera.read())
if task_complete(obs, goal_uint8):
break
actions_2d = clamp_safety(planner.plan(obs, goal_uint8))
for a2d in actions_2d[:N_EXEC]:
a6d = pad_2d_to_6d_franka(a2d)
robot.send_delta_ee(a6d)
time.sleep(CONTROL_DT)
step += 1
if step >= MAX_STEPS: break
robot.move_to_home()
Diagnostic numbers β full honesty
We evaluate three concentric tests:
encoder predictor MPPI usability
health quality (the deploy gate)
βββββββββ ββββββββββ βββββββββββββββ
Test LOEO RΒ² pred/id ratio cost-surface CV
Sim ref n/a 0.79 (h=1) 0.13
v3 baseline +0.03 0.97 (h=1) 0.13
v4b (this) (not measured) (not measured) β
0.29 β
Cost-surface CV in detail
Tested on 12 random (z_t, z_g) training pairs, K=256 random action
candidates each. The CV (std(cost) / mean(cost)) measures how
discriminative the cost surface is across action candidates:
CV mean CV max GT rank pct
v3 baseline 0.132 0.419 33 %
v4b (this) 0.286 0.712 44 %
sim PushT 0.133 0.367 59 %
PASS gate β₯ 0.30 β β€ 25 %
v4b's CV (0.29) is 2.17Γ the LeWM-default training (0.13) on the same
data β the largest movement we've seen from any single technique. It
exceeds the sim PushT reference (0.13). But it falls 0.014 short of
our 0.30 PASS threshold, and GT_rank_pct worsened (33 β 44 %),
meaning the predictor's cost ranks ground-truth-like actions only
slightly better than random.
Test A β in-distribution sanity (deploy gate)
Feed training-distribution (obs, goal) pairs to the planner; compare
output to ground-truth teleop action.
pred |a| ratio cos(pred, truth)
truth (reference) 6.5 mm 1.00Γ +1.000
v3 baseline 33 mm 5.08Γ β0.05
v4b PRISM H=5 26 mm 3.93Γ +0.10
v4b vanilla H=5 32 mm 4.92Γ +0.10
v4b is the closest to truth across all configs we've tried β 22 % better ratio than v3 baseline, direction sign flipped from -0.05 to +0.10. Still well above the PASS gate (β€ 2Γ ratio, β₯ +0.3 cos), but clear directional progress.
Honest summary
| Axis | v3 β v4b | PASS? |
|---|---|---|
| Cost-surface CV | 0.13 β 0.29 (+118 %) | β marginal (need 0.30) |
| Test A ratio | 5.08Γ β 3.93Γ (-23 %) | β (need β€ 2Γ) |
| Test A cos | β0.05 β +0.10 | β (need β₯ +0.3) |
| Encoder collapse | healthy | β |
| Predictor stability | epochs 1β2 unstable, recovered by 18 | β (with our hyperparams) |
What this experiment shows
Multi-step autoregressive supervision genuinely improves LeWM-MPPI's cost-surface discriminativity on small real-robot data. The 2.17Γ CV improvement is the largest single-technique improvement we measured across V2, v3, and v4b. It's the technique LeWM-default skipped (and didn't need on 2.34 M sim frames).
AR supervision is fragile on a JEPA + SIGReg + small-data setup. Our first attempt (
num_preds=5,lr=5e-5,dropout=0.1) silently crashed at epoch 3 with diverging val/sigreg. The conservative recipe that works in this release:num_preds=3,lr=2e-5,predictor.dropout=0.0, full BN running stats, stop-gradient between AR rollout steps.The gap from "discriminative cost surface" to "deploy-ready" is non-trivial. Even with CV 2Γ sim's, our
GT_rank_pctregressed (33 β 44 %): the cost surface gets sharper but doesn't point more strongly at the ground-truth action. Likely fixes (not in this release): stronger / deeper prior head, more demos, scheduled-sampling schedule on the autoregressive rollout.
v4 vs v3 vs v2
| v2 (negative case) | v3 (LeWM-default scale-up) | v4 (multi-step) | |
|---|---|---|---|
| Source repo | YuhaiW/prism-jepa-franka-pusht-v2 (private) |
(not separately released) | this repo (public) |
| Demos | 50 | 200 | 200 |
| Frames | 18 k | 51 k | 51 k |
| Predictor training | 1-step teacher-forced + small predictor + manual normalization bug | LeWM-default 1-step | multi-step AR, num_preds=3 |
| Cost-surface CV | n/a | 0.13 | 0.29 |
| Deploy status | known-broken negative case | wandering on real robot | research preview, deploy-marginal |
Why drop 4 action dims
Recorded 6-dim Franka teleop action has only 2 meaningful dims:
| dim | std | std / dx_std | interpretation |
|---|---|---|---|
| dx | 0.0049 | 1.00 | signal |
| dy | 0.0058 | 1.17 | signal |
| dz | 0.0002 | 0.04 | floating-point jitter |
| drx | 0.0026 | 0.49 | Quest controller drift |
| dry | 0.0032 | 0.65 | Quest controller drift |
| drz | 0.0012 | 0.25 | Quest controller drift |
dz is essentially zero-noise; drx/dry/drz are teleop tracking
artefacts. After per-dim StandardScaler they'd all become unit-std and
look "equally important" to the model. Dropping at the data level
lets the model focus on the actual signal.
Caveats
- Not validated on real-robot SR. The release is offline-diagnostic- best, not closed-loop success.
- Single seed. Multi-seed variance unmeasured.
- Encoder is ViT-tiny from scratch on 51 k frames β known to be sensitive to deploy-scene distribution shift (lighting, table colour, camera intrinsics). For best results, match deploy scene to training scene as closely as possible.
- PRISM prior is moderate-quality. Val MSE drop 14.7 % (marginal vs 15 % gate). A deeper prior head would likely help further.
Citation / context
If you use this ckpt in research, please cite the LeWM paper for the underlying architecture and the multi-step autoregressive training trick from world-model literature (Dreamer / TD-MPC / PlaNet style):
- LeWM: LeWorldModel: First Stable JEPA from Pixels. HF:
papers/2603.19312 - TD-MPC / TD-MPC2: Hansen et al., for the multi-horizon AR loss style
- Dreamer / DreamerV3: Hafner et al., for imagined rollouts in training
- This release is part of the PRISM-JEPA project (paper in preparation).
License
apache-2.0