YuhaiW commited on
Commit
8f017a6
·
verified ·
1 Parent(s): acc0ba2

red_cube PRISM-JEPA deploy bundle: LeWM WM + PRISM prior + self-contained inference + README

Browse files
README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: pytorch
4
+ tags:
5
+ - robotics
6
+ - world-model
7
+ - jepa
8
+ - lewm
9
+ - prism
10
+ - arx
11
+ - visuomotor
12
+ ---
13
+
14
+ # PRISM-JEPA · red_cube (ARX-X5) — LeWM world model + PRISM action prior
15
+
16
+ Complete deployable stack for goal-conditioned visuomotor planning on the ARX-X5
17
+ **"red cube"** task: the **LeWM JEPA world model** + the **PRISM goal-conditioned action
18
+ prior** + **self-contained PRISM-MPPI inference code**.
19
+
20
+ ## ⚠️ Status — read first
21
+
22
+ This is a **research artifact / deployment hand-off package, NOT a validated policy.**
23
+
24
+ - Trained on 201 teleop demos (`Xia-2004/red_cube`, 42,165 frames, 5-DoF).
25
+ - The world model **converges cleanly and its forward model is accurate** (rollout pred/id
26
+ ≈ 0.25 < 0.5), **but its MPPI cost surface is weak/flat** on this small-real-robot,
27
+ small-action data (CV ≈ **0.14** ≪ 0.30 "discriminative" threshold). Consequence: the
28
+ planner's distinctive **cost-rescoring is dormant**, so in practice **PRISM ≈ a
29
+ goal-conditioned BC prior** — in offline A/B it produces ~31% more expert-like actions
30
+ than vanilla LeWM-MPPI (which wanders), but adds **no measurable goal-progress** in the
31
+ world model's own latent metric (paired t-test p = 0.57).
32
+ - **Never run on the real robot.** Treat as a starting point; add workspace/velocity safety
33
+ limits and validate before any hardware run.
34
+ - Full analysis & how these numbers were obtained: project doc
35
+ `docs/30_red_cube_cv_investigation_and_prism.md`.
36
+
37
+ ## Contents
38
+
39
+ | file | description |
40
+ |---|---|
41
+ | `lewm_red_cube_epoch_100_object.ckpt` | LeWM world model — pickled JEPA: ViT-tiny encoder + AR transformer predictor + action encoder (~18M params) |
42
+ | `prior_head_red_cube.pt` | PRISM goal-conditioned action prior — `state_dict` + `config` + action `StandardScaler` (mean/scale) |
43
+ | `arx_inference_demo.py` | self-contained `PrismMPPIInference` (PoG-fused PRISM-MPPI; `use_prism=False` → vanilla LeWM-MPPI) |
44
+ | `jepa.py`, `module.py`, `prior_head.py` | model classes required to unpickle the ckpt and run the prior |
45
+
46
+ ## Observation / action space
47
+
48
+ - **Observation:** single top-down RGB frame, **224×224×3 uint8** (RealSense `camera_third`).
49
+ - **Goal:** an RGB goal image, same format (the prior + cost are conditioned on it).
50
+ - **Action:** 5-DoF delta end-effector **`[dx, dy, dz, dyaw, d_gripper]`**, raw units, one per
51
+ control tick. `plan()` returns one plan-step = `A_block = 5` ticks → shape **`(5, 5)`**.
52
+
53
+ ## Dependencies
54
+
55
+ `torch`, `numpy`, `einops`, and `transformers` (the encoder inside the ckpt is a HuggingFace
56
+ ViT, needed at unpickle time). The three bundled `.py` files must be importable from the
57
+ working directory. (If unpickling complains about a missing class, also `pip install
58
+ stable-pretraining`.)
59
+
60
+ ## Deploy — receding-horizon control loop
61
+
62
+ ```python
63
+ from arx_inference_demo import PrismMPPIInference
64
+
65
+ planner = PrismMPPIInference(
66
+ lewm_ckpt = "lewm_red_cube_epoch_100_object.ckpt",
67
+ prior_ckpt = "prior_head_red_cube.pt",
68
+ use_prism = True, # True = PRISM (prior ⊗ MPPI via PoG fusion); False = vanilla LeWM-MPPI
69
+ device = "cuda",
70
+ )
71
+
72
+ goal_img = load_goal_image() # (224,224,3) uint8 — the task goal image
73
+ while not done:
74
+ obs = camera.read() # (224,224,3) uint8, top-down camera_third view
75
+ actions = planner.plan(obs, goal_img) # (5, 5) raw [dx,dy,dz,dyaw,d_gripper]
76
+ for a in actions: # receding horizon: execute the block, then replan
77
+ robot.execute(a) # (or execute fewer than 5 and replan more often)
78
+ ```
79
+
80
+ `plan()` runs one full PRISM-MPPI optimization and returns the first `A_block = 5` env-step
81
+ actions of the optimized plan, in **raw action units** (already de-normalized).
82
+
83
+ ## Key hyperparameters (`PrismMPPIInference` constructor)
84
+
85
+ | arg | default | meaning |
86
+ |---|---|---|
87
+ | `H` | 5 | planning horizon (plan-steps) |
88
+ | `A_block` | 5 | env-steps (ticks) per plan-step ("frameskip") |
89
+ | `K` | 128 | MPPI samples per iteration |
90
+ | `n_iters` | 30 | MPPI refinement iterations |
91
+ | `var_scale` | 1.0 | initial planner sampling std |
92
+ | `prior_sigma_scale` | 2.0 | multiplier on the prior σ before PoG fusion (PRISM only) |
93
+ | `temperature` | 0.5 | MPPI softmax temperature |
94
+ | `history_size` | 3 | LeWM history-window length (**must match training**) |
95
+
96
+ `H`, `A_block`, `A_raw`, `history_size` must match the checkpoints — the constructor asserts
97
+ the prior head's config agrees. Change them only if you retrain.
98
+
99
+ ## PRISM vs vanilla (A/B)
100
+
101
+ Build a second planner with `use_prism=False` for a baseline (plain LeWM-MPPI, no prior,
102
+ same encoder/predictor/MPPI loop). On this task PRISM produces more expert-like actions;
103
+ vanilla tends to wander because the cost surface is flat.
104
+
105
+ ## Provenance
106
+
107
+ Data: [`Xia-2004/red_cube`](https://huggingface.co/datasets/Xia-2004/red_cube) (ARX-X5
108
+ left-arm teleop). Sibling of `Xia-2004/arx-left-cube`. World-model architecture is identical
109
+ to the sim LeWM (ViT-tiny, embed_dim 192, predictor depth 6 / heads 16) — part of the
110
+ PRISM-JEPA project (sister of Newt-PRISM, CoRL 2026).
arx_inference_demo.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """arx_inference_demo.py — standalone PRISM-MPPI inference for ARX cube task.
2
+
3
+ This file is **self-contained**: it depends only on the bundled
4
+ `jepa.py`, `module.py`, `prior_head.py`, plus standard torch / numpy.
5
+ No `stable_worldmodel` import — the MPPI loop is re-implemented inline.
6
+
7
+ Intended use by a downstream consumer (e.g., the ARX deployment side):
8
+
9
+ from arx_inference_demo import PrismMPPIInference
10
+
11
+ planner = PrismMPPIInference(
12
+ lewm_ckpt = "lewm_arx.ckpt",
13
+ prior_ckpt = "prior_head_arx.pt",
14
+ device = "cuda",
15
+ )
16
+
17
+ # In the control loop:
18
+ while not done:
19
+ obs_uint8 = camera.read() # (224, 224, 3) uint8 RGB
20
+ goal_uint8 = goal_image # (224, 224, 3) uint8 RGB
21
+ actions = planner.plan(obs_uint8, goal_uint8)
22
+ # → (A_block, 5) float32, raw action units
23
+ for a in actions:
24
+ robot.execute(a) # step the robot
25
+
26
+ `plan()` performs one full PRISM-MPPI optimization and returns the first
27
+ A_block = 5 env-step actions of the optimized plan. The caller may choose
28
+ to execute all 5 then replan (receding-horizon, k=A_block), or execute
29
+ fewer and replan more often.
30
+
31
+ PRISM-MPPI summary:
32
+ 1. JEPA encoder turns current obs + goal image into latent embeddings z_t, z_g.
33
+ 2. PRISM prior head maps (z_t, z_g) → (μ_p, σ_p) over the next
34
+ H × A_block × A_raw normalized actions.
35
+ 3. We seed an MPPI distribution N(0, var_scale I) and PoG-fuse with the
36
+ prior to get N(fused_μ, fused_σ²). The variance is FROZEN through MPPI
37
+ iterations (this is the PRISM-MPPI signature; see paper §3).
38
+ 4. Each iteration samples K candidate action sequences, rolls them out via
39
+ the LeWM ARPredictor in latent space, computes cost = MSE(predicted
40
+ final z, z_g), reweights candidates by exp(-β·cost), updates the mean.
41
+ 5. After n_iters iterations, the first A_block entries of the mean are
42
+ returned (denormalized to raw env action units via the saved
43
+ StandardScaler).
44
+ """
45
+ from __future__ import annotations
46
+
47
+ from pathlib import Path
48
+
49
+ import numpy as np
50
+ import torch
51
+ import torch.nn.functional as F
52
+
53
+ # Required for unpickling the LeWM ckpt — these modules must be importable
54
+ import jepa # noqa: F401 — registers JEPA class
55
+ import module # noqa: F401 — registers ARPredictor, Embedder, etc.
56
+ from prior_head import PriorHead
57
+
58
+ IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
59
+ IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
60
+
61
+
62
+ def _preprocess(img_uint8: np.ndarray, device: torch.device) -> torch.Tensor:
63
+ """uint8 (H, W, 3) → float (1, 3, 224, 224), ImageNet-normalized."""
64
+ assert img_uint8.shape == (224, 224, 3), \
65
+ f"Expected (224, 224, 3) image, got {img_uint8.shape}"
66
+ t = torch.from_numpy(img_uint8).permute(2, 0, 1).float().div(255.0).unsqueeze(0)
67
+ t = t.to(device)
68
+ mean = IMAGENET_MEAN.to(device)
69
+ std = IMAGENET_STD.to(device)
70
+ return (t - mean) / std
71
+
72
+
73
+ def _pog_fusion(mean, std, mu_p, sg_p, sigma_floor=0.05):
74
+ """Product-of-Gaussians fusion. Matches prism_mppi.pog_fusion."""
75
+ eps = 1e-8
76
+ tau_base = 1.0 / (std ** 2 + eps)
77
+ tau_p = 1.0 / (sg_p ** 2 + eps)
78
+ tau_c = tau_base + tau_p
79
+ fused_mean = (tau_base * mean + tau_p * mu_p) / tau_c
80
+ fused_std = (1.0 / tau_c).sqrt().clamp(min=sigma_floor)
81
+ return fused_mean, fused_std
82
+
83
+
84
+ class PrismMPPIInference:
85
+ """Standalone PRISM-MPPI planner for ARX cube task.
86
+
87
+ Supports two modes via the `use_prism` constructor flag — kept on a
88
+ single class so that PRISM and vanilla-MPPI A/B comparisons use the
89
+ exact same encoder, predictor, MPPI loop, and StandardScaler. The
90
+ only difference between the two modes is whether the PoG fusion at
91
+ init time uses the prior head's (μ_p, σ_p) or not. The prior-head
92
+ checkpoint is always loaded — its StandardScaler (action
93
+ normalization) is shared by both modes so the comparison is
94
+ apples-to-apples in raw action units.
95
+
96
+ Args (paper defaults — change only if you know what you're doing):
97
+ lewm_ckpt: path to lewm_arx.ckpt (pickled JEPA module)
98
+ prior_ckpt: path to prior_head_arx.pt (PRISM head state_dict + scaler)
99
+ use_prism: if True (default), inject the PRISM prior via PoG fusion.
100
+ If False, skip the prior — the planner becomes vanilla
101
+ LeWM-MPPI from N(0, var_scale) seed. Use this for paper-
102
+ grade real-robot A/B against PRISM-MPPI.
103
+ H: planning horizon in plan-steps (default 5)
104
+ A_block: env-steps per plan-step (default 5, "frameskip")
105
+ K: num MPPI samples per iteration (default 128)
106
+ n_iters: num MPPI refinement iterations (default 30)
107
+ var_scale: initial planner std (default 1.0)
108
+ temperature: MPPI softmax temperature β = 1/temperature (default 0.5)
109
+ sigma_floor: lower bound on fused σ (default 0.05); only used by PRISM
110
+ prior_sigma_scale: multiplier on prior σ_p before fusion (default 2.0,
111
+ matches the paper's PRISM-MPPI s=2.0 setting); only used by PRISM
112
+ history_size: LeWM history-window length (default 3; must match training)
113
+ device: 'cuda' or 'cpu'
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ lewm_ckpt: str | Path,
119
+ prior_ckpt: str | Path,
120
+ use_prism: bool = True,
121
+ H: int = 5,
122
+ A_block: int = 5,
123
+ K: int = 128,
124
+ n_iters: int = 30,
125
+ var_scale: float = 1.0,
126
+ temperature: float = 0.5,
127
+ sigma_floor: float = 0.05,
128
+ prior_sigma_scale: float = 2.0,
129
+ history_size: int = 3,
130
+ device: str = "cuda",
131
+ ):
132
+ self.device = torch.device(device)
133
+ self.use_prism = bool(use_prism)
134
+ self.H = H
135
+ self.A_block = A_block
136
+ self.K = K
137
+ self.n_iters = n_iters
138
+ self.var_scale = var_scale
139
+ self.beta = 1.0 / temperature
140
+ self.sigma_floor = sigma_floor
141
+ self.prior_sigma_scale = prior_sigma_scale
142
+ self.history_size = history_size
143
+
144
+ # ---- Load LeWM (encoder + AR predictor, pickled) ----
145
+ print(f"[init] loading LeWM ckpt: {lewm_ckpt}")
146
+ self.lewm = torch.load(
147
+ str(lewm_ckpt), map_location=self.device, weights_only=False,
148
+ )
149
+ self.lewm.to(self.device).eval()
150
+ for p in self.lewm.parameters():
151
+ p.requires_grad_(False)
152
+
153
+ # ---- Load PRISM prior head + scaler (scaler always used; head conditionally) ----
154
+ print(f"[init] loading prior head + scaler: {prior_ckpt}")
155
+ pck = torch.load(str(prior_ckpt), map_location=self.device, weights_only=False)
156
+ cfg = pck["config"]
157
+ self.A_raw = int(cfg["A_raw"])
158
+ assert cfg["H"] == self.H and cfg["A_block"] == self.A_block, (
159
+ f"Ckpt config mismatch: H={cfg['H']} A_block={cfg['A_block']} "
160
+ f"vs runtime H={self.H} A_block={self.A_block}"
161
+ )
162
+
163
+ if self.use_prism:
164
+ self.head = PriorHead(**cfg).to(self.device).eval()
165
+ self.head.load_state_dict(pck["state_dict"])
166
+ for p in self.head.parameters():
167
+ p.requires_grad_(False)
168
+ else:
169
+ self.head = None # vanilla LeWM-MPPI mode: skip PoG fusion
170
+
171
+ # Action denormalization (raw_action = norm_action * scale + mean) — always loaded
172
+ self.scaler_mean = torch.tensor(pck["scaler_mean"], device=self.device).float()
173
+ self.scaler_scale = torch.tensor(pck["scaler_scale"], device=self.device).float()
174
+ mode_str = "PRISM-MPPI" if self.use_prism else "vanilla LeWM-MPPI (PRISM off)"
175
+ print(f"[init] mode = {mode_str}")
176
+ print(f"[init] z_dim={cfg['z_dim']} H={self.H} A_block={self.A_block} "
177
+ f"A_raw={self.A_raw}")
178
+ print(f"[init] device={self.device} K={self.K} n_iters={self.n_iters}")
179
+
180
+ @torch.no_grad()
181
+ def _encode(self, img_uint8: np.ndarray) -> torch.Tensor:
182
+ """uint8 image → (1, D) CLS embedding."""
183
+ x = _preprocess(img_uint8, self.device)
184
+ # JEPA.encode expects a dict with 'pixels' shape (B, T, C, H, W)
185
+ info = {"pixels": x.unsqueeze(1)} # add T=1 dim
186
+ info = self.lewm.encode(info)
187
+ return info["emb"][:, 0] # (1, D)
188
+
189
+ @torch.no_grad()
190
+ def _prior(self, z_t: torch.Tensor, z_g: torch.Tensor):
191
+ """PRISM head: (1, D), (1, D) → (μ, σ) of shape (1, H, A_block, A_raw)
192
+ in normalized action space."""
193
+ return self.head(z_t, z_g)
194
+
195
+ @torch.no_grad()
196
+ def _rollout_costs(
197
+ self,
198
+ z_t: torch.Tensor, # (1, D)
199
+ z_g: torch.Tensor, # (1, D)
200
+ action_candidates: torch.Tensor, # (1, K, H*A_block, A_raw) normalized
201
+ ) -> torch.Tensor: # (1, K) cost per candidate
202
+ """Rollout each candidate via LeWM AR predictor, compute final-z MSE to z_g."""
203
+ B, K, T_total, A = action_candidates.shape
204
+ assert T_total == self.H * self.A_block
205
+ D = z_t.shape[-1]
206
+ HS = self.history_size
207
+
208
+ # Seed embedding history with the current z_t (tile to HS length)
209
+ # emb: (B*K, HS, D)
210
+ emb = z_t.unsqueeze(1).expand(B, K, D).reshape(B * K, D)
211
+ emb = emb.unsqueeze(1).expand(-1, HS, -1).contiguous()
212
+
213
+ # action_seq: (B*K, T_total, A) — env-step actions; predictor consumes them block-by-block
214
+ act_seq = action_candidates.reshape(B * K, T_total, A)
215
+
216
+ # Group actions into plan-steps of A_block: (B*K, H, A_block * A)
217
+ act_plan = act_seq.reshape(B * K, self.H, self.A_block * A)
218
+
219
+ # Embed actions via the predictor's action_encoder (Embedder)
220
+ # act_emb: (B*K, H, action_emb_dim)
221
+ act_emb = self.lewm.action_encoder(act_plan)
222
+
223
+ # AR rollout
224
+ for t in range(self.H):
225
+ emb_trunc = emb[:, -HS:] # (B*K, HS, D)
226
+ act_trunc = act_emb[:, max(0, t - HS + 1): t + 1] # last HS actions seen
227
+ # Pad on the left if we don't have HS history of actions yet
228
+ if act_trunc.shape[1] < HS:
229
+ pad = act_trunc[:, :1].expand(-1, HS - act_trunc.shape[1], -1)
230
+ act_trunc = torch.cat([pad, act_trunc], dim=1)
231
+ pred = self.lewm.predict(emb_trunc, act_trunc)[:, -1:] # (B*K, 1, D)
232
+ emb = torch.cat([emb, pred], dim=1)
233
+
234
+ # Final predicted embedding: emb[:, -1]
235
+ pred_final = emb[:, -1] # (B*K, D)
236
+ goal = z_g.unsqueeze(1).expand(B, K, D).reshape(B * K, D)
237
+ cost = F.mse_loss(pred_final, goal, reduction="none").sum(dim=-1) # (B*K,)
238
+ return cost.reshape(B, K)
239
+
240
+ @torch.no_grad()
241
+ def plan(self, obs_uint8: np.ndarray, goal_uint8: np.ndarray) -> np.ndarray:
242
+ """One MPPI optimization (PRISM or vanilla depending on `use_prism`).
243
+
244
+ Returns (A_block, A_raw) actions in raw env units.
245
+ """
246
+ # 1. Encode
247
+ z_t = self._encode(obs_uint8) # (1, D)
248
+ z_g = self._encode(goal_uint8) # (1, D)
249
+
250
+ # 2. Init MPPI distribution N(0, var_scale)
251
+ shape = (1, self.H * self.A_block, self.A_raw)
252
+ mean = torch.zeros(shape, device=self.device)
253
+ std = torch.full(shape, self.var_scale, device=self.device)
254
+
255
+ # 3. (PRISM only) prior in normalized action space + PoG fusion
256
+ if self.use_prism:
257
+ mu_p, sg_p = self._prior(z_t, z_g) # (1, H, A_block, A_raw)
258
+ mu_p_flat = mu_p.reshape(*shape)
259
+ sg_p_flat = sg_p.reshape(*shape) * self.prior_sigma_scale
260
+ mean, std = _pog_fusion(mean, std, mu_p_flat, sg_p_flat, self.sigma_floor)
261
+
262
+ # 4. MPPI iterations (frozen σ — PRISM-MPPI signature when use_prism=True;
263
+ # matches stable_worldmodel.solver.MPPISolver default when use_prism=False)
264
+ for it in range(self.n_iters):
265
+ noise = torch.randn(
266
+ 1, self.K, self.H * self.A_block, self.A_raw, device=self.device,
267
+ )
268
+ cands = mean.unsqueeze(1) + noise * std.unsqueeze(1)
269
+ # cands: (1, K, H*A_block, A_raw)
270
+
271
+ cost = self._rollout_costs(z_t, z_g, cands) # (1, K)
272
+ log_w = -self.beta * (cost - cost.min(dim=-1, keepdim=True).values)
273
+ w = torch.softmax(log_w, dim=-1) # (1, K)
274
+
275
+ # Importance-weighted mean update; std FROZEN (PRISM-MPPI)
276
+ mean = (w.unsqueeze(-1).unsqueeze(-1) * cands).sum(dim=1)
277
+ # mean: (1, H*A_block, A_raw)
278
+
279
+ # 5. First A_block actions, denormalized
280
+ first_block_norm = mean[0, : self.A_block] # (A_block, A_raw)
281
+ first_block_raw = first_block_norm * self.scaler_scale + self.scaler_mean
282
+ return first_block_raw.cpu().numpy().astype(np.float32)
283
+
284
+
285
+ # ===========================================================================
286
+ # Sanity test: load + plan on a sample from the ARX h5
287
+ # ===========================================================================
288
+
289
+ if __name__ == "__main__":
290
+ import argparse
291
+
292
+ ap = argparse.ArgumentParser()
293
+ ap.add_argument(
294
+ "--lewm-ckpt", default=".stable-wm/lewm_arx_epoch_100_object.ckpt",
295
+ )
296
+ ap.add_argument("--prior-ckpt", default="prior_head_arx.pt")
297
+ ap.add_argument("--h5", default=".stable-wm/arx_left_cube.h5")
298
+ ap.add_argument("--seed", type=int, default=0)
299
+ ap.add_argument("--no-prism", action="store_true",
300
+ help="Run vanilla LeWM-MPPI (no PRISM prior). Use for A/B comparison.")
301
+ args = ap.parse_args()
302
+
303
+ # Load a sample from the ARX h5 — first frame of episode 0 + its goal
304
+ import h5py
305
+ print(f"\n[demo] loading sample from {args.h5}")
306
+ with h5py.File(args.h5, "r") as f:
307
+ obs = f["pixels"][0]
308
+ goal = f["goal_pixels"][0]
309
+ ground_truth_action = f["action"][0]
310
+ print(f"[demo] obs.shape={obs.shape} goal.shape={goal.shape} "
311
+ f"obs.dtype={obs.dtype}")
312
+
313
+ # Build planner
314
+ print()
315
+ planner = PrismMPPIInference(
316
+ lewm_ckpt=args.lewm_ckpt,
317
+ prior_ckpt=args.prior_ckpt,
318
+ use_prism=not args.no_prism,
319
+ device="cuda" if torch.cuda.is_available() else "cpu",
320
+ )
321
+
322
+ # Plan
323
+ mode = "vanilla LeWM-MPPI" if args.no_prism else "PRISM-MPPI"
324
+ print(f"\n[demo] running {mode} on the sample obs + its goal image…")
325
+ import time
326
+ t0 = time.time()
327
+ actions = planner.plan(obs, goal)
328
+ dt = time.time() - t0
329
+ print(f"[demo] planned in {dt:.2f}s")
330
+ print(f"[demo] action sequence (A_block × A_raw): shape={actions.shape}")
331
+ print(f"[demo] first action: {actions[0].tolist()}")
332
+ print(f"[demo] ground-truth (t=0): {ground_truth_action.tolist()}")
333
+ print(f"[demo] |Δ|: {np.linalg.norm(actions[0] - ground_truth_action):.4f}")
jepa.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """JEPA Implementation"""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ def detach_clone(v):
9
+ return v.detach().clone() if torch.is_tensor(v) else v
10
+
11
+ class JEPA(nn.Module):
12
+
13
+ def __init__(
14
+ self,
15
+ encoder,
16
+ predictor,
17
+ action_encoder,
18
+ projector=None,
19
+ pred_proj=None,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.encoder = encoder
24
+ self.predictor = predictor
25
+ self.action_encoder = action_encoder
26
+ self.projector = projector or nn.Identity()
27
+ self.pred_proj = pred_proj or nn.Identity()
28
+
29
+ def encode(self, info):
30
+ """Encode observations and actions into embeddings.
31
+ info: dict with pixels and action keys
32
+ """
33
+
34
+ pixels = info['pixels'].float()
35
+ b = pixels.size(0)
36
+ pixels = rearrange(pixels, "b t ... -> (b t) ...") # flatten for encoding
37
+ output = self.encoder(pixels, interpolate_pos_encoding=True)
38
+ pixels_emb = output.last_hidden_state[:, 0] # cls token
39
+ emb = self.projector(pixels_emb)
40
+ info["emb"] = rearrange(emb, "(b t) d -> b t d", b=b)
41
+
42
+ if "action" in info:
43
+ info["act_emb"] = self.action_encoder(info["action"])
44
+
45
+ return info
46
+
47
+ def predict(self, emb, act_emb):
48
+ """Predict next state embedding
49
+ emb: (B, T, D)
50
+ act_emb: (B, T, A_emb)
51
+ """
52
+ preds = self.predictor(emb, act_emb)
53
+ preds = self.pred_proj(rearrange(preds, "b t d -> (b t) d"))
54
+ preds = rearrange(preds, "(b t) d -> b t d", b=emb.size(0))
55
+ return preds
56
+
57
+ ####################
58
+ ## Inference only ##
59
+ ####################
60
+
61
+ def rollout(self, info, action_sequence, history_size: int = 3):
62
+ """Rollout the model given an initial info dict and action sequence.
63
+ pixels: (B, S, T, C, H, W)
64
+ action_sequence: (B, S, T, action_dim)
65
+ - S is the number of action plan samples
66
+ - T is the time horizon
67
+ """
68
+
69
+ assert "pixels" in info, "pixels not in info_dict"
70
+ H = info["pixels"].size(2)
71
+ B, S, T = action_sequence.shape[:3]
72
+ act_0, act_future = torch.split(action_sequence, [H, T - H], dim=2)
73
+ info["action"] = act_0
74
+ n_steps = T - H
75
+
76
+ # copy and encode initial info dict
77
+ _init = {k: v[:, 0] for k, v in info.items() if torch.is_tensor(v)}
78
+ _init = self.encode(_init)
79
+ emb = info["emb"] = _init["emb"].unsqueeze(1).expand(B, S, -1, -1)
80
+ _init = {k: detach_clone(v) for k, v in _init.items()}
81
+
82
+ # flatten batch and sample dimensions for rollout
83
+ emb = rearrange(emb, "b s ... -> (b s) ...").clone()
84
+ act = rearrange(act_0, "b s ... -> (b s) ...")
85
+ act_future = rearrange(act_future, "b s ... -> (b s) ...")
86
+
87
+ # rollout predictor autoregressively for n_steps
88
+ HS = history_size
89
+ for t in range(n_steps):
90
+ act_emb = self.action_encoder(act)
91
+ emb_trunc = emb[:, -HS:] # (BS, HS, D)
92
+ act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
93
+ pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
94
+ emb = torch.cat([emb, pred_emb], dim=1) # (BS, T+1, D)
95
+
96
+ next_act = act_future[:, t : t + 1, :] # (BS, 1, action_dim)
97
+ act = torch.cat([act, next_act], dim=1) # (BS, T+1, action_dim)
98
+
99
+ # predict the last state
100
+ act_emb = self.action_encoder(act) # (BS, T, A_emb)
101
+ emb_trunc = emb[:, -HS:] # (BS, HS, D)
102
+ act_trunc = act_emb[:, -HS:] # (BS, HS, A_emb)
103
+ pred_emb = self.predict(emb_trunc, act_trunc)[:, -1:] # (BS, 1, D)
104
+ emb = torch.cat([emb, pred_emb], dim=1)
105
+
106
+ # unflatten batch and sample dimensions
107
+ pred_rollout = rearrange(emb, "(b s) ... -> b s ...", b=B, s=S)
108
+ info["predicted_emb"] = pred_rollout
109
+
110
+ return info
111
+
112
+ def criterion(self, info_dict: dict):
113
+ """Compute the cost between predicted embeddings and goal embeddings."""
114
+ pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
115
+ goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
116
+
117
+ goal_emb = goal_emb[..., -1:, :].expand_as(pred_emb)
118
+
119
+ # return last-step cost per action candidate
120
+ cost = F.mse_loss(
121
+ pred_emb[..., -1:, :],
122
+ goal_emb[..., -1:, :].detach(),
123
+ reduction="none",
124
+ ).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
125
+
126
+ return cost
127
+
128
+ def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
129
+ """ Compute the cost of action candidates given an info dict with goal and initial state."""
130
+
131
+ assert "goal" in info_dict, "goal not in info_dict"
132
+
133
+ device = next(self.parameters()).device
134
+ for k in list(info_dict.keys()):
135
+ if torch.is_tensor(info_dict[k]):
136
+ info_dict[k] = info_dict[k].to(device)
137
+
138
+ goal = {k: v[:, 0] for k, v in info_dict.items() if torch.is_tensor(v)}
139
+ goal["pixels"] = goal["goal"]
140
+
141
+ for k in info_dict:
142
+ if k.startswith("goal_"):
143
+ goal[k[len("goal_") :]] = goal.pop(k)
144
+
145
+ goal.pop("action")
146
+ goal = self.encode(goal)
147
+
148
+ info_dict["goal_emb"] = goal["emb"]
149
+ info_dict = self.rollout(info_dict, action_candidates)
150
+
151
+ cost = self.criterion(info_dict)
152
+
153
+ return cost
lewm_red_cube_epoch_100_object.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca701f5496e3a23b0dbae8fa9a66f5de980abe729845739f75fce3556b97fe8e
3
+ size 72355236
module.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ def modulate(x, shift, scale):
7
+ """AdaLN-zero modulation"""
8
+ return x * (1 + scale) + shift
9
+
10
+ class SIGReg(torch.nn.Module):
11
+ """Sketch Isotropic Gaussian Regularizer (single-GPU!)"""
12
+
13
+ def __init__(self, knots=17, num_proj=1024):
14
+ super().__init__()
15
+ self.num_proj = num_proj
16
+ t = torch.linspace(0, 3, knots, dtype=torch.float32)
17
+ dt = 3 / (knots - 1)
18
+ weights = torch.full((knots,), 2 * dt, dtype=torch.float32)
19
+ weights[[0, -1]] = dt
20
+ window = torch.exp(-t.square() / 2.0)
21
+ self.register_buffer("t", t)
22
+ self.register_buffer("phi", window)
23
+ self.register_buffer("weights", weights * window)
24
+
25
+ def forward(self, proj):
26
+ """
27
+ proj: (T, B, D)
28
+ """
29
+ # sample random projections
30
+ A = torch.randn(proj.size(-1), self.num_proj, device=proj.device)
31
+ A = A.div_(A.norm(p=2, dim=0))
32
+ # compute the epps-pulley statistic
33
+ x_t = (proj @ A).unsqueeze(-1) * self.t
34
+ err = (x_t.cos().mean(-3) - self.phi).square() + x_t.sin().mean(-3).square()
35
+ statistic = (err @ self.weights) * proj.size(-2)
36
+ return statistic.mean() # average over projections and time
37
+
38
+ class FeedForward(nn.Module):
39
+ """FeedForward network used in Transformers"""
40
+
41
+ def __init__(self, dim, hidden_dim, dropout=0.0):
42
+ super().__init__()
43
+ self.net = nn.Sequential(
44
+ nn.LayerNorm(dim),
45
+ nn.Linear(dim, hidden_dim),
46
+ nn.GELU(),
47
+ nn.Dropout(dropout),
48
+ nn.Linear(hidden_dim, dim),
49
+ nn.Dropout(dropout),
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.net(x)
54
+
55
+
56
+ class Attention(nn.Module):
57
+ """Scaled dot-product attention with causal masking"""
58
+
59
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
60
+ super().__init__()
61
+ inner_dim = dim_head * heads
62
+ project_out = not (heads == 1 and dim_head == dim)
63
+ self.heads = heads
64
+ self.scale = dim_head**-0.5
65
+ self.dropout = dropout
66
+ self.norm = nn.LayerNorm(dim)
67
+ self.attend = nn.Softmax(dim=-1)
68
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
69
+ self.to_out = (
70
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
71
+ if project_out
72
+ else nn.Identity()
73
+ )
74
+
75
+ def forward(self, x, causal=True):
76
+ """
77
+ x : (B, T, D)
78
+ """
79
+ x = self.norm(x)
80
+ drop = self.dropout if self.training else 0.0
81
+ qkv = self.to_qkv(x).chunk(3, dim=-1) # q, k, v: (B, heads, T, dim_head)
82
+ q, k, v = (rearrange(t, "b t (h d) -> b h t d", h=self.heads) for t in qkv)
83
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=causal)
84
+ out = rearrange(out, "b h t d -> b t (h d)")
85
+ return self.to_out(out)
86
+
87
+
88
+ class ConditionalBlock(nn.Module):
89
+ """Transformer block with AdaLN-zero conditioning"""
90
+
91
+ def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
92
+ super().__init__()
93
+
94
+ self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
95
+ self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
96
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
97
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
98
+ self.adaLN_modulation = nn.Sequential(
99
+ nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)
100
+ )
101
+
102
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
103
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
104
+
105
+ def forward(self, x, c):
106
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
107
+ self.adaLN_modulation(c).chunk(6, dim=-1)
108
+ )
109
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
110
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
111
+ return x
112
+
113
+
114
+ class Block(nn.Module):
115
+ """Standard Transformer block"""
116
+
117
+ def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
118
+ super().__init__()
119
+
120
+ self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
121
+ self.mlp = FeedForward(dim, mlp_dim, dropout=dropout)
122
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
123
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
124
+
125
+ def forward(self, x):
126
+ x = x + self.attn(self.norm1(x))
127
+ x = x + self.mlp(self.norm2(x))
128
+ return x
129
+
130
+
131
+ class Transformer(nn.Module):
132
+ """Standard Transformer with support for AdaLN-zero blocks"""
133
+
134
+ def __init__(
135
+ self,
136
+ input_dim,
137
+ hidden_dim,
138
+ output_dim,
139
+ depth,
140
+ heads,
141
+ dim_head,
142
+ mlp_dim,
143
+ dropout=0.0,
144
+ block_class=Block,
145
+ ):
146
+ super().__init__()
147
+ self.norm = nn.LayerNorm(hidden_dim)
148
+ self.layers = nn.ModuleList([])
149
+
150
+ self.input_proj = (
151
+ nn.Linear(input_dim, hidden_dim)
152
+ if input_dim != hidden_dim
153
+ else nn.Identity()
154
+ )
155
+
156
+ self.cond_proj = (
157
+ nn.Linear(input_dim, hidden_dim)
158
+ if input_dim != hidden_dim
159
+ else nn.Identity()
160
+ )
161
+
162
+ self.output_proj = (
163
+ nn.Linear(hidden_dim, output_dim)
164
+ if hidden_dim != output_dim
165
+ else nn.Identity()
166
+ )
167
+
168
+ for _ in range(depth):
169
+ self.layers.append(
170
+ block_class(hidden_dim, heads, dim_head, mlp_dim, dropout)
171
+ )
172
+
173
+ def forward(self, x, c=None):
174
+
175
+ if hasattr(self, "input_proj"):
176
+ x = self.input_proj(x)
177
+
178
+ if c is not None and hasattr(self, "cond_proj"):
179
+ c = self.cond_proj(c)
180
+
181
+ for block in self.layers:
182
+ x = block(x) if isinstance(block, Block) else block(x, c)
183
+ x = self.norm(x)
184
+
185
+ if hasattr(self, "output_proj"):
186
+ x = self.output_proj(x)
187
+ return x
188
+
189
+ class Embedder(nn.Module):
190
+ def __init__(
191
+ self,
192
+ input_dim=10,
193
+ smoothed_dim=10,
194
+ emb_dim=10,
195
+ mlp_scale=4,
196
+ ):
197
+ super().__init__()
198
+ self.patch_embed = nn.Conv1d(input_dim, smoothed_dim, kernel_size=1, stride=1)
199
+ self.embed = nn.Sequential(
200
+ nn.Linear(smoothed_dim, mlp_scale * emb_dim),
201
+ nn.SiLU(),
202
+ nn.Linear(mlp_scale * emb_dim, emb_dim),
203
+ )
204
+
205
+ def forward(self, x):
206
+ """
207
+ x: (B, T, D)
208
+ """
209
+ x = x.float()
210
+ x = x.permute(0, 2, 1)
211
+ x = self.patch_embed(x)
212
+ x = x.permute(0, 2, 1)
213
+ x = self.embed(x)
214
+ return x
215
+
216
+
217
+ class MLP(nn.Module):
218
+ """Simple MLP with optional normalization and activation"""
219
+
220
+ def __init__(
221
+ self,
222
+ input_dim,
223
+ hidden_dim,
224
+ output_dim=None,
225
+ norm_fn=nn.LayerNorm,
226
+ act_fn=nn.GELU,
227
+ ):
228
+ super().__init__()
229
+ norm_fn = norm_fn(hidden_dim) if norm_fn is not None else nn.Identity()
230
+ self.net = nn.Sequential(
231
+ nn.Linear(input_dim, hidden_dim),
232
+ norm_fn,
233
+ act_fn(),
234
+ nn.Linear(hidden_dim, output_dim or input_dim),
235
+ )
236
+
237
+ def forward(self, x):
238
+ """
239
+ x: (B*T, D)
240
+ """
241
+ return self.net(x)
242
+
243
+
244
+ class ActionEncoder2DWrapper(nn.Module):
245
+ """Slices 6-dim raw action input down to 2-dim (dx, dy) before encoding.
246
+
247
+ Accepts either (..., frameskip*6) or (..., frameskip*2) trailing dim
248
+ and adapts. Lives in module.py so it's importable from any script
249
+ that already imports `module` to use other LeWM components.
250
+ """
251
+ def __init__(self, inner: nn.Module, frameskip: int = 5):
252
+ super().__init__()
253
+ self.inner = inner
254
+ self.frameskip = frameskip
255
+
256
+ def forward(self, x):
257
+ if x.shape[-1] == self.frameskip * 6:
258
+ B = x.shape[:-1]
259
+ x = x.reshape(*B, self.frameskip, 6)
260
+ x = x[..., :2]
261
+ x = x.reshape(*B, self.frameskip * 2)
262
+ return self.inner(x)
263
+
264
+
265
+ class ARPredictor(nn.Module):
266
+ """Autoregressive predictor for next-step embedding prediction."""
267
+
268
+ def __init__(
269
+ self,
270
+ *,
271
+ num_frames,
272
+ depth,
273
+ heads,
274
+ mlp_dim,
275
+ input_dim,
276
+ hidden_dim,
277
+ output_dim=None,
278
+ dim_head=64,
279
+ dropout=0.0,
280
+ emb_dropout=0.0,
281
+ ):
282
+ super().__init__()
283
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_frames, input_dim))
284
+ self.dropout = nn.Dropout(emb_dropout)
285
+ self.transformer = Transformer(
286
+ input_dim,
287
+ hidden_dim,
288
+ output_dim or input_dim,
289
+ depth,
290
+ heads,
291
+ dim_head,
292
+ mlp_dim,
293
+ dropout,
294
+ block_class=ConditionalBlock,
295
+ )
296
+
297
+ def forward(self, x, c):
298
+ """
299
+ x: (B, T, d)
300
+ c: (B, T, act_dim)
301
+ """
302
+ T = x.size(1)
303
+ x = x + self.pos_embedding[:, :T]
304
+ x = self.dropout(x)
305
+ x = self.transformer(x, c)
306
+ return x
prior_head.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PriorHead — MLP that maps (z_t, z_g) → Gaussian over an action sequence.
2
+
3
+ Per §9 design: input is concat(z_t, z_g) ∈ R^{2D}, output is (μ, σ) over
4
+ H × A_block × A_raw normalized actions. σ is per-input via softplus + a floor.
5
+ The head's output sits in StandardScaler-normalized action space; the eval-side
6
+ policy is responsible for inverse-transform back to env action units.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class PriorHead(nn.Module):
16
+ def __init__(
17
+ self,
18
+ z_dim: int,
19
+ H: int,
20
+ A_block: int,
21
+ A_raw: int,
22
+ hidden: int = 512,
23
+ sigma_floor: float = 0.05,
24
+ ):
25
+ super().__init__()
26
+ self.z_dim = z_dim
27
+ self.H = H
28
+ self.A_block = A_block
29
+ self.A_raw = A_raw
30
+ self.action_seq_dim = H * A_block * A_raw
31
+ self.sigma_floor = sigma_floor
32
+
33
+ self.mlp = nn.Sequential(
34
+ nn.Linear(2 * z_dim, hidden),
35
+ nn.GELU(),
36
+ nn.Linear(hidden, hidden),
37
+ nn.GELU(),
38
+ nn.Linear(hidden, 2 * self.action_seq_dim),
39
+ )
40
+
41
+ def forward(self, z_t: torch.Tensor, z_g: torch.Tensor):
42
+ """z_t, z_g: (B, D). Returns (mu, sigma) each of shape (B, H, A_block, A_raw)."""
43
+ x = torch.cat([z_t, z_g], dim=-1)
44
+ out = self.mlp(x)
45
+ mu_flat, log_sigma_flat = out.chunk(2, dim=-1)
46
+ sigma_flat = F.softplus(log_sigma_flat) + self.sigma_floor
47
+
48
+ B = mu_flat.size(0)
49
+ shape = (B, self.H, self.A_block, self.A_raw)
50
+ return mu_flat.view(shape), sigma_flat.view(shape)
51
+
52
+
53
+ def beta_nll_loss(
54
+ mu: torch.Tensor,
55
+ sigma: torch.Tensor,
56
+ target: torch.Tensor,
57
+ beta: float = 0.5,
58
+ ) -> torch.Tensor:
59
+ """β-NLL (Seitzer et al. 2022).
60
+
61
+ Standard Gaussian NLL per element (dropping additive constant):
62
+ nll_i = 0.5 (y_i − μ_i)² / σ_i² + log σ_i
63
+ β-NLL multiplies by stop_grad(σ_i^(2β)) before averaging:
64
+ L = mean_i [ stop_grad(σ_i^{2β}) · nll_i ]
65
+ β=0.5 is the recommended robust default — keeps σ-gradient alive but
66
+ prevents the σ-blow-up pathology of vanilla NLL when μ is hard to fit.
67
+ """
68
+ var = sigma.pow(2)
69
+ log_sigma = sigma.log()
70
+ sq_err = (target - mu).pow(2)
71
+ nll = 0.5 * sq_err / var + log_sigma
72
+ weight = sigma.detach().pow(2 * beta)
73
+ return (weight * nll).mean()
prior_head_red_cube.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0026eda7168dc515df27fadeb415ac2e6bd25fb98206b0db6f34fb365d6c591
3
+ size 2359141