atharva-pantheon commited on
Commit
af282a6
·
verified ·
1 Parent(s): 77604fb

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +58 -0
README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ tags:
4
+ - robotics
5
+ - rlt
6
+ - rl-token
7
+ - molmoact2
8
+ - vla
9
+ ---
10
+
11
+ # RLT Stage-1 RL Token Encoder (MolmoAct2 / YAM stack-cube)
12
+
13
+ Backup of the **RL Token (RLT) Stage-1 encoder** for the frozen MolmoAct2-BimanualYAM
14
+ stack-cube fine-tune. Faithful PyTorch port of openpi's `pi0_rl.py` (Xu et al. 2025):
15
+ a learned `<rl>` query compresses the VLA's `(M=690, 2560)` prefix hidden states into a
16
+ single **`z_rl`** token; a causal AR decoder reconstructs the prefix (per-token squared-L2,
17
+ stop-grad targets, α=0 / frozen VLA). `z_rl` is the state for the downstream SAC actor-critic.
18
+
19
+ ## Chosen encoder
20
+ **`checkpoints/rl_token_encoder_ctxdrop09_best.pt`** (load `["ema"]`). Trained with the
21
+ openpi/paper knobs (AdamW 5e-5, 1k warmup, grad-clip 1.0, EMA 0.999, 10k steps) **plus
22
+ `context_dropout=0.9`** — zeroing 90% of the decoder's teacher-forced context, which fixes
23
+ the AR-leak that otherwise leaves `z_rl` diffuse (the bare α=0 reconstruction lets the decoder
24
+ ignore the token).
25
+
26
+ ## Validation
27
+ | | baseline (α=0) | **dropout-0.9 (chosen)** |
28
+ |---|---|---|
29
+ | PCA top-10 var | 15% | **28%** |
30
+ | temporal smoothness (↓) | 0.72 | **0.69** |
31
+ | **success-vs-failure** LogReg CV acc | — | **99.2%** (silhouette 0.13) |
32
+
33
+ `z_rl` cleanly separates success (44 teleop demos) from failure (7 baseline rollouts, SR≈0)
34
+ in t-SNE — see `outputs/gate_success_fail.png`. Caveat: success/failure are from different
35
+ sessions, so part of the 99% is domain shift, not pure task semantics — strong upper bound.
36
+
37
+ ## Data
38
+ Trained on 9,668 `(690,2560)` prefix shards from the 44 `atharva-pantheon/yam-stack-cube`
39
+ demos (~1.3 h teleop @ 10 Hz). Matches the RL Token paper's "small per-task demo set" (1–10 h).
40
+
41
+ ## Files
42
+ - `code/` — `rl_token_encoder.py` (model), `train_encoder.py`, `collect_prefix.py` (demo→prefix
43
+ collector), `collect_fail_replay.py` (karma-rollout→prefix), `tsne_gate.py`, `gate_success_fail.py`.
44
+ - `checkpoints/` — `ctxdrop09_best/final` (chosen), `nodrop_best/final` (baseline), `ctxdrop05_best`.
45
+ - `plots/` — `tsne_final.png` (phase structure), `gate_success_fail.png` (success/fail), others.
46
+
47
+ ## Use (Phase-4 actor-critic)
48
+ ```python
49
+ import torch
50
+ from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig
51
+ ae = RLTokenAutoencoder(RLTokenConfig(dim=2560))
52
+ ae.load_state_dict(torch.load("rl_token_encoder_ctxdrop09_best.pt", map_location="cpu")["ema"])
53
+ ae.eval()
54
+ z_rl = ae.encode(prefix, mask) # (b, M, 2560) -> (b, 2560); SAC state x = (z_rl, proprio)
55
+ ```
56
+
57
+ **Gotcha:** validate `z_rl` via `tsne_gate.py` / `gate_success_fail.py`, NOT a first-token
58
+ ablation — the first prefix token is a constant special id (151645), making that test vacuous.