red_cube PRISM-JEPA deploy bundle: LeWM WM + PRISM prior + self-contained inference + README
Browse files- README.md +110 -0
- arx_inference_demo.py +333 -0
- jepa.py +153 -0
- lewm_red_cube_epoch_100_object.ckpt +3 -0
- module.py +306 -0
- prior_head.py +73 -0
- prior_head_red_cube.pt +3 -0
README.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
tags:
|
| 5 |
+
- robotics
|
| 6 |
+
- world-model
|
| 7 |
+
- jepa
|
| 8 |
+
- lewm
|
| 9 |
+
- prism
|
| 10 |
+
- arx
|
| 11 |
+
- visuomotor
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# PRISM-JEPA · red_cube (ARX-X5) — LeWM world model + PRISM action prior
|
| 15 |
+
|
| 16 |
+
Complete deployable stack for goal-conditioned visuomotor planning on the ARX-X5
|
| 17 |
+
**"red cube"** task: the **LeWM JEPA world model** + the **PRISM goal-conditioned action
|
| 18 |
+
prior** + **self-contained PRISM-MPPI inference code**.
|
| 19 |
+
|
| 20 |
+
## ⚠️ Status — read first
|
| 21 |
+
|
| 22 |
+
This is a **research artifact / deployment hand-off package, NOT a validated policy.**
|
| 23 |
+
|
| 24 |
+
- Trained on 201 teleop demos (`Xia-2004/red_cube`, 42,165 frames, 5-DoF).
|
| 25 |
+
- The world model **converges cleanly and its forward model is accurate** (rollout pred/id
|
| 26 |
+
≈ 0.25 < 0.5), **but its MPPI cost surface is weak/flat** on this small-real-robot,
|
| 27 |
+
small-action data (CV ≈ **0.14** ≪ 0.30 "discriminative" threshold). Consequence: the
|
| 28 |
+
planner's distinctive **cost-rescoring is dormant**, so in practice **PRISM ≈ a
|
| 29 |
+
goal-conditioned BC prior** — in offline A/B it produces ~31% more expert-like actions
|
| 30 |
+
than vanilla LeWM-MPPI (which wanders), but adds **no measurable goal-progress** in the
|
| 31 |
+
world model's own latent metric (paired t-test p = 0.57).
|
| 32 |
+
- **Never run on the real robot.** Treat as a starting point; add workspace/velocity safety
|
| 33 |
+
limits and validate before any hardware run.
|
| 34 |
+
- Full analysis & how these numbers were obtained: project doc
|
| 35 |
+
`docs/30_red_cube_cv_investigation_and_prism.md`.
|
| 36 |
+
|
| 37 |
+
## Contents
|
| 38 |
+
|
| 39 |
+
| file | description |
|
| 40 |
+
|---|---|
|
| 41 |
+
| `lewm_red_cube_epoch_100_object.ckpt` | LeWM world model — pickled JEPA: ViT-tiny encoder + AR transformer predictor + action encoder (~18M params) |
|
| 42 |
+
| `prior_head_red_cube.pt` | PRISM goal-conditioned action prior — `state_dict` + `config` + action `StandardScaler` (mean/scale) |
|
| 43 |
+
| `arx_inference_demo.py` | self-contained `PrismMPPIInference` (PoG-fused PRISM-MPPI; `use_prism=False` → vanilla LeWM-MPPI) |
|
| 44 |
+
| `jepa.py`, `module.py`, `prior_head.py` | model classes required to unpickle the ckpt and run the prior |
|
| 45 |
+
|
| 46 |
+
## Observation / action space
|
| 47 |
+
|
| 48 |
+
- **Observation:** single top-down RGB frame, **224×224×3 uint8** (RealSense `camera_third`).
|
| 49 |
+
- **Goal:** an RGB goal image, same format (the prior + cost are conditioned on it).
|
| 50 |
+
- **Action:** 5-DoF delta end-effector **`[dx, dy, dz, dyaw, d_gripper]`**, raw units, one per
|
| 51 |
+
control tick. `plan()` returns one plan-step = `A_block = 5` ticks → shape **`(5, 5)`**.
|
| 52 |
+
|
| 53 |
+
## Dependencies
|
| 54 |
+
|
| 55 |
+
`torch`, `numpy`, `einops`, and `transformers` (the encoder inside the ckpt is a HuggingFace
|
| 56 |
+
ViT, needed at unpickle time). The three bundled `.py` files must be importable from the
|
| 57 |
+
working directory. (If unpickling complains about a missing class, also `pip install
|
| 58 |
+
stable-pretraining`.)
|
| 59 |
+
|
| 60 |
+
## Deploy — receding-horizon control loop
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from arx_inference_demo import PrismMPPIInference
|
| 64 |
+
|
| 65 |
+
planner = PrismMPPIInference(
|
| 66 |
+
lewm_ckpt = "lewm_red_cube_epoch_100_object.ckpt",
|
| 67 |
+
prior_ckpt = "prior_head_red_cube.pt",
|
| 68 |
+
use_prism = True, # True = PRISM (prior ⊗ MPPI via PoG fusion); False = vanilla LeWM-MPPI
|
| 69 |
+
device = "cuda",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
goal_img = load_goal_image() # (224,224,3) uint8 — the task goal image
|
| 73 |
+
while not done:
|
| 74 |
+
obs = camera.read() # (224,224,3) uint8, top-down camera_third view
|
| 75 |
+
actions = planner.plan(obs, goal_img) # (5, 5) raw [dx,dy,dz,dyaw,d_gripper]
|
| 76 |
+
for a in actions: # receding horizon: execute the block, then replan
|
| 77 |
+
robot.execute(a) # (or execute fewer than 5 and replan more often)
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
`plan()` runs one full PRISM-MPPI optimization and returns the first `A_block = 5` env-step
|
| 81 |
+
actions of the optimized plan, in **raw action units** (already de-normalized).
|
| 82 |
+
|
| 83 |
+
## Key hyperparameters (`PrismMPPIInference` constructor)
|
| 84 |
+
|
| 85 |
+
| arg | default | meaning |
|
| 86 |
+
|---|---|---|
|
| 87 |
+
| `H` | 5 | planning horizon (plan-steps) |
|
| 88 |
+
| `A_block` | 5 | env-steps (ticks) per plan-step ("frameskip") |
|
| 89 |
+
| `K` | 128 | MPPI samples per iteration |
|
| 90 |
+
| `n_iters` | 30 | MPPI refinement iterations |
|
| 91 |
+
| `var_scale` | 1.0 | initial planner sampling std |
|
| 92 |
+
| `prior_sigma_scale` | 2.0 | multiplier on the prior σ before PoG fusion (PRISM only) |
|
| 93 |
+
| `temperature` | 0.5 | MPPI softmax temperature |
|
| 94 |
+
| `history_size` | 3 | LeWM history-window length (**must match training**) |
|
| 95 |
+
|
| 96 |
+
`H`, `A_block`, `A_raw`, `history_size` must match the checkpoints — the constructor asserts
|
| 97 |
+
the prior head's config agrees. Change them only if you retrain.
|
| 98 |
+
|
| 99 |
+
## PRISM vs vanilla (A/B)
|
| 100 |
+
|
| 101 |
+
Build a second planner with `use_prism=False` for a baseline (plain LeWM-MPPI, no prior,
|
| 102 |
+
same encoder/predictor/MPPI loop). On this task PRISM produces more expert-like actions;
|
| 103 |
+
vanilla tends to wander because the cost surface is flat.
|
| 104 |
+
|
| 105 |
+
## Provenance
|
| 106 |
+
|
| 107 |
+
Data: [`Xia-2004/red_cube`](https://huggingface.co/datasets/Xia-2004/red_cube) (ARX-X5
|
| 108 |
+
left-arm teleop). Sibling of `Xia-2004/arx-left-cube`. World-model architecture is identical
|
| 109 |
+
to the sim LeWM (ViT-tiny, embed_dim 192, predictor depth 6 / heads 16) — part of the
|
| 110 |
+
PRISM-JEPA project (sister of Newt-PRISM, CoRL 2026).
|
arx_inference_demo.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""arx_inference_demo.py — standalone PRISM-MPPI inference for ARX cube task.
|
| 2 |
+
|
| 3 |
+
This file is **self-contained**: it depends only on the bundled
|
| 4 |
+
`jepa.py`, `module.py`, `prior_head.py`, plus standard torch / numpy.
|
| 5 |
+
No `stable_worldmodel` import — the MPPI loop is re-implemented inline.
|
| 6 |
+
|
| 7 |
+
Intended use by a downstream consumer (e.g., the ARX deployment side):
|
| 8 |
+
|
| 9 |
+
from arx_inference_demo import PrismMPPIInference
|
| 10 |
+
|
| 11 |
+
planner = PrismMPPIInference(
|
| 12 |
+
lewm_ckpt = "lewm_arx.ckpt",
|
| 13 |
+
prior_ckpt = "prior_head_arx.pt",
|
| 14 |
+
device = "cuda",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# In the control loop:
|
| 18 |
+
while not done:
|
| 19 |
+
obs_uint8 = camera.read() # (224, 224, 3) uint8 RGB
|
| 20 |
+
goal_uint8 = goal_image # (224, 224, 3) uint8 RGB
|
| 21 |
+
actions = planner.plan(obs_uint8, goal_uint8)
|
| 22 |
+
# → (A_block, 5) float32, raw action units
|
| 23 |
+
for a in actions:
|
| 24 |
+
robot.execute(a) # step the robot
|
| 25 |
+
|
| 26 |
+
`plan()` performs one full PRISM-MPPI optimization and returns the first
|
| 27 |
+
A_block = 5 env-step actions of the optimized plan. The caller may choose
|
| 28 |
+
to execute all 5 then replan (receding-horizon, k=A_block), or execute
|
| 29 |
+
fewer and replan more often.
|
| 30 |
+
|
| 31 |
+
PRISM-MPPI summary:
|
| 32 |
+
1. JEPA encoder turns current obs + goal image into latent embeddings z_t, z_g.
|
| 33 |
+
2. PRISM prior head maps (z_t, z_g) → (μ_p, σ_p) over the next
|
| 34 |
+
H × A_block × A_raw normalized actions.
|
| 35 |
+
3. We seed an MPPI distribution N(0, var_scale I) and PoG-fuse with the
|
| 36 |
+
prior to get N(fused_μ, fused_σ²). The variance is FROZEN through MPPI
|
| 37 |
+
iterations (this is the PRISM-MPPI signature; see paper §3).
|
| 38 |
+
4. Each iteration samples K candidate action sequences, rolls them out via
|
| 39 |
+
the LeWM ARPredictor in latent space, computes cost = MSE(predicted
|
| 40 |
+
final z, z_g), reweights candidates by exp(-β·cost), updates the mean.
|
| 41 |
+
5. After n_iters iterations, the first A_block entries of the mean are
|
| 42 |
+
returned (denormalized to raw env action units via the saved
|
| 43 |
+
StandardScaler).
|
| 44 |
+
"""
|
| 45 |
+
from __future__ import annotations
|
| 46 |
+
|
| 47 |
+
from pathlib import Path
|
| 48 |
+
|
| 49 |
+
import numpy as np
|
| 50 |
+
import torch
|
| 51 |
+
import torch.nn.functional as F
|
| 52 |
+
|
| 53 |
+
# Required for unpickling the LeWM ckpt — these modules must be importable
|
| 54 |
+
import jepa # noqa: F401 — registers JEPA class
|
| 55 |
+
import module # noqa: F401 — registers ARPredictor, Embedder, etc.
|
| 56 |
+
from prior_head import PriorHead
|
| 57 |
+
|
| 58 |
+
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 59 |
+
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _preprocess(img_uint8: np.ndarray, device: torch.device) -> torch.Tensor:
|
| 63 |
+
"""uint8 (H, W, 3) → float (1, 3, 224, 224), ImageNet-normalized."""
|
| 64 |
+
assert img_uint8.shape == (224, 224, 3), \
|
| 65 |
+
f"Expected (224, 224, 3) image, got {img_uint8.shape}"
|
| 66 |
+
t = torch.from_numpy(img_uint8).permute(2, 0, 1).float().div(255.0).unsqueeze(0)
|
| 67 |
+
t = t.to(device)
|
| 68 |
+
mean = IMAGENET_MEAN.to(device)
|
| 69 |
+
std = IMAGENET_STD.to(device)
|
| 70 |
+
return (t - mean) / std
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _pog_fusion(mean, std, mu_p, sg_p, sigma_floor=0.05):
|
| 74 |
+
"""Product-of-Gaussians fusion. Matches prism_mppi.pog_fusion."""
|
| 75 |
+
eps = 1e-8
|
| 76 |
+
tau_base = 1.0 / (std ** 2 + eps)
|
| 77 |
+
tau_p = 1.0 / (sg_p ** 2 + eps)
|
| 78 |
+
tau_c = tau_base + tau_p
|
| 79 |
+
fused_mean = (tau_base * mean + tau_p * mu_p) / tau_c
|
| 80 |
+
fused_std = (1.0 / tau_c).sqrt().clamp(min=sigma_floor)
|
| 81 |
+
return fused_mean, fused_std
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class PrismMPPIInference:
|
| 85 |
+
"""Standalone PRISM-MPPI planner for ARX cube task.
|
| 86 |
+
|
| 87 |
+
Supports two modes via the `use_prism` constructor flag — kept on a
|
| 88 |
+
single class so that PRISM and vanilla-MPPI A/B comparisons use the
|
| 89 |
+
exact same encoder, predictor, MPPI loop, and StandardScaler. The
|
| 90 |
+
only difference between the two modes is whether the PoG fusion at
|
| 91 |
+
init time uses the prior head's (μ_p, σ_p) or not. The prior-head
|
| 92 |
+
checkpoint is always loaded — its StandardScaler (action
|
| 93 |
+
normalization) is shared by both modes so the comparison is
|
| 94 |
+
apples-to-apples in raw action units.
|
| 95 |
+
|
| 96 |
+
Args (paper defaults — change only if you know what you're doing):
|
| 97 |
+
lewm_ckpt: path to lewm_arx.ckpt (pickled JEPA module)
|
| 98 |
+
prior_ckpt: path to prior_head_arx.pt (PRISM head state_dict + scaler)
|
| 99 |
+
use_prism: if True (default), inject the PRISM prior via PoG fusion.
|
| 100 |
+
If False, skip the prior — the planner becomes vanilla
|
| 101 |
+
LeWM-MPPI from N(0, var_scale) seed. Use this for paper-
|
| 102 |
+
grade real-robot A/B against PRISM-MPPI.
|
| 103 |
+
H: planning horizon in plan-steps (default 5)
|
| 104 |
+
A_block: env-steps per plan-step (default 5, "frameskip")
|
| 105 |
+
K: num MPPI samples per iteration (default 128)
|
| 106 |
+
n_iters: num MPPI refinement iterations (default 30)
|
| 107 |
+
var_scale: initial planner std (default 1.0)
|
| 108 |
+
temperature: MPPI softmax temperature β = 1/temperature (default 0.5)
|
| 109 |
+
sigma_floor: lower bound on fused σ (default 0.05); only used by PRISM
|
| 110 |
+
prior_sigma_scale: multiplier on prior σ_p before fusion (default 2.0,
|
| 111 |
+
matches the paper's PRISM-MPPI s=2.0 setting); only used by PRISM
|
| 112 |
+
history_size: LeWM history-window length (default 3; must match training)
|
| 113 |
+
device: 'cuda' or 'cpu'
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
lewm_ckpt: str | Path,
|
| 119 |
+
prior_ckpt: str | Path,
|
| 120 |
+
use_prism: bool = True,
|
| 121 |
+
H: int = 5,
|
| 122 |
+
A_block: int = 5,
|
| 123 |
+
K: int = 128,
|
| 124 |
+
n_iters: int = 30,
|
| 125 |
+
var_scale: float = 1.0,
|
| 126 |
+
temperature: float = 0.5,
|
| 127 |
+
sigma_floor: float = 0.05,
|
| 128 |
+
prior_sigma_scale: float = 2.0,
|
| 129 |
+
history_size: int = 3,
|
| 130 |
+
device: str = "cuda",
|
| 131 |
+
):
|
| 132 |
+
self.device = torch.device(device)
|
| 133 |
+
self.use_prism = bool(use_prism)
|
| 134 |
+
self.H = H
|
| 135 |
+
self.A_block = A_block
|
| 136 |
+
self.K = K
|
| 137 |
+
self.n_iters = n_iters
|
| 138 |
+
self.var_scale = var_scale
|
| 139 |
+
self.beta = 1.0 / temperature
|
| 140 |
+
self.sigma_floor = sigma_floor
|
| 141 |
+
self.prior_sigma_scale = prior_sigma_scale
|
| 142 |
+
self.history_size = history_size
|
| 143 |
+
|
| 144 |
+
# ---- Load LeWM (encoder + AR predictor, pickled) ----
|
| 145 |
+
print(f"[init] loading LeWM ckpt: {lewm_ckpt}")
|
| 146 |
+
self.lewm = torch.load(
|
| 147 |
+
str(lewm_ckpt), map_location=self.device, weights_only=False,
|
| 148 |
+
)
|
| 149 |
+
self.lewm.to(self.device).eval()
|
| 150 |
+
for p in self.lewm.parameters():
|
| 151 |
+
p.requires_grad_(False)
|
| 152 |
+
|
| 153 |
+
# ---- Load PRISM prior head + scaler (scaler always used; head conditionally) ----
|
| 154 |
+
print(f"[init] loading prior head + scaler: {prior_ckpt}")
|
| 155 |
+
pck = torch.load(str(prior_ckpt), map_location=self.device, weights_only=False)
|
| 156 |
+
cfg = pck["config"]
|
| 157 |
+
self.A_raw = int(cfg["A_raw"])
|
| 158 |
+
assert cfg["H"] == self.H and cfg["A_block"] == self.A_block, (
|
| 159 |
+
f"Ckpt config mismatch: H={cfg['H']} A_block={cfg['A_block']} "
|
| 160 |
+
f"vs runtime H={self.H} A_block={self.A_block}"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if self.use_prism:
|
| 164 |
+
self.head = PriorHead(**cfg).to(self.device).eval()
|
| 165 |
+
self.head.load_state_dict(pck["state_dict"])
|
| 166 |
+
for p in self.head.parameters():
|
| 167 |
+
p.requires_grad_(False)
|
| 168 |
+
else:
|
| 169 |
+
self.head = None # vanilla LeWM-MPPI mode: skip PoG fusion
|
| 170 |
+
|
| 171 |
+
# Action denormalization (raw_action = norm_action * scale + mean) — always loaded
|
| 172 |
+
self.scaler_mean = torch.tensor(pck["scaler_mean"], device=self.device).float()
|
| 173 |
+
self.scaler_scale = torch.tensor(pck["scaler_scale"], device=self.device).float()
|
| 174 |
+
mode_str = "PRISM-MPPI" if self.use_prism else "vanilla LeWM-MPPI (PRISM off)"
|
| 175 |
+
print(f"[init] mode = {mode_str}")
|
| 176 |
+
print(f"[init] z_dim={cfg['z_dim']} H={self.H} A_block={self.A_block} "
|
| 177 |
+
f"A_raw={self.A_raw}")
|
| 178 |
+
print(f"[init] device={self.device} K={self.K} n_iters={self.n_iters}")
|
| 179 |
+
|
| 180 |
+
@torch.no_grad()
|
| 181 |
+
def _encode(self, img_uint8: np.ndarray) -> torch.Tensor:
|
| 182 |
+
"""uint8 image → (1, D) CLS embedding."""
|
| 183 |
+
x = _preprocess(img_uint8, self.device)
|
| 184 |
+
# JEPA.encode expects a dict with 'pixels' shape (B, T, C, H, W)
|
| 185 |
+
info = {"pixels": x.unsqueeze(1)} # add T=1 dim
|
| 186 |
+
info = self.lewm.encode(info)
|
| 187 |
+
return info["emb"][:, 0] # (1, D)
|
| 188 |
+
|
| 189 |
+
@torch.no_grad()
|
| 190 |
+
def _prior(self, z_t: torch.Tensor, z_g: torch.Tensor):
|
| 191 |
+
"""PRISM head: (1, D), (1, D) → (μ, σ) of shape (1, H, A_block, A_raw)
|
| 192 |
+
in normalized action space."""
|
| 193 |
+
return self.head(z_t, z_g)
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def _rollout_costs(
|
| 197 |
+
self,
|
| 198 |
+
z_t: torch.Tensor, # (1, D)
|
| 199 |
+
z_g: torch.Tensor, # (1, D)
|
| 200 |
+
action_candidates: torch.Tensor, # (1, K, H*A_block, A_raw) normalized
|
| 201 |
+
) -> torch.Tensor: # (1, K) cost per candidate
|
| 202 |
+
"""Rollout each candidate via LeWM AR predictor, compute final-z MSE to z_g."""
|
| 203 |
+
B, K, T_total, A = action_candidates.shape
|
| 204 |
+
assert T_total == self.H * self.A_block
|
| 205 |
+
D = z_t.shape[-1]
|
| 206 |
+
HS = self.history_size
|
| 207 |
+
|
| 208 |
+
# Seed embedding history with the current z_t (tile to HS length)
|
| 209 |
+
# emb: (B*K, HS, D)
|
| 210 |
+
emb = z_t.unsqueeze(1).expand(B, K, D).reshape(B * K, D)
|
| 211 |
+
emb = emb.unsqueeze(1).expand(-1, HS, -1).contiguous()
|
| 212 |
+
|
| 213 |
+
# action_seq: (B*K, T_total, A) — env-step actions; predictor consumes them block-by-block
|
| 214 |
+
act_seq = action_candidates.reshape(B * K, T_total, A)
|
| 215 |
+
|
| 216 |
+
# Group actions into plan-steps of A_block: (B*K, H, A_block * A)
|
| 217 |
+
act_plan = act_seq.reshape(B * K, self.H, self.A_block * A)
|
| 218 |
+
|
| 219 |
+
# Embed actions via the predictor's action_encoder (Embedder)
|
| 220 |
+
# act_emb: (B*K, H, action_emb_dim)
|
| 221 |
+
act_emb = self.lewm.action_encoder(act_plan)
|
| 222 |
+
|
| 223 |
+
# AR rollout
|
| 224 |
+
for t in range(self.H):
|
| 225 |
+
emb_trunc = emb[:, -HS:] # (B*K, HS, D)
|
| 226 |
+
act_trunc = act_emb[:, max(0, t - HS + 1): t + 1] # last HS actions seen
|
| 227 |
+
# Pad on the left if we don't have HS history of actions yet
|
| 228 |
+
if act_trunc.shape[1] < HS:
|
| 229 |
+
pad = act_trunc[:, :1].expand(-1, HS - act_trunc.shape[1], -1)
|
| 230 |
+
act_trunc = torch.cat([pad, act_trunc], dim=1)
|
| 231 |
+
pred = self.lewm.predict(emb_trunc, act_trunc)[:, -1:] # (B*K, 1, D)
|
| 232 |
+
emb = torch.cat([emb, pred], dim=1)
|
| 233 |
+
|
| 234 |
+
# Final predicted embedding: emb[:, -1]
|
| 235 |
+
pred_final = emb[:, -1] # (B*K, D)
|
| 236 |
+
goal = z_g.unsqueeze(1).expand(B, K, D).reshape(B * K, D)
|
| 237 |
+
cost = F.mse_loss(pred_final, goal, reduction="none").sum(dim=-1) # (B*K,)
|
| 238 |
+
return cost.reshape(B, K)
|
| 239 |
+
|
| 240 |
+
@torch.no_grad()
|
| 241 |
+
def plan(self, obs_uint8: np.ndarray, goal_uint8: np.ndarray) -> np.ndarray:
|
| 242 |
+
"""One MPPI optimization (PRISM or vanilla depending on `use_prism`).
|
| 243 |
+
|
| 244 |
+
Returns (A_block, A_raw) actions in raw env units.
|
| 245 |
+
"""
|
| 246 |
+
# 1. Encode
|
| 247 |
+
z_t = self._encode(obs_uint8) # (1, D)
|
| 248 |
+
z_g = self._encode(goal_uint8) # (1, D)
|
| 249 |
+
|
| 250 |
+
# 2. Init MPPI distribution N(0, var_scale)
|
| 251 |
+
shape = (1, self.H * self.A_block, self.A_raw)
|
| 252 |
+
mean = torch.zeros(shape, device=self.device)
|
| 253 |
+
std = torch.full(shape, self.var_scale, device=self.device)
|
| 254 |
+
|
| 255 |
+
# 3. (PRISM only) prior in normalized action space + PoG fusion
|
| 256 |
+
if self.use_prism:
|
| 257 |
+
mu_p, sg_p = self._prior(z_t, z_g) # (1, H, A_block, A_raw)
|
| 258 |
+
mu_p_flat = mu_p.reshape(*shape)
|
| 259 |
+
sg_p_flat = sg_p.reshape(*shape) * self.prior_sigma_scale
|
| 260 |
+
mean, std = _pog_fusion(mean, std, mu_p_flat, sg_p_flat, self.sigma_floor)
|
| 261 |
+
|
| 262 |
+
# 4. MPPI iterations (frozen σ — PRISM-MPPI signature when use_prism=True;
|
| 263 |
+
# matches stable_worldmodel.solver.MPPISolver default when use_prism=False)
|
| 264 |
+
for it in range(self.n_iters):
|
| 265 |
+
noise = torch.randn(
|
| 266 |
+
1, self.K, self.H * self.A_block, self.A_raw, device=self.device,
|
| 267 |
+
)
|
| 268 |
+
cands = mean.unsqueeze(1) + noise * std.unsqueeze(1)
|
| 269 |
+
# cands: (1, K, H*A_block, A_raw)
|
| 270 |
+
|
| 271 |
+
cost = self._rollout_costs(z_t, z_g, cands) # (1, K)
|
| 272 |
+
log_w = -self.beta * (cost - cost.min(dim=-1, keepdim=True).values)
|
| 273 |
+
w = torch.softmax(log_w, dim=-1) # (1, K)
|
| 274 |
+
|
| 275 |
+
# Importance-weighted mean update; std FROZEN (PRISM-MPPI)
|
| 276 |
+
mean = (w.unsqueeze(-1).unsqueeze(-1) * cands).sum(dim=1)
|
| 277 |
+
# mean: (1, H*A_block, A_raw)
|
| 278 |
+
|
| 279 |
+
# 5. First A_block actions, denormalized
|
| 280 |
+
first_block_norm = mean[0, : self.A_block] # (A_block, A_raw)
|
| 281 |
+
first_block_raw = first_block_norm * self.scaler_scale + self.scaler_mean
|
| 282 |
+
return first_block_raw.cpu().numpy().astype(np.float32)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# ===========================================================================
|
| 286 |
+
# Sanity test: load + plan on a sample from the ARX h5
|
| 287 |
+
# ===========================================================================
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
import argparse
|
| 291 |
+
|
| 292 |
+
ap = argparse.ArgumentParser()
|
| 293 |
+
ap.add_argument(
|
| 294 |
+
"--lewm-ckpt", default=".stable-wm/lewm_arx_epoch_100_object.ckpt",
|
| 295 |
+
)
|
| 296 |
+
ap.add_argument("--prior-ckpt", default="prior_head_arx.pt")
|
| 297 |
+
ap.add_argument("--h5", default=".stable-wm/arx_left_cube.h5")
|
| 298 |
+
ap.add_argument("--seed", type=int, default=0)
|
| 299 |
+
ap.add_argument("--no-prism", action="store_true",
|
| 300 |
+
help="Run vanilla LeWM-MPPI (no PRISM prior). Use for A/B comparison.")
|
| 301 |
+
args = ap.parse_args()
|
| 302 |
+
|
| 303 |
+
# Load a sample from the ARX h5 — first frame of episode 0 + its goal
|
| 304 |
+
import h5py
|
| 305 |
+
print(f"\n[demo] loading sample from {args.h5}")
|
| 306 |
+
with h5py.File(args.h5, "r") as f:
|
| 307 |
+
obs = f["pixels"][0]
|
| 308 |
+
goal = f["goal_pixels"][0]
|
| 309 |
+
ground_truth_action = f["action"][0]
|
| 310 |
+
print(f"[demo] obs.shape={obs.shape} goal.shape={goal.shape} "
|
| 311 |
+
f"obs.dtype={obs.dtype}")
|
| 312 |
+
|
| 313 |
+
# Build planner
|
| 314 |
+
print()
|
| 315 |
+
planner = PrismMPPIInference(
|
| 316 |
+
lewm_ckpt=args.lewm_ckpt,
|
| 317 |
+
prior_ckpt=args.prior_ckpt,
|
| 318 |
+
use_prism=not args.no_prism,
|
| 319 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Plan
|
| 323 |
+
mode = "vanilla LeWM-MPPI" if args.no_prism else "PRISM-MPPI"
|
| 324 |
+
print(f"\n[demo] running {mode} on the sample obs + its goal image…")
|
| 325 |
+
import time
|
| 326 |
+
t0 = time.time()
|
| 327 |
+
actions = planner.plan(obs, goal)
|
| 328 |
+
dt = time.time() - t0
|
| 329 |
+
print(f"[demo] planned in {dt:.2f}s")
|
| 330 |
+
print(f"[demo] action sequence (A_block × A_raw): shape={actions.shape}")
|
| 331 |
+
print(f"[demo] first action: {actions[0].tolist()}")
|
| 332 |
+
print(f"[demo] ground-truth (t=0): {ground_truth_action.tolist()}")
|
| 333 |
+
print(f"[demo] |Δ|: {np.linalg.norm(actions[0] - ground_truth_action):.4f}")
|
jepa.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""JEPA Implementation"""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
def detach_clone(v):
|
| 9 |
+
return v.detach().clone() if torch.is_tensor(v) else v
|
| 10 |
+
|
| 11 |
+
class JEPA(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
encoder,
|
| 16 |
+
predictor,
|
| 17 |
+
action_encoder,
|
| 18 |
+
projector=None,
|
| 19 |
+
pred_proj=None,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.encoder = encoder
|
| 24 |
+
self.predictor = predictor
|
| 25 |
+
self.action_encoder = action_encoder
|
| 26 |
+
self.projector = projector or nn.Identity()
|
| 27 |
+
self.pred_proj = pred_proj or nn.Identity()
|
| 28 |
+
|
| 29 |
+
def encode(self, info):
|
| 30 |
+
"""Encode observations and actions into embeddings.
|
| 31 |
+
info: dict with pixels and action keys
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
pixels = info['pixels'].float()
|
| 35 |
+
b = pixels.size(0)
|
| 36 |
+
pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
|
| 37 |
+
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
| 38 |
+
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
| 39 |
+
emb = self.projector(pixels_emb)
|
| 40 |
+
info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
|
| 41 |
+
|
| 42 |
+
if "action" in info:
|
| 43 |
+
info["act_emb"] = self.action_encoder(info["action"])
|
| 44 |
+
|
| 45 |
+
return info
|
| 46 |
+
|
| 47 |
+
def predict(self, emb, act_emb):
|
| 48 |
+
"""Predict next state embedding
|
| 49 |
+
emb: (B, T, D)
|
| 50 |
+
act_emb: (B, T, A_emb)
|
| 51 |
+
"""
|
| 52 |
+
preds = self.predictor(emb, act_emb)
|
| 53 |
+
preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
|
| 54 |
+
preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
|
| 55 |
+
return preds
|
| 56 |
+
|
| 57 |
+
####################
|
| 58 |
+
## Inference only ##
|
| 59 |
+
####################
|
| 60 |
+
|
| 61 |
+
def rollout(self, info, action_sequence, history_size: int = 3):
|
| 62 |
+
"""Rollout the model given an initial info dict and action sequence.
|
| 63 |
+
pixels: (B, S, T, C, H, W)
|
| 64 |
+
action_sequence: (B, S, T, action_dim)
|
| 65 |
+
- S is the number of action plan samples
|
| 66 |
+
- T is the time horizon
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
assert "pixels" in info, "pixels not in info_dict"
|
| 70 |
+
H = info["pixels"].size(2)
|
| 71 |
+
B, S, T = action_sequence.shape[:3]
|
| 72 |
+
act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
|
| 73 |
+
info["action"] = act_0
|
| 74 |
+
n_steps = T - H
|
| 75 |
+
|
| 76 |
+
# copy and encode initial info dict
|
| 77 |
+
_init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
|
| 78 |
+
_init = self.encode(_init)
|
| 79 |
+
emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
|
| 80 |
+
_init = {k: detach_clone(v) for k, v in _init.items()}
|
| 81 |
+
|
| 82 |
+
# flatten batch and sample dimensions for rollout
|
| 83 |
+
emb = rearrange(emb, "b s ... -> (b s) ...").clone()
|
| 84 |
+
act = rearrange(act_0, "b s ... -> (b s) ...")
|
| 85 |
+
act_future = rearrange(act_future, "b s ... -> (b s) ...")
|
| 86 |
+
|
| 87 |
+
# rollout predictor autoregressively for n_steps
|
| 88 |
+
HS = history_size
|
| 89 |
+
for t in range(n_steps):
|
| 90 |
+
act_emb = self.action_encoder(act)
|
| 91 |
+
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
| 92 |
+
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
| 93 |
+
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
| 94 |
+
emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
|
| 95 |
+
|
| 96 |
+
next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
|
| 97 |
+
act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
|
| 98 |
+
|
| 99 |
+
# predict the last state
|
| 100 |
+
act_emb = self.action_encoder(act) # (BS, T, A_emb)
|
| 101 |
+
emb_trunc = emb[:, -HS:] # (BS, HS, D)
|
| 102 |
+
act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
|
| 103 |
+
pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
|
| 104 |
+
emb = torch.cat([emb, pred_emb], dim=1)
|
| 105 |
+
|
| 106 |
+
# unflatten batch and sample dimensions
|
| 107 |
+
pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
|
| 108 |
+
info["predicted_emb"] = pred_rollout
|
| 109 |
+
|
| 110 |
+
return info
|
| 111 |
+
|
| 112 |
+
def criterion(self, info_dict: dict):
|
| 113 |
+
"""Compute the cost between predicted embeddings and goal embeddings."""
|
| 114 |
+
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
| 115 |
+
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
| 116 |
+
|
| 117 |
+
goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
|
| 118 |
+
|
| 119 |
+
# return last-step cost per action candidate
|
| 120 |
+
cost = F.mse_loss(
|
| 121 |
+
pred_emb[..., -1:, :],
|
| 122 |
+
goal_emb[..., -1:, :].detach(),
|
| 123 |
+
reduction="none",
|
| 124 |
+
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
|
| 125 |
+
|
| 126 |
+
return cost
|
| 127 |
+
|
| 128 |
+
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
|
| 129 |
+
""" Compute the cost of action candidates given an info dict with goal and initial state."""
|
| 130 |
+
|
| 131 |
+
assert "goal" in info_dict, "goal not in info_dict"
|
| 132 |
+
|
| 133 |
+
device = next(self.parameters()).device
|
| 134 |
+
for k in list(info_dict.keys()):
|
| 135 |
+
if torch.is_tensor(info_dict[k]):
|
| 136 |
+
info_dict[k] = info_dict[k].to(device)
|
| 137 |
+
|
| 138 |
+
goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
|
| 139 |
+
goal["pixels"] = goal["goal"]
|
| 140 |
+
|
| 141 |
+
for k in info_dict:
|
| 142 |
+
if k.startswith("goal_"):
|
| 143 |
+
goal[k[len("goal_") :]] = goal.pop(k)
|
| 144 |
+
|
| 145 |
+
goal.pop("action")
|
| 146 |
+
goal = self.encode(goal)
|
| 147 |
+
|
| 148 |
+
info_dict["goal_emb"] = goal["emb"]
|
| 149 |
+
info_dict = self.rollout(info_dict, action_candidates)
|
| 150 |
+
|
| 151 |
+
cost = self.criterion(info_dict)
|
| 152 |
+
|
| 153 |
+
return cost
|
lewm_red_cube_epoch_100_object.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca701f5496e3a23b0dbae8fa9a66f5de980abe729845739f75fce3556b97fe8e
|
| 3 |
+
size 72355236
|
module.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
def modulate(x, shift, scale):
|
| 7 |
+
"""AdaLN-zero modulation"""
|
| 8 |
+
return x * (1 + scale) + shift
|
| 9 |
+
|
| 10 |
+
class SIGReg(torch.nn.Module):
|
| 11 |
+
"""Sketch Isotropic Gaussian Regularizer (single-GPU!)"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, knots=17, num_proj=1024):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.num_proj = num_proj
|
| 16 |
+
t = torch.linspace(0, 3, knots, dtype=torch.float32)
|
| 17 |
+
dt = 3 / (knots - 1)
|
| 18 |
+
weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
|
| 19 |
+
weights[[0, -1]] = dt
|
| 20 |
+
window = torch.exp(-t.square() / 2.0)
|
| 21 |
+
self.register_buffer("t", t)
|
| 22 |
+
self.register_buffer("phi", window)
|
| 23 |
+
self.register_buffer("weights", weights * window)
|
| 24 |
+
|
| 25 |
+
def forward(self, proj):
|
| 26 |
+
"""
|
| 27 |
+
proj: (T, B, D)
|
| 28 |
+
"""
|
| 29 |
+
# sample random projections
|
| 30 |
+
A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
|
| 31 |
+
A = A.div_(A.norm(p=2, dim=0))
|
| 32 |
+
# compute the epps-pulley statistic
|
| 33 |
+
x_t = (proj @ A).unsqueeze(-1) * self.t
|
| 34 |
+
err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
|
| 35 |
+
statistic = (err @ self.weights) * proj.size(-2)
|
| 36 |
+
return statistic.mean() # average over projections and time
|
| 37 |
+
|
| 38 |
+
class FeedForward(nn.Module):
|
| 39 |
+
"""FeedForward network used in Transformers"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.net = nn.Sequential(
|
| 44 |
+
nn.LayerNorm(dim),
|
| 45 |
+
nn.Linear(dim, hidden_dim),
|
| 46 |
+
nn.GELU(),
|
| 47 |
+
nn.Dropout(dropout),
|
| 48 |
+
nn.Linear(hidden_dim, dim),
|
| 49 |
+
nn.Dropout(dropout),
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return self.net(x)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Attention(nn.Module):
|
| 57 |
+
"""Scaled dot-product attention with causal masking"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
| 60 |
+
super().__init__()
|
| 61 |
+
inner_dim = dim_head * heads
|
| 62 |
+
project_out = not (heads == 1 and dim_head == dim)
|
| 63 |
+
self.heads = heads
|
| 64 |
+
self.scale = dim_head**-0.5
|
| 65 |
+
self.dropout = dropout
|
| 66 |
+
self.norm = nn.LayerNorm(dim)
|
| 67 |
+
self.attend = nn.Softmax(dim=-1)
|
| 68 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
| 69 |
+
self.to_out = (
|
| 70 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
| 71 |
+
if project_out
|
| 72 |
+
else nn.Identity()
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward(self, x, causal=True):
|
| 76 |
+
"""
|
| 77 |
+
x : (B, T, D)
|
| 78 |
+
"""
|
| 79 |
+
x = self.norm(x)
|
| 80 |
+
drop = self.dropout if self.training else 0.0
|
| 81 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1) # q, k, v: (B, heads, T, dim_head)
|
| 82 |
+
q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
|
| 83 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
|
| 84 |
+
out = rearrange(out, "b h t d -> b t (h d)")
|
| 85 |
+
return self.to_out(out)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ConditionalBlock(nn.Module):
|
| 89 |
+
"""Transformer block with AdaLN-zero conditioning"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
| 92 |
+
super().__init__()
|
| 93 |
+
|
| 94 |
+
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
| 95 |
+
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
|
| 96 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 97 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 98 |
+
self.adaLN_modulation = nn.Sequential(
|
| 99 |
+
nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
| 103 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
| 104 |
+
|
| 105 |
+
def forward(self, x, c):
|
| 106 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 107 |
+
self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 108 |
+
)
|
| 109 |
+
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 110 |
+
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Block(nn.Module):
|
| 115 |
+
"""Standard Transformer block"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
| 118 |
+
super().__init__()
|
| 119 |
+
|
| 120 |
+
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
| 121 |
+
self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
|
| 122 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 123 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = x + self.attn(self.norm1(x))
|
| 127 |
+
x = x + self.mlp(self.norm2(x))
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Transformer(nn.Module):
|
| 132 |
+
"""Standard Transformer with support for AdaLN-zero blocks"""
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
input_dim,
|
| 137 |
+
hidden_dim,
|
| 138 |
+
output_dim,
|
| 139 |
+
depth,
|
| 140 |
+
heads,
|
| 141 |
+
dim_head,
|
| 142 |
+
mlp_dim,
|
| 143 |
+
dropout=0.0,
|
| 144 |
+
block_class=Block,
|
| 145 |
+
):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.norm = nn.LayerNorm(hidden_dim)
|
| 148 |
+
self.layers = nn.ModuleList([])
|
| 149 |
+
|
| 150 |
+
self.input_proj = (
|
| 151 |
+
nn.Linear(input_dim, hidden_dim)
|
| 152 |
+
if input_dim != hidden_dim
|
| 153 |
+
else nn.Identity()
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.cond_proj = (
|
| 157 |
+
nn.Linear(input_dim, hidden_dim)
|
| 158 |
+
if input_dim != hidden_dim
|
| 159 |
+
else nn.Identity()
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self.output_proj = (
|
| 163 |
+
nn.Linear(hidden_dim, output_dim)
|
| 164 |
+
if hidden_dim != output_dim
|
| 165 |
+
else nn.Identity()
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
for _ in range(depth):
|
| 169 |
+
self.layers.append(
|
| 170 |
+
block_class(hidden_dim, heads, dim_head, mlp_dim, dropout)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def forward(self, x, c=None):
|
| 174 |
+
|
| 175 |
+
if hasattr(self, "input_proj"):
|
| 176 |
+
x = self.input_proj(x)
|
| 177 |
+
|
| 178 |
+
if c is not None and hasattr(self, "cond_proj"):
|
| 179 |
+
c = self.cond_proj(c)
|
| 180 |
+
|
| 181 |
+
for block in self.layers:
|
| 182 |
+
x = block(x) if isinstance(block, Block) else block(x, c)
|
| 183 |
+
x = self.norm(x)
|
| 184 |
+
|
| 185 |
+
if hasattr(self, "output_proj"):
|
| 186 |
+
x = self.output_proj(x)
|
| 187 |
+
return x
|
| 188 |
+
|
| 189 |
+
class Embedder(nn.Module):
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
input_dim=10,
|
| 193 |
+
smoothed_dim=10,
|
| 194 |
+
emb_dim=10,
|
| 195 |
+
mlp_scale=4,
|
| 196 |
+
):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.patch_embed = nn.Conv1d(input_dim, smoothed_dim, kernel_size=1, stride=1)
|
| 199 |
+
self.embed = nn.Sequential(
|
| 200 |
+
nn.Linear(smoothed_dim, mlp_scale * emb_dim),
|
| 201 |
+
nn.SiLU(),
|
| 202 |
+
nn.Linear(mlp_scale * emb_dim, emb_dim),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
def forward(self, x):
|
| 206 |
+
"""
|
| 207 |
+
x: (B, T, D)
|
| 208 |
+
"""
|
| 209 |
+
x = x.float()
|
| 210 |
+
x = x.permute(0, 2, 1)
|
| 211 |
+
x = self.patch_embed(x)
|
| 212 |
+
x = x.permute(0, 2, 1)
|
| 213 |
+
x = self.embed(x)
|
| 214 |
+
return x
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class MLP(nn.Module):
|
| 218 |
+
"""Simple MLP with optional normalization and activation"""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
input_dim,
|
| 223 |
+
hidden_dim,
|
| 224 |
+
output_dim=None,
|
| 225 |
+
norm_fn=nn.LayerNorm,
|
| 226 |
+
act_fn=nn.GELU,
|
| 227 |
+
):
|
| 228 |
+
super().__init__()
|
| 229 |
+
norm_fn = norm_fn(hidden_dim) if norm_fn is not None else nn.Identity()
|
| 230 |
+
self.net = nn.Sequential(
|
| 231 |
+
nn.Linear(input_dim, hidden_dim),
|
| 232 |
+
norm_fn,
|
| 233 |
+
act_fn(),
|
| 234 |
+
nn.Linear(hidden_dim, output_dim or input_dim),
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def forward(self, x):
|
| 238 |
+
"""
|
| 239 |
+
x: (B*T, D)
|
| 240 |
+
"""
|
| 241 |
+
return self.net(x)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class ActionEncoder2DWrapper(nn.Module):
|
| 245 |
+
"""Slices 6-dim raw action input down to 2-dim (dx, dy) before encoding.
|
| 246 |
+
|
| 247 |
+
Accepts either (..., frameskip*6) or (..., frameskip*2) trailing dim
|
| 248 |
+
and adapts. Lives in module.py so it's importable from any script
|
| 249 |
+
that already imports `module` to use other LeWM components.
|
| 250 |
+
"""
|
| 251 |
+
def __init__(self, inner: nn.Module, frameskip: int = 5):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.inner = inner
|
| 254 |
+
self.frameskip = frameskip
|
| 255 |
+
|
| 256 |
+
def forward(self, x):
|
| 257 |
+
if x.shape[-1] == self.frameskip * 6:
|
| 258 |
+
B = x.shape[:-1]
|
| 259 |
+
x = x.reshape(*B, self.frameskip, 6)
|
| 260 |
+
x = x[..., :2]
|
| 261 |
+
x = x.reshape(*B, self.frameskip * 2)
|
| 262 |
+
return self.inner(x)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class ARPredictor(nn.Module):
|
| 266 |
+
"""Autoregressive predictor for next-step embedding prediction."""
|
| 267 |
+
|
| 268 |
+
def __init__(
|
| 269 |
+
self,
|
| 270 |
+
*,
|
| 271 |
+
num_frames,
|
| 272 |
+
depth,
|
| 273 |
+
heads,
|
| 274 |
+
mlp_dim,
|
| 275 |
+
input_dim,
|
| 276 |
+
hidden_dim,
|
| 277 |
+
output_dim=None,
|
| 278 |
+
dim_head=64,
|
| 279 |
+
dropout=0.0,
|
| 280 |
+
emb_dropout=0.0,
|
| 281 |
+
):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, input_dim))
|
| 284 |
+
self.dropout = nn.Dropout(emb_dropout)
|
| 285 |
+
self.transformer = Transformer(
|
| 286 |
+
input_dim,
|
| 287 |
+
hidden_dim,
|
| 288 |
+
output_dim or input_dim,
|
| 289 |
+
depth,
|
| 290 |
+
heads,
|
| 291 |
+
dim_head,
|
| 292 |
+
mlp_dim,
|
| 293 |
+
dropout,
|
| 294 |
+
block_class=ConditionalBlock,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def forward(self, x, c):
|
| 298 |
+
"""
|
| 299 |
+
x: (B, T, d)
|
| 300 |
+
c: (B, T, act_dim)
|
| 301 |
+
"""
|
| 302 |
+
T = x.size(1)
|
| 303 |
+
x = x + self.pos_embedding[:, :T]
|
| 304 |
+
x = self.dropout(x)
|
| 305 |
+
x = self.transformer(x, c)
|
| 306 |
+
return x
|
prior_head.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PriorHead — MLP that maps (z_t, z_g) → Gaussian over an action sequence.
|
| 2 |
+
|
| 3 |
+
Per §9 design: input is concat(z_t, z_g) ∈ R^{2D}, output is (μ, σ) over
|
| 4 |
+
H × A_block × A_raw normalized actions. σ is per-input via softplus + a floor.
|
| 5 |
+
The head's output sits in StandardScaler-normalized action space; the eval-side
|
| 6 |
+
policy is responsible for inverse-transform back to env action units.
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PriorHead(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
z_dim: int,
|
| 19 |
+
H: int,
|
| 20 |
+
A_block: int,
|
| 21 |
+
A_raw: int,
|
| 22 |
+
hidden: int = 512,
|
| 23 |
+
sigma_floor: float = 0.05,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.z_dim = z_dim
|
| 27 |
+
self.H = H
|
| 28 |
+
self.A_block = A_block
|
| 29 |
+
self.A_raw = A_raw
|
| 30 |
+
self.action_seq_dim = H * A_block * A_raw
|
| 31 |
+
self.sigma_floor = sigma_floor
|
| 32 |
+
|
| 33 |
+
self.mlp = nn.Sequential(
|
| 34 |
+
nn.Linear(2 * z_dim, hidden),
|
| 35 |
+
nn.GELU(),
|
| 36 |
+
nn.Linear(hidden, hidden),
|
| 37 |
+
nn.GELU(),
|
| 38 |
+
nn.Linear(hidden, 2 * self.action_seq_dim),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(self, z_t: torch.Tensor, z_g: torch.Tensor):
|
| 42 |
+
"""z_t, z_g: (B, D). Returns (mu, sigma) each of shape (B, H, A_block, A_raw)."""
|
| 43 |
+
x = torch.cat([z_t, z_g], dim=-1)
|
| 44 |
+
out = self.mlp(x)
|
| 45 |
+
mu_flat, log_sigma_flat = out.chunk(2, dim=-1)
|
| 46 |
+
sigma_flat = F.softplus(log_sigma_flat) + self.sigma_floor
|
| 47 |
+
|
| 48 |
+
B = mu_flat.size(0)
|
| 49 |
+
shape = (B, self.H, self.A_block, self.A_raw)
|
| 50 |
+
return mu_flat.view(shape), sigma_flat.view(shape)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def beta_nll_loss(
|
| 54 |
+
mu: torch.Tensor,
|
| 55 |
+
sigma: torch.Tensor,
|
| 56 |
+
target: torch.Tensor,
|
| 57 |
+
beta: float = 0.5,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""β-NLL (Seitzer et al. 2022).
|
| 60 |
+
|
| 61 |
+
Standard Gaussian NLL per element (dropping additive constant):
|
| 62 |
+
nll_i = 0.5 (y_i − μ_i)² / σ_i² + log σ_i
|
| 63 |
+
β-NLL multiplies by stop_grad(σ_i^(2β)) before averaging:
|
| 64 |
+
L = mean_i [ stop_grad(σ_i^{2β}) · nll_i ]
|
| 65 |
+
β=0.5 is the recommended robust default — keeps σ-gradient alive but
|
| 66 |
+
prevents the σ-blow-up pathology of vanilla NLL when μ is hard to fit.
|
| 67 |
+
"""
|
| 68 |
+
var = sigma.pow(2)
|
| 69 |
+
log_sigma = sigma.log()
|
| 70 |
+
sq_err = (target - mu).pow(2)
|
| 71 |
+
nll = 0.5 * sq_err / var + log_sigma
|
| 72 |
+
weight = sigma.detach().pow(2 * beta)
|
| 73 |
+
return (weight * nll).mean()
|
prior_head_red_cube.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0026eda7168dc515df27fadeb415ac2e6bd25fb98206b0db6f34fb365d6c591
|
| 3 |
+
size 2359141
|