PRISM-LeWM for Franka FR3 Planar PushT β€” v4 (multi-step supervision)

Research preview of a LeWM-style JEPA world model trained on 200 real-robot demonstrations (51 k frames, Franka FR3, 10 Hz) using multi-step autoregressive predictor supervision. This is the first public release of LeWM trained with this scheme.

Status: deploy-marginal. Cost-surface discriminativity improved 2.17Γ— over LeWM-default training (CV 0.13 β†’ 0.29), and the Test A in-distribution deploy gate is the closest to PASS we've achieved (3.93Γ— vs PASS ≀ 2Γ—). Real-robot SR is not yet measured. Use as a baseline / starting point for further work, not as a production-ready policy.

TL;DR

Source dataset Rongxuan-Zhou/pusht_lewm_fr3 (200 demos, 51 492 frames @ 10 Hz, Franka FR3)
Action space 2-dim (dx, dy) β€” planar delta-EE in meters
Encoder ViT-tiny from scratch (5.5 M params)
Predictor LeWM-default ARPredictor (10.79 M params, depth 6 / heads 16 / mlp 2048)
Training novelty 3-step autoregressive supervision (predictor sees own predictions during training; TD-MPC-style stop-gradient between steps). LeWM-default is 1-step teacher-forced.
Prior head PriorHead MLP, val MSE drop 14.7 % (marginal vs 15 % gate)
Cost-surface CV @ H=5 0.29 vs sim PushT 0.13 vs LeWM-default-on-our-data 0.13
Bundle size ~ 80 MB

What's in here

File Size Role
lewm_pusht_lewm_fr3_v4b.ckpt 72 MB LeWM world model (JEPA encoder + multi-step-trained AR predictor)
prior_head_pusht_lewm_fr3_v4b.pt 2 MB PRISM prior head + StandardScaler
franka_pusht_lewm_inference.py 13 KB Self-contained PRISM-MPPI / vanilla-MPPI inference loop
jepa.py, module.py ~ 14 KB Model classes (required to unpickle)
prior_head.py 2.4 KB PriorHead class
requirements.txt 0.4 KB Runtime deps
README.md this file Usage + deploy guidance

Installation

pip install huggingface_hub
python -c "from huggingface_hub import snapshot_download; \
    snapshot_download(repo_id='YuhaiW/prism-jepa-franka-pusht-v4', \
                      local_dir='./franka_pusht_v4_bundle')"
cd franka_pusht_v4_bundle/
pip install -r requirements.txt

PyTorch β‰₯ 2.1 + a CUDA GPU recommended. CPU works but plan() is ~ 10Γ— slower.

Quick start β€” running PRISM-MPPI

import numpy as np
from franka_pusht_lewm_inference import (
    PrismMPPIInference,
    pad_2d_to_6d_franka,
)

# PRISM-MPPI (uses the prior head to bias MPPI's initial action distribution)
planner = PrismMPPIInference(
    lewm_ckpt  = "lewm_pusht_lewm_fr3_v4b.ckpt",
    prior_ckpt = "prior_head_pusht_lewm_fr3_v4b.pt",
    use_prism  = True,
    device     = "cuda",
)

# In your control loop:
obs_uint8  = camera.read_d455_agent_view()    # (224, 224, 3) uint8 RGB
goal_uint8 = goal_image                        # (224, 224, 3) uint8 RGB
actions_2d = planner.plan(obs_uint8, goal_uint8)
# β†’ (5, 2) array of (dx, dy) deltas in meters

for a2d in actions_2d:
    a6d = pad_2d_to_6d_franka(a2d)             # (6,)  dx, dy, 0, 0, 0, 0
    robot.send_delta_ee(a6d)

For vanilla MPPI (no prior, useful for A/B comparison):

planner = PrismMPPIInference(..., use_prism=False, ...)

How MPPI works in this code (~ 4 sentences)

For each plan() call:

  1. Encode the current obs and goal images via the JEPA encoder β†’ latent vectors z_t, z_g (each 192-dim).
  2. Init MPPI distribution as N(0, var_scaleΒ²) in StandardScaler-normalized action space. If use_prism=True, PoG-fuse this with the prior head's (ΞΌ_p, Οƒ_p) β†’ biased starting distribution.
  3. MPPI loop (30 iterations): sample K=128 candidate action sequences from the current N(ΞΌ, σ²); for each, autoregressively roll out the predictor for H=5 plan-steps to get a predicted z_5; compute cost = β€–z_5 - z_gβ€–Β²; reweight candidates by softmax(-Ξ² Β· cost); update ΞΌ to the importance-weighted mean. Οƒ is frozen across iterations β€” this is the PRISM-MPPI signature.
  4. Return the first A_block = 5 env-step actions of the converged ΞΌ, denormalised back to raw env units (meters).

If you'd like to tweak: H, K, n_iters, var_scale, temperature, prior_sigma_scale, history_size are all constructor args of PrismMPPIInference. Defaults match the LeWM-sim paper conventions.

Franka FR3 deployment notes

Action format

The model outputs 2-dim actions (dx, dy) in meters per 10 Hz env-step. The 4 other dims (dz, drx, dry, drz) that some Franka teleop pipelines record are floating-point jitter in our dataset and were dropped at the data level β€” see Β§"Why drop 4 dims" below for evidence.

For Franka FR3 controllers that expect 6-dim (dx, dy, dz, drx, dry, drz), pad with zeros:

a2d = planner.plan(obs, goal)[0]
a6d = pad_2d_to_6d_franka(a2d)              # (6,) = [dx, dy, 0, 0, 0, 0]
robot.send_delta_ee(a6d)

Training distribution per env-step:

  • dx range Β±0.025 m, std 0.005 m
  • dy range Β±0.030 m, std 0.006 m
  • Frame rate: 10 Hz β†’ average speed ~ 5 cm/s

Suggested safety clamps

Even a stable model occasionally outputs actions a few Γ— the training range. Clamp before sending:

ACTION_CLAMP_2D = np.array([0.030, 0.035])   # ~1.2Γ— max(|training|)
def clamp_safety(a):
    return np.clip(a, -ACTION_CLAMP_2D, +ACTION_CLAMP_2D)

actions_2d = clamp_safety(planner.plan(obs, goal))

Also on the hardware side:

  • Workspace bounding box: roughly x ∈ [0.35, 0.81], y ∈ [-0.30, 0.37], z = 0.099 Β± 0.005 (planar-lock height) in robot base frame
  • Operator e-stop physically reachable
  • First N trials at 0.5Γ— velocity scaling
  • Time-based MAX_STEPS cap (β‰ˆ 5 s at 10 Hz)

Receding-horizon control loop

CONTROL_DT = 0.1           # 10 Hz (matches training)
N_EXEC = 5                  # = A_block; execute then replan
MAX_STEPS = 50              # β‰ˆ 5 s safety cap

step = 0
while step < MAX_STEPS:
    obs = preprocess_to_224(camera.read())
    if task_complete(obs, goal_uint8):
        break
    actions_2d = clamp_safety(planner.plan(obs, goal_uint8))
    for a2d in actions_2d[:N_EXEC]:
        a6d = pad_2d_to_6d_franka(a2d)
        robot.send_delta_ee(a6d)
        time.sleep(CONTROL_DT)
        step += 1
        if step >= MAX_STEPS: break
robot.move_to_home()

Diagnostic numbers β€” full honesty

We evaluate three concentric tests:

                  encoder         predictor       MPPI usability
                  health          quality         (the deploy gate)
                  ─────────       ──────────      ───────────────
  Test           LOEO RΒ²          pred/id ratio   cost-surface CV
  Sim ref        n/a              0.79 (h=1)      0.13
  v3 baseline    +0.03            0.97 (h=1)      0.13
  v4b (this)     (not measured)   (not measured)  β˜… 0.29 β˜…

Cost-surface CV in detail

Tested on 12 random (z_t, z_g) training pairs, K=256 random action candidates each. The CV (std(cost) / mean(cost)) measures how discriminative the cost surface is across action candidates:

                CV mean   CV max    GT rank pct
  v3 baseline   0.132     0.419     33 %
  v4b (this)    0.286     0.712     44 %
  sim PushT     0.133     0.367     59 %
  PASS gate     β‰₯ 0.30    β€”         ≀ 25 %

v4b's CV (0.29) is 2.17Γ— the LeWM-default training (0.13) on the same data β€” the largest movement we've seen from any single technique. It exceeds the sim PushT reference (0.13). But it falls 0.014 short of our 0.30 PASS threshold, and GT_rank_pct worsened (33 β†’ 44 %), meaning the predictor's cost ranks ground-truth-like actions only slightly better than random.

Test A β€” in-distribution sanity (deploy gate)

Feed training-distribution (obs, goal) pairs to the planner; compare output to ground-truth teleop action.

                       pred |a|     ratio       cos(pred, truth)
  truth (reference)    6.5 mm       1.00Γ—       +1.000
  v3 baseline          33 mm        5.08Γ—       βˆ’0.05
  v4b PRISM H=5        26 mm        3.93Γ—       +0.10
  v4b vanilla H=5      32 mm        4.92Γ—       +0.10

v4b is the closest to truth across all configs we've tried β€” 22 % better ratio than v3 baseline, direction sign flipped from -0.05 to +0.10. Still well above the PASS gate (≀ 2Γ— ratio, β‰₯ +0.3 cos), but clear directional progress.

Honest summary

Axis v3 β†’ v4b PASS?
Cost-surface CV 0.13 β†’ 0.29 (+118 %) βœ— marginal (need 0.30)
Test A ratio 5.08Γ— β†’ 3.93Γ— (-23 %) βœ— (need ≀ 2Γ—)
Test A cos βˆ’0.05 β†’ +0.10 βœ— (need β‰₯ +0.3)
Encoder collapse healthy βœ“
Predictor stability epochs 1–2 unstable, recovered by 18 βœ“ (with our hyperparams)

What this experiment shows

  1. Multi-step autoregressive supervision genuinely improves LeWM-MPPI's cost-surface discriminativity on small real-robot data. The 2.17Γ— CV improvement is the largest single-technique improvement we measured across V2, v3, and v4b. It's the technique LeWM-default skipped (and didn't need on 2.34 M sim frames).

  2. AR supervision is fragile on a JEPA + SIGReg + small-data setup. Our first attempt (num_preds=5, lr=5e-5, dropout=0.1) silently crashed at epoch 3 with diverging val/sigreg. The conservative recipe that works in this release: num_preds=3, lr=2e-5, predictor.dropout=0.0, full BN running stats, stop-gradient between AR rollout steps.

  3. The gap from "discriminative cost surface" to "deploy-ready" is non-trivial. Even with CV 2Γ— sim's, our GT_rank_pct regressed (33 β†’ 44 %): the cost surface gets sharper but doesn't point more strongly at the ground-truth action. Likely fixes (not in this release): stronger / deeper prior head, more demos, scheduled-sampling schedule on the autoregressive rollout.

v4 vs v3 vs v2

v2 (negative case) v3 (LeWM-default scale-up) v4 (multi-step)
Source repo YuhaiW/prism-jepa-franka-pusht-v2 (private) (not separately released) this repo (public)
Demos 50 200 200
Frames 18 k 51 k 51 k
Predictor training 1-step teacher-forced + small predictor + manual normalization bug LeWM-default 1-step multi-step AR, num_preds=3
Cost-surface CV n/a 0.13 0.29
Deploy status known-broken negative case wandering on real robot research preview, deploy-marginal

Why drop 4 action dims

Recorded 6-dim Franka teleop action has only 2 meaningful dims:

dim std std / dx_std interpretation
dx 0.0049 1.00 signal
dy 0.0058 1.17 signal
dz 0.0002 0.04 floating-point jitter
drx 0.0026 0.49 Quest controller drift
dry 0.0032 0.65 Quest controller drift
drz 0.0012 0.25 Quest controller drift

dz is essentially zero-noise; drx/dry/drz are teleop tracking artefacts. After per-dim StandardScaler they'd all become unit-std and look "equally important" to the model. Dropping at the data level lets the model focus on the actual signal.

Caveats

  • Not validated on real-robot SR. The release is offline-diagnostic- best, not closed-loop success.
  • Single seed. Multi-seed variance unmeasured.
  • Encoder is ViT-tiny from scratch on 51 k frames β€” known to be sensitive to deploy-scene distribution shift (lighting, table colour, camera intrinsics). For best results, match deploy scene to training scene as closely as possible.
  • PRISM prior is moderate-quality. Val MSE drop 14.7 % (marginal vs 15 % gate). A deeper prior head would likely help further.

Citation / context

If you use this ckpt in research, please cite the LeWM paper for the underlying architecture and the multi-step autoregressive training trick from world-model literature (Dreamer / TD-MPC / PlaNet style):

  • LeWM: LeWorldModel: First Stable JEPA from Pixels. HF: papers/2603.19312
  • TD-MPC / TD-MPC2: Hansen et al., for the multi-horizon AR loss style
  • Dreamer / DreamerV3: Hafner et al., for imagined rollouts in training
  • This release is part of the PRISM-JEPA project (paper in preparation).

License

apache-2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading

Paper for YuhaiW/prism-jepa-franka-pusht-v4