Upload README.md with huggingface_hub
Browse files
README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: other
|
| 3 |
+
tags:
|
| 4 |
+
- robotics
|
| 5 |
+
- rlt
|
| 6 |
+
- rl-token
|
| 7 |
+
- molmoact2
|
| 8 |
+
- vla
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# RLT Stage-1 RL Token Encoder (MolmoAct2 / YAM stack-cube)
|
| 12 |
+
|
| 13 |
+
Backup of the **RL Token (RLT) Stage-1 encoder** for the frozen MolmoAct2-BimanualYAM
|
| 14 |
+
stack-cube fine-tune. Faithful PyTorch port of openpi's `pi0_rl.py` (Xu et al. 2025):
|
| 15 |
+
a learned `<rl>` query compresses the VLA's `(M=690, 2560)` prefix hidden states into a
|
| 16 |
+
single **`z_rl`** token; a causal AR decoder reconstructs the prefix (per-token squared-L2,
|
| 17 |
+
stop-grad targets, α=0 / frozen VLA). `z_rl` is the state for the downstream SAC actor-critic.
|
| 18 |
+
|
| 19 |
+
## Chosen encoder
|
| 20 |
+
**`checkpoints/rl_token_encoder_ctxdrop09_best.pt`** (load `["ema"]`). Trained with the
|
| 21 |
+
openpi/paper knobs (AdamW 5e-5, 1k warmup, grad-clip 1.0, EMA 0.999, 10k steps) **plus
|
| 22 |
+
`context_dropout=0.9`** — zeroing 90% of the decoder's teacher-forced context, which fixes
|
| 23 |
+
the AR-leak that otherwise leaves `z_rl` diffuse (the bare α=0 reconstruction lets the decoder
|
| 24 |
+
ignore the token).
|
| 25 |
+
|
| 26 |
+
## Validation
|
| 27 |
+
| | baseline (α=0) | **dropout-0.9 (chosen)** |
|
| 28 |
+
|---|---|---|
|
| 29 |
+
| PCA top-10 var | 15% | **28%** |
|
| 30 |
+
| temporal smoothness (↓) | 0.72 | **0.69** |
|
| 31 |
+
| **success-vs-failure** LogReg CV acc | — | **99.2%** (silhouette 0.13) |
|
| 32 |
+
|
| 33 |
+
`z_rl` cleanly separates success (44 teleop demos) from failure (7 baseline rollouts, SR≈0)
|
| 34 |
+
in t-SNE — see `outputs/gate_success_fail.png`. Caveat: success/failure are from different
|
| 35 |
+
sessions, so part of the 99% is domain shift, not pure task semantics — strong upper bound.
|
| 36 |
+
|
| 37 |
+
## Data
|
| 38 |
+
Trained on 9,668 `(690,2560)` prefix shards from the 44 `atharva-pantheon/yam-stack-cube`
|
| 39 |
+
demos (~1.3 h teleop @ 10 Hz). Matches the RL Token paper's "small per-task demo set" (1–10 h).
|
| 40 |
+
|
| 41 |
+
## Files
|
| 42 |
+
- `code/` — `rl_token_encoder.py` (model), `train_encoder.py`, `collect_prefix.py` (demo→prefix
|
| 43 |
+
collector), `collect_fail_replay.py` (karma-rollout→prefix), `tsne_gate.py`, `gate_success_fail.py`.
|
| 44 |
+
- `checkpoints/` — `ctxdrop09_best/final` (chosen), `nodrop_best/final` (baseline), `ctxdrop05_best`.
|
| 45 |
+
- `plots/` — `tsne_final.png` (phase structure), `gate_success_fail.png` (success/fail), others.
|
| 46 |
+
|
| 47 |
+
## Use (Phase-4 actor-critic)
|
| 48 |
+
```python
|
| 49 |
+
import torch
|
| 50 |
+
from rl_token_encoder import RLTokenAutoencoder, RLTokenConfig
|
| 51 |
+
ae = RLTokenAutoencoder(RLTokenConfig(dim=2560))
|
| 52 |
+
ae.load_state_dict(torch.load("rl_token_encoder_ctxdrop09_best.pt", map_location="cpu")["ema"])
|
| 53 |
+
ae.eval()
|
| 54 |
+
z_rl = ae.encode(prefix, mask) # (b, M, 2560) -> (b, 2560); SAC state x = (z_rl, proprio)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
**Gotcha:** validate `z_rl` via `tsne_gate.py` / `gate_success_fail.py`, NOT a first-token
|
| 58 |
+
ablation — the first prefix token is a constant special id (151645), making that test vacuous.
|