AdrianLlopart commited on
Commit
60e450f
·
0 Parent(s):

Duplicate from AdrianLlopart/rskill-molmoact2-libero-nf4

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - molmoact2
5
+ - robotics
6
+ - image-text-to-text
7
+ - libero
8
+ - bitsandbytes
9
+ - nf4
10
+ - 4-bit
11
+ base_model: allenai/MolmoAct2-LIBERO
12
+ base_model_relation: quantized
13
+ ---
14
+
15
+ > ⚠️ **NF4-quantized fork for OpenRAL.** This repository is a **4-bit bitsandbytes NF4**
16
+ > quantization (compute dtype bf16) of [`allenai/MolmoAct2-LIBERO`](https://hf.co/allenai/MolmoAct2-LIBERO), produced for OpenRAL's
17
+ > ≤8 GiB-VRAM robot deployment. The authoritative scheme is in
18
+ > [`quantization_metadata.json`](./quantization_metadata.json) (`scheme: nf4`). The Hub's
19
+ > auto-detected **8-bit** tag is approximate (it reflects the packed bitsandbytes uint8 storage);
20
+ > the applied scheme is **NF4 (4-bit)**. Load via OpenRAL's `load_prequantized_state_for_rskill`.
21
+
22
+ <img src="assets/MolmoAct2.svg" alt="MolmoAct Logo" height="50">
23
+
24
+ # **MolmoAct2-LIBERO**
25
+
26
+ MolmoAct2 is an open vision-language-action model for robot control. It builds on Molmo2-ER and attaches a flow-matching continuous action expert that conditions on the VLM key-value cache through a per-layer connection.
27
+
28
+ This checkpoint is fine-tuned on the full LIBERO training mixture, combining Spatial, Object, Goal, and Long suites. It is intended for both further fine-tuning and LIBERO policy inference.
29
+
30
+ ## Quick Links
31
+
32
+ - 📂 Models: [Models](https://huggingface.co/collections/allenai/molmoact2-models), [Finetuned Models](https://huggingface.co/collections/allenai/molmoact2-finetuned-models)
33
+ - 📂 Datasets: [MolmoAct2-BimanualYAM Dataset](https://huggingface.co/collections/allenai/molmoact2-datasets), [MolmoAct2 Datasets](https://huggingface.co/collections/allenai/molmoact2-datasets), [Molmo2-ER Datasets](https://huggingface.co/collections/allenai/molmo2-er-datasets)
34
+ - 📄 Paper: [arXiv:2605.02881](https://arxiv.org/abs/2605.02881)
35
+ - 💻 Code: [allenai/molmoact2](https://github.com/allenai/molmoact2)
36
+ - 🎥 Blog Post: [MolmoAct2](https://allenai.org/blog/molmoact2)
37
+
38
+ ## Intended Use
39
+
40
+ Use this checkpoint for LIBERO inference or for further fine-tuning. Dataset normalization metadata is stored in `norm_stats.json`. pass `norm_tag="libero"` at inference time.
41
+
42
+ Continuous action prediction is the intended and recommended inference mode. Discrete action prediction is exposed for parity and debugging, but we use continuous actions by default.
43
+
44
+ ## Install
45
+
46
+ ```bash
47
+ pip install torch transformers pillow numpy huggingface_hub
48
+ ```
49
+
50
+ ## Sample Input
51
+
52
+ This sample comes from `libero_10`, episode 0, frame 0. The LIBERO camera order is front/agent view followed by wrist view.
53
+
54
+ | Agentview RGB | Wrist RGB |
55
+ | --- | --- |
56
+ | ![Sample agentview RGB](assets/sample_agentview_rgb.png) | ![Sample wrist RGB](assets/sample_wrist_rgb.png) |
57
+
58
+ ```python
59
+ from huggingface_hub import hf_hub_download
60
+ from PIL import Image
61
+ import numpy as np
62
+
63
+ repo_id = "allenai/MolmoAct2-LIBERO"
64
+
65
+ agentview_rgb = Image.open(
66
+ hf_hub_download(repo_id, "assets/sample_agentview_rgb.png")
67
+ ).convert("RGB")
68
+ wrist_rgb = Image.open(
69
+ hf_hub_download(repo_id, "assets/sample_wrist_rgb.png")
70
+ ).convert("RGB")
71
+
72
+ task = "put the white mug on the left plate and put the yellow and white mug on the right plate"
73
+ robot_state = np.array(
74
+ [
75
+ -0.05338004603981972,
76
+ 0.007029631175100803,
77
+ 0.6783280968666077,
78
+ 3.1407692432403564,
79
+ 0.0017593271331861615,
80
+ -0.08994418382644653,
81
+ 0.03878866136074066,
82
+ -0.03878721222281456,
83
+ ],
84
+ dtype=np.float32,
85
+ )
86
+ ```
87
+
88
+ ## Continuous Actions
89
+
90
+ ```python
91
+ import numpy as np
92
+ import torch
93
+ from huggingface_hub import hf_hub_download
94
+ from PIL import Image
95
+ from transformers import AutoModelForImageTextToText, AutoProcessor
96
+
97
+ repo_id = "allenai/MolmoAct2-LIBERO"
98
+
99
+ agentview_rgb = Image.open(
100
+ hf_hub_download(repo_id, "assets/sample_agentview_rgb.png")
101
+ ).convert("RGB")
102
+ wrist_rgb = Image.open(
103
+ hf_hub_download(repo_id, "assets/sample_wrist_rgb.png")
104
+ ).convert("RGB")
105
+ task = "put the white mug on the left plate and put the yellow and white mug on the right plate"
106
+ robot_state = np.array(
107
+ [
108
+ -0.05338004603981972,
109
+ 0.007029631175100803,
110
+ 0.6783280968666077,
111
+ 3.1407692432403564,
112
+ 0.0017593271331861615,
113
+ -0.08994418382644653,
114
+ 0.03878866136074066,
115
+ -0.03878721222281456,
116
+ ],
117
+ dtype=np.float32,
118
+ )
119
+
120
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
121
+ model = AutoModelForImageTextToText.from_pretrained(
122
+ repo_id,
123
+ trust_remote_code=True,
124
+ dtype=torch.float32,
125
+ ).to("cuda").eval()
126
+
127
+ out = model.predict_action(
128
+ processor=processor,
129
+ images=[agentview_rgb, wrist_rgb],
130
+ task=task,
131
+ state=robot_state,
132
+ norm_tag="libero",
133
+ inference_action_mode="continuous",
134
+ enable_depth_reasoning=False,
135
+ num_steps=10,
136
+ normalize_language=True,
137
+ enable_cuda_graph=True,
138
+ )
139
+
140
+ actions = out.actions
141
+ ```
142
+
143
+ MolmoAct2 was trained with mixed precision. For our reported experiments, we ran inference in `float32`. This path uses the most GPU memory: roughly 26GB with CUDA graph enabled, or around 24GB without CUDA graph.
144
+
145
+ If you have a GPU with less memory, you can run inference with `bfloat16` instead:
146
+
147
+ ```python
148
+ model = AutoModelForImageTextToText.from_pretrained(
149
+ repo_id,
150
+ trust_remote_code=True,
151
+ dtype=torch.bfloat16,
152
+ ).to("cuda").eval()
153
+
154
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
155
+ out = model.predict_action(...)
156
+ ```
157
+
158
+ Using `bfloat16` is much more memory efficient and can run under 16GB of GPU memory in our tests. It usually does not hurt performance much.
159
+
160
+
161
+ `images` should preserve camera order, for example `[agentview_rgb, wrist_rgb]`. Images may be PIL images or RGB arrays. `state` is the raw robot state, and actions are returned in robot scale.
162
+
163
+ `normalize_language=True` is the default. It lowercases the task string and removes trailing sentence punctuation to match training preprocessing. Set it to `False` if you need to preserve the task text exactly.
164
+
165
+ `enable_cuda_graph=True` is the default. The first few calls can be slow because the model warms up and captures CUDA graphs. run several random warm-up calls before measuring deployment latency. `num_steps` controls the continuous flow solver and defaults to the checkpoint config value, 10.
166
+
167
+ Depth reasoning is disabled for this checkpoint. Calling `enable_depth_reasoning=True` will raise an error.
168
+
169
+ ## Discrete Actions
170
+
171
+ Discrete action inference requires a caller-provided action tokenizer. It is not saved in this repository. Discrete mode decodes action tokens directly. the continuous action expert is not used.
172
+
173
+ ```python
174
+ action_tokenizer = AutoProcessor.from_pretrained(
175
+ "allenai/MolmoAct2-FAST-Tokenizer",
176
+ trust_remote_code=True,
177
+ )
178
+
179
+ out = model.predict_action(
180
+ processor=processor,
181
+ images=[agentview_rgb, wrist_rgb],
182
+ task=task,
183
+ state=robot_state,
184
+ norm_tag="libero",
185
+ inference_action_mode="discrete",
186
+ action_tokenizer=action_tokenizer,
187
+ enable_depth_reasoning=False,
188
+ )
189
+ ```
190
+
191
+ ## Model and Hardware Safety
192
+
193
+ MolmoAct2 generate robot actions from visual observations and language instructions, but their behavior may vary across embodiments, environments, and hardware configurations. Users should carefully validate model outputs before deployment, especially when operating physical robots or other actuated systems. Where possible, actions should be monitored through interpretable intermediate outputs (adaptive depth map), simulation rollouts, action limits, or other safety checks before execution on hardware. The model’s action space should be bounded by the training data, robot controller limits, and task-specific safety constraints, including limits on speed, workspace, torque, and contact force. Users should follow the hardware manufacturer’s safety guidelines, use appropriate emergency-stop mechanisms, and operate the system only in a safely configured environment with human supervision.
194
+
195
+ ## Citation
196
+
197
+ ```bibtex
198
+ @misc{fang2026molmoact2actionreasoningmodels,
199
+ title={MolmoAct2: Action Reasoning Models for Real-world Deployment},
200
+ author={Haoquan Fang and Jiafei Duan and Donovan Clay and Sam Wang and Shuo Liu and Weikai Huang and Xiang Fan and Wei-Chuan Tsai and Shirui Chen and Yi Ru Wang and Shanli Xing and Jaemin Cho and Jae Sung Park and Ainaz Eftekhar and Peter Sushko and Karen Farley and Angad Wadhwa and Cole Harrison and Winson Han and Ying-Chun Lee and Eli VanderBilt and Rose Hendrix and Suveen Ellawela and Lucas Ngoo and Joyce Chai and Zhongzheng Ren and Ali Farhadi and Dieter Fox and Ranjay Krishna},
201
+ year={2026},
202
+ eprint={2605.02881},
203
+ archivePrefix={arXiv},
204
+ primaryClass={cs.RO},
205
+ url={https://arxiv.org/abs/2605.02881},
206
+ }
207
+ ```
config.json ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "action_end_token_id": 151933,
3
+ "action_expert_config": {
4
+ "attn_dropout": 0.0,
5
+ "causal_attn": false,
6
+ "context_layer_norm": true,
7
+ "dropout": 0.0,
8
+ "ffn_multiple_of": 256,
9
+ "hidden_size": 768,
10
+ "mlp_ratio": 4.0,
11
+ "model_type": "molmoact2_action_expert",
12
+ "num_heads": 8,
13
+ "num_layers": 36,
14
+ "qk_norm": true,
15
+ "qk_norm_eps": 1e-06,
16
+ "rope": true,
17
+ "timestep_embed_dim": 256
18
+ },
19
+ "action_expert_depth_gate": false,
20
+ "action_expert_depth_gate_init_bias": -4.0,
21
+ "action_expert_depth_gate_per_layer": false,
22
+ "action_mode": "both",
23
+ "max_action_horizon": 10,
24
+ "action_output_token_id": 151931,
25
+ "action_start_token_id": 151932,
26
+ "action_token_start_id": 151934,
27
+ "adapter_config": {
28
+ "attention_dropout": 0.0,
29
+ "attn_implementation": "sdpa",
30
+ "float32_attention": true,
31
+ "head_dim": 72,
32
+ "hidden_act": "silu",
33
+ "hidden_size": 1152,
34
+ "image_feature_dropout": 0.0,
35
+ "initializer_range": 0.02,
36
+ "intermediate_size": 9728,
37
+ "model_type": "molmoact2",
38
+ "num_attention_heads": 16,
39
+ "num_key_value_heads": 16,
40
+ "pooling_attention_mask": true,
41
+ "residual_dropout": 0.0,
42
+ "text_hidden_size": 2560,
43
+ "vit_layers": [
44
+ -3,
45
+ -9
46
+ ]
47
+ },
48
+ "add_action_expert": true,
49
+ "add_control_tokens": true,
50
+ "add_setup_tokens": true,
51
+ "architectures": [
52
+ "MolmoAct2ForConditionalGeneration"
53
+ ],
54
+ "auto_map": {
55
+ "AutoConfig": "configuration_molmoact2.MolmoAct2Config",
56
+ "AutoModelForImageTextToText": "modeling_molmoact2.MolmoAct2ForConditionalGeneration"
57
+ },
58
+ "depth_end_token_id": null,
59
+ "depth_mode": 2,
60
+ "depth_output_token_id": null,
61
+ "depth_start_token_id": null,
62
+ "depth_token_start_id": null,
63
+ "dtype": "float32",
64
+ "enable_depth_reasoning": false,
65
+ "flow_matching_beta_alpha": 1.0,
66
+ "flow_matching_beta_beta": 1.5,
67
+ "flow_matching_cutoff": 1.0,
68
+ "flow_matching_num_steps": 10,
69
+ "flow_matching_time_offset": 0.001,
70
+ "flow_matching_time_scale": 0.999,
71
+ "frame_end_token_id": 154632,
72
+ "frame_start_token_id": 154631,
73
+ "image_col_id": 154627,
74
+ "image_end_token_id": 154625,
75
+ "image_high_res_id": 154626,
76
+ "image_low_res_id": 154630,
77
+ "image_patch_id": 154626,
78
+ "image_start_token_id": 154624,
79
+ "initializer_range": 0.02,
80
+ "low_res_image_start_token_id": 154628,
81
+ "mask_action_dim_padding": true,
82
+ "max_action_dim": 32,
83
+ "model_type": "molmoact2",
84
+ "n_obs_steps": 1,
85
+ "norm_stats_filename": "norm_stats.json",
86
+ "num_action_tokens": 2048,
87
+ "num_depth_codes": 100,
88
+ "num_depth_tokens": 0,
89
+ "num_state_tokens": 256,
90
+ "state_end_token_id": 151674,
91
+ "state_format": "discrete",
92
+ "state_start_token_id": 151673,
93
+ "state_token_start_id": 151675,
94
+ "text_config": {
95
+ "additional_vocab_size": 128,
96
+ "attention_dropout": 0.0,
97
+ "attn_implementation": "sdpa",
98
+ "embedding_dropout": 0.0,
99
+ "head_dim": 128,
100
+ "hidden_act": "silu",
101
+ "hidden_size": 2560,
102
+ "initializer_range": 0.02,
103
+ "intermediate_size": 9728,
104
+ "layer_norm_eps": 1e-06,
105
+ "max_position_embeddings": 16384,
106
+ "model_type": "molmoact2_text",
107
+ "norm_after": false,
108
+ "num_attention_heads": 32,
109
+ "num_hidden_layers": 36,
110
+ "num_key_value_heads": 8,
111
+ "qk_norm_type": "qwen3",
112
+ "qkv_bias": false,
113
+ "residual_dropout": 0.0,
114
+ "rope_parameters": {
115
+ "rope_theta": 5000000.0,
116
+ "rope_type": "default"
117
+ },
118
+ "rope_scaling_layers": null,
119
+ "rope_theta": 5000000.0,
120
+ "tie_word_embeddings": false,
121
+ "use_cache": true,
122
+ "use_qk_norm": true,
123
+ "vocab_size": 154624
124
+ },
125
+ "tie_word_embeddings": false,
126
+ "transformers_version": "5.3.0",
127
+ "use_frame_special_tokens": true,
128
+ "vit_config": {
129
+ "attention_dropout": 0.0,
130
+ "attn_implementation": "sdpa",
131
+ "float32_attention": true,
132
+ "head_dim": 72,
133
+ "hidden_act": "gelu_pytorch_tanh",
134
+ "hidden_size": 1152,
135
+ "image_default_input_size": [
136
+ 378,
137
+ 378
138
+ ],
139
+ "image_num_pos": 729,
140
+ "image_patch_size": 14,
141
+ "initializer_range": 0.02,
142
+ "intermediate_size": 4304,
143
+ "layer_norm_eps": 1e-06,
144
+ "model_type": "molmoact2",
145
+ "num_attention_heads": 16,
146
+ "num_hidden_layers": 27,
147
+ "num_key_value_heads": 16,
148
+ "residual_dropout": 0.0
149
+ },
150
+ "bos_token_id": 151645,
151
+ "eos_token_id": 151645,
152
+ "pad_token_id": 151643
153
+ }
configuration_molmoact2.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MolmoAct2 configuration
3
+ """
4
+
5
+ from typing import Optional, Any
6
+
7
+ from transformers import PretrainedConfig
8
+ from transformers.modeling_rope_utils import rope_config_validation
9
+ from transformers.utils import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class MolmoAct2VitConfig(PretrainedConfig):
15
+ r"""
16
+ This is the configuration class to store the configuration of a [`MolmoAct2VisionTransformer`].
17
+ It is used to instantiate a `MolmoAct2VisionTransformer` according to the specified arguments,
18
+ defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Example:
24
+ ```python
25
+ >>> from transformers import MolmoAct2VitConfig, MolmoAct2VisionTransformer
26
+
27
+ >>> # Initializing a MolmoAct2VitConfig
28
+ >>> configuration = MolmoAct2VitConfig()
29
+
30
+ >>> # Initializing a MolmoAct2VisionTransformer (with random weights)
31
+ >>> model = MolmoAct2VisionTransformer(configuration)
32
+
33
+ >>> # Accessing the model configuration
34
+ >>> configuration = model.config
35
+ ```"""
36
+
37
+ model_type = "molmoact2"
38
+ base_config_key = "vit_config"
39
+
40
+ def __init__(
41
+ self,
42
+ hidden_size: int = 1152,
43
+ intermediate_size: int = 4304,
44
+ num_hidden_layers: int = 27,
45
+ num_attention_heads: int = 16,
46
+ num_key_value_heads: int = 16,
47
+ head_dim: int = 72,
48
+ hidden_act: str = "gelu_pytorch_tanh",
49
+ layer_norm_eps: float = 1e-6,
50
+ image_default_input_size: tuple[int, int] = (378, 378),
51
+ image_patch_size: int = 14,
52
+ image_num_pos: int = 577,
53
+ attention_dropout: float = 0.0,
54
+ residual_dropout: float = 0.0,
55
+ initializer_range: float = 0.02,
56
+ float32_attention: bool = True,
57
+ attn_implementation: str = "eager",
58
+ **kwargs,
59
+ ):
60
+ self.attn_implementation = attn_implementation
61
+ super().__init__(
62
+ attn_implementation=attn_implementation,
63
+ **kwargs
64
+ )
65
+ self.hidden_size = hidden_size
66
+ self.intermediate_size = intermediate_size
67
+ self.num_hidden_layers = num_hidden_layers
68
+ self.num_attention_heads = num_attention_heads
69
+ self.num_key_value_heads = num_key_value_heads
70
+ self.head_dim = head_dim
71
+ self.hidden_act = hidden_act
72
+ self.layer_norm_eps = layer_norm_eps
73
+ self.image_default_input_size = image_default_input_size
74
+ self.image_patch_size = image_patch_size
75
+ self.image_num_pos = image_num_pos
76
+ self.attention_dropout = attention_dropout
77
+ self.residual_dropout = residual_dropout
78
+ self.initializer_range = initializer_range
79
+ self.float32_attention = float32_attention
80
+
81
+ @property
82
+ def image_num_patch(self):
83
+ h, w = self.image_default_input_size
84
+ return h // self.image_patch_size, w // self.image_patch_size
85
+
86
+
87
+ class MolmoAct2AdapterConfig(PretrainedConfig):
88
+ r"""
89
+ This is the configuration class to store the configuration of MolmoAct2Adapter. With MolmoAct2VitConfig,
90
+ It is used to instantiate an MolmoAct2VisionBackbone according to the specified arguments,
91
+ defining the model architecture.
92
+
93
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
94
+ documentation from [`PretrainedConfig`] for more information.
95
+
96
+ Example:
97
+
98
+ ```python
99
+ >>> from transformers import MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2VisionBackbone
100
+
101
+ >>> # Initializing a MolmoAct2VitConfig and a MolmoAct2AdapterConfig
102
+ >>> vit_config = MolmoAct2VitConfig()
103
+ >>> adapter_config = MolmoPoolingConfig()
104
+
105
+ >>> # Initializing a MolmoAct2VisionBackbone (with random weights)
106
+ >>> model = MolmoAct2VisionBackbone(vit_config, adapter_config)
107
+
108
+ >>> # Accessing the model configuration
109
+ >>> vit_configuration = model.vit_config
110
+ >>> adapter_configuration = model.adapter_config
111
+ ```"""
112
+
113
+ model_type = "molmoact2"
114
+ base_config_key = "adapter_config"
115
+
116
+ def __init__(
117
+ self,
118
+ vit_layers: tuple = (-3, -9),
119
+ pooling_attention_mask: bool = False,
120
+ hidden_size: int = 1152,
121
+ num_attention_heads: int = 16,
122
+ num_key_value_heads: int = 16,
123
+ head_dim: int = 72,
124
+ float32_attention: bool = True,
125
+ attention_dropout: float = 0.0,
126
+ residual_dropout: float = 0.0,
127
+ hidden_act: str = "silu",
128
+ intermediate_size: int = 18944,
129
+ text_hidden_size: int = 3584,
130
+ image_feature_dropout: float = 0.0,
131
+ initializer_range: float = 0.02,
132
+ attn_implementation: str = "eager",
133
+ **kwargs,
134
+ ):
135
+ self.attn_implementation = attn_implementation
136
+ super().__init__(
137
+ attn_implementation=attn_implementation,
138
+ **kwargs
139
+ )
140
+ self.vit_layers = vit_layers
141
+ self.pooling_attention_mask = pooling_attention_mask
142
+ self.hidden_size = hidden_size
143
+ self.num_attention_heads = num_attention_heads
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.head_dim = head_dim
146
+ self.float32_attention = float32_attention
147
+ self.attention_dropout = attention_dropout
148
+ self.residual_dropout = residual_dropout
149
+ self.hidden_act = hidden_act
150
+ self.intermediate_size = intermediate_size
151
+ self.text_hidden_size = text_hidden_size
152
+ self.image_feature_dropout = image_feature_dropout
153
+ self.initializer_range = initializer_range
154
+
155
+
156
+ class MolmoAct2TextConfig(PretrainedConfig):
157
+ r"""
158
+ This is the configuration class to store the configuration of a [`MolmoAct2TextModel`]. It is used to instantiate a
159
+ `MolmoAct2TextModel` according to the specified arguments, defining the model architecture.
160
+
161
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
162
+ documentation from [`PretrainedConfig`] for more information.
163
+
164
+ Example:
165
+ ```python
166
+ >>> from transformers import MolmoAct2TextConfig, MolmoAct2TextModel
167
+
168
+ >>> # Initializing a MolmoAct2TextConfig
169
+ >>> configuration = MolmoAct2TextConfig()
170
+
171
+ >>> # Initializing a MolmoAct2TextModel (with random weights)
172
+ >>> model = MolmoAct2TextModel(configuration)
173
+
174
+ >>> # Accessing the model configuration
175
+ >>> configuration = model.config
176
+ ```"""
177
+
178
+ model_type = "molmoact2_text"
179
+ base_config_key = "text_config"
180
+ keys_to_ignore_at_inference = ["past_key_values"]
181
+ base_model_tp_plan = {
182
+ "blocks.*.self_attn.att_proj": "colwise",
183
+ "blocks.*.self_attn.attn_out": "rowwise",
184
+ "blocks.*.mlp.ff_proj": "colwise",
185
+ "blocks.*.mlp.ff_out": "rowwise",
186
+ }
187
+ base_model_pp_plan = {
188
+ "wte": (["input_ids"], ["inputs_embeds"]),
189
+ "blocks": (["hidden_states", "attention_mask"], ["hidden_states"]),
190
+ "ln_f": (["hidden_states"], ["hidden_states"]),
191
+ }
192
+
193
+ def __init__(
194
+ self,
195
+ hidden_size: int = 3584,
196
+ num_attention_heads: int = 28,
197
+ num_key_value_heads: Optional[int] = 4,
198
+ head_dim: int = 128,
199
+ vocab_size: int = 152064,
200
+ additional_vocab_size: int = 128,
201
+ qkv_bias: bool = True,
202
+ num_hidden_layers: int = 48,
203
+ intermediate_size: int = 18944,
204
+ hidden_act: str = "silu",
205
+ embedding_dropout: float=0.0,
206
+ attention_dropout: float=0.0,
207
+ residual_dropout: float = 0.0,
208
+ max_position_embeddings: int = 4096,
209
+ rope_theta: float = 1000000.0,
210
+ rope_scaling: dict[str, Any] = None,
211
+ rope_scaling_layers: Optional[list[int]] = None,
212
+ use_qk_norm: bool = False,
213
+ qk_norm_type: str = "olmo",
214
+ layer_norm_eps: int = 1e-6,
215
+ norm_after: bool = False,
216
+ initializer_range: float = 0.02,
217
+ use_cache=True,
218
+ tie_word_embeddings=False,
219
+ attn_implementation: str = "eager",
220
+ **kwargs,
221
+ ):
222
+ self.attn_implementation = attn_implementation
223
+ super().__init__(
224
+ tie_word_embeddings=tie_word_embeddings,
225
+ attn_implementation=attn_implementation,
226
+ **kwargs
227
+ )
228
+ self.hidden_size = hidden_size
229
+ self.num_attention_heads = num_attention_heads
230
+ if num_key_value_heads is None:
231
+ num_key_value_heads = num_attention_heads
232
+ self.num_key_value_heads = num_key_value_heads
233
+ self.head_dim = head_dim
234
+ self.vocab_size = vocab_size
235
+ self.additional_vocab_size = additional_vocab_size
236
+ self.qkv_bias = qkv_bias
237
+ self.num_hidden_layers = num_hidden_layers
238
+ self.intermediate_size = intermediate_size
239
+ self.hidden_act = hidden_act
240
+ self.embedding_dropout = embedding_dropout
241
+ self.attention_dropout = attention_dropout
242
+ self.residual_dropout = residual_dropout
243
+ self.max_position_embeddings = max_position_embeddings
244
+ self.rope_theta = rope_theta
245
+ self.rope_scaling = rope_scaling
246
+ self.rope_scaling_layers = rope_scaling_layers
247
+ self.use_qk_norm = use_qk_norm
248
+ self.qk_norm_type = qk_norm_type
249
+ self.layer_norm_eps = layer_norm_eps
250
+ self.norm_after = norm_after
251
+ self.initializer_range = initializer_range
252
+ self.use_cache = use_cache
253
+
254
+ # Validate the correctness of rotary position embeddings parameters
255
+ rope_config_validation(self)
256
+
257
+
258
+ class MolmoAct2ActionExpertConfig(PretrainedConfig):
259
+ r"""Configuration for the MolmoAct2 modern action expert."""
260
+
261
+ model_type = "molmoact2_action_expert"
262
+ base_config_key = "action_expert_config"
263
+
264
+ def __init__(
265
+ self,
266
+ max_action_horizon: int = 32,
267
+ max_action_dim: int = 32,
268
+ hidden_size: int = 1024,
269
+ num_layers: int = 32,
270
+ num_heads: int = 16,
271
+ mlp_ratio: float = 8.0 / 3.0,
272
+ ffn_multiple_of: int = 256,
273
+ timestep_embed_dim: int = 256,
274
+ dropout: float = 0.0,
275
+ attn_dropout: float = 0.0,
276
+ context_layer_norm: bool = True,
277
+ qk_norm: bool = True,
278
+ qk_norm_eps: float = 1e-6,
279
+ rope: bool = True,
280
+ causal_attn: bool = False,
281
+ **kwargs,
282
+ ):
283
+ super().__init__(**kwargs)
284
+ self.max_action_horizon = max_action_horizon
285
+ self.max_action_dim = max_action_dim
286
+ self.hidden_size = hidden_size
287
+ self.num_layers = num_layers
288
+ self.num_heads = num_heads
289
+ self.mlp_ratio = mlp_ratio
290
+ self.ffn_multiple_of = ffn_multiple_of
291
+ self.timestep_embed_dim = timestep_embed_dim
292
+ self.dropout = dropout
293
+ self.attn_dropout = attn_dropout
294
+ self.context_layer_norm = context_layer_norm
295
+ self.qk_norm = qk_norm
296
+ self.qk_norm_eps = qk_norm_eps
297
+ self.rope = rope
298
+ self.causal_attn = causal_attn
299
+
300
+ def to_dict(self):
301
+ output = super().to_dict()
302
+ # These are derived from the parent MolmoAct2Config for HF exports. Keeping
303
+ # them out of the public nested config avoids duplicated sources of truth.
304
+ output.pop("max_action_horizon", None)
305
+ output.pop("max_action_dim", None)
306
+ return output
307
+
308
+
309
+ class MolmoAct2Config(PretrainedConfig):
310
+ r"""
311
+ This is the configuration class to store the configuration of a [`MolmoAct2ForConditionalGeneration`].
312
+ It is used to instantiate an MolmoAct2 model according to the specified arguments, defining the model architecture.
313
+
314
+ Example:
315
+
316
+ ```python
317
+ >>> from transformers import MolmoAct2Config, MolmoAct2VitConfig, MolmoAct2AdapterConfig, MolmoAct2TextConfig
318
+
319
+ >>> # Initializing a MolmoAct2VitConfig
320
+ >>> vit_config = MolmoAct2VitConfig()
321
+
322
+ >>> # Initializing a MolmoAct2AdapterConfig
323
+ >>> adapter_config = MolmoAct2AdapterConfig()
324
+
325
+ >>> # Initializing a MolmoAct2TextConfig
326
+ >>> text_config = MolmoAct2TextConfig()
327
+
328
+ >>> # Initializing a MolmoAct2Config
329
+ >>> configuration = MolmoAct2Config(
330
+ >>> vit_config=vit_config,
331
+ >>> adapter_config=adapter_config,
332
+ >>> text_config=text_config,
333
+ >>> image_start_token_id=151936,
334
+ >>> image_end_token_id=151937,
335
+ >>> image_patch_id=151938,
336
+ >>> image_col_id=151939,
337
+ >>> low_res_image_start_token_id=151940,
338
+ >>> image_low_res_id=151942,
339
+ >>> frame_start_token_id=151943,
340
+ >>> frame_end_token_id=151944,
341
+ >>> )
342
+
343
+ >>> # Initializing a model
344
+ >>> model = MolmoAct2ForConditionalGeneration(configuration)
345
+
346
+ >>> # Accessing the model configuration
347
+ >>> configuration = model.config
348
+ ```"""
349
+
350
+ model_type = "molmoact2"
351
+ sub_configs = {
352
+ "text_config": MolmoAct2TextConfig,
353
+ "vit_config": MolmoAct2VitConfig,
354
+ "adapter_config": MolmoAct2AdapterConfig,
355
+ "action_expert_config": MolmoAct2ActionExpertConfig,
356
+ }
357
+
358
+ def __init__(
359
+ self,
360
+ vit_config: MolmoAct2VitConfig = None,
361
+ adapter_config: MolmoAct2AdapterConfig = None,
362
+ text_config: MolmoAct2TextConfig = None,
363
+ action_expert_config: MolmoAct2ActionExpertConfig = None,
364
+ image_start_token_id: int = None,
365
+ low_res_image_start_token_id: int = None,
366
+ image_end_token_id: int = None,
367
+ image_low_res_id: int = None,
368
+ image_patch_id: int = None,
369
+ image_col_id: int = None,
370
+ frame_start_token_id: int = None,
371
+ frame_end_token_id: int = None,
372
+ use_frame_special_tokens: bool = True,
373
+ initializer_range: float = 0.02,
374
+ add_action_expert: bool = True,
375
+ max_action_dim: int = 32,
376
+ max_action_horizon: int = 30,
377
+ n_obs_steps: int = 30,
378
+ action_mode: str = "both",
379
+ state_format: str = "discrete",
380
+ flow_matching_num_steps: int = 10,
381
+ flow_matching_cutoff: float = 1.0,
382
+ flow_matching_time_offset: float = 0.001,
383
+ flow_matching_time_scale: float = 0.999,
384
+ flow_matching_beta_alpha: float = 1.0,
385
+ flow_matching_beta_beta: float = 1.5,
386
+ mask_action_dim_padding: bool = True,
387
+ enable_depth_reasoning: bool = False,
388
+ depth_mode: int = 2,
389
+ num_depth_codes: int = 100,
390
+ action_expert_depth_gate: bool = False,
391
+ action_expert_depth_gate_per_layer: bool = False,
392
+ action_expert_depth_gate_init_bias: float = -4.0,
393
+ action_output_token_id: int = None,
394
+ action_start_token_id: int = None,
395
+ action_end_token_id: int = None,
396
+ action_token_start_id: int = None,
397
+ num_action_tokens: int = 0,
398
+ depth_output_token_id: int = None,
399
+ depth_start_token_id: int = None,
400
+ depth_end_token_id: int = None,
401
+ depth_token_start_id: int = None,
402
+ num_depth_tokens: int = 0,
403
+ state_start_token_id: int = None,
404
+ state_end_token_id: int = None,
405
+ state_token_start_id: int = None,
406
+ num_state_tokens: int = 0,
407
+ add_setup_tokens: bool = True,
408
+ add_control_tokens: bool = True,
409
+ norm_stats_filename: str = "norm_stats.json",
410
+ **kwargs,
411
+ ):
412
+ super().__init__(**kwargs)
413
+ if vit_config is None:
414
+ self.vit_config = MolmoAct2VitConfig()
415
+ elif isinstance(vit_config, dict):
416
+ self.vit_config = MolmoAct2VitConfig(**vit_config)
417
+ else:
418
+ self.vit_config = vit_config
419
+ if adapter_config is None:
420
+ self.adapter_config = MolmoAct2AdapterConfig()
421
+ elif isinstance(adapter_config, dict):
422
+ self.adapter_config = MolmoAct2AdapterConfig(**adapter_config)
423
+ else:
424
+ self.adapter_config = adapter_config
425
+ if text_config is None:
426
+ self.text_config = MolmoAct2TextConfig()
427
+ elif isinstance(text_config, dict):
428
+ self.text_config = MolmoAct2TextConfig(**text_config)
429
+ else:
430
+ self.text_config = text_config
431
+ self.add_action_expert = bool(add_action_expert)
432
+ if not self.add_action_expert:
433
+ self.action_expert_config = None
434
+ elif action_expert_config is None:
435
+ self.action_expert_config = MolmoAct2ActionExpertConfig(
436
+ max_action_horizon=max_action_horizon,
437
+ max_action_dim=max_action_dim,
438
+ num_layers=self.text_config.num_hidden_layers,
439
+ )
440
+ elif isinstance(action_expert_config, dict):
441
+ self.action_expert_config = MolmoAct2ActionExpertConfig(**action_expert_config)
442
+ else:
443
+ self.action_expert_config = action_expert_config
444
+ if self.add_action_expert:
445
+ self.action_expert_config.max_action_dim = int(max_action_dim)
446
+ self.action_expert_config.max_action_horizon = int(max_action_horizon)
447
+ self._validate_release_action_config(
448
+ state_format=state_format,
449
+ )
450
+ self.image_start_token_id = image_start_token_id
451
+ self.low_res_image_start_token_id = low_res_image_start_token_id
452
+ self.image_end_token_id = image_end_token_id
453
+ self.image_low_res_id = image_low_res_id
454
+ self.image_high_res_id = image_patch_id
455
+ self.image_patch_id = image_patch_id
456
+ self.image_col_id = image_col_id
457
+ self.frame_start_token_id = frame_start_token_id
458
+ self.frame_end_token_id = frame_end_token_id
459
+ self.use_frame_special_tokens = use_frame_special_tokens
460
+ self.initializer_range = initializer_range
461
+ self.max_action_dim = max_action_dim
462
+ self.max_action_horizon = max_action_horizon
463
+ self.n_obs_steps = n_obs_steps
464
+ self.action_mode = action_mode
465
+ self.state_format = state_format
466
+ self.flow_matching_num_steps = flow_matching_num_steps
467
+ self.flow_matching_cutoff = flow_matching_cutoff
468
+ self.flow_matching_time_offset = flow_matching_time_offset
469
+ self.flow_matching_time_scale = flow_matching_time_scale
470
+ self.flow_matching_beta_alpha = flow_matching_beta_alpha
471
+ self.flow_matching_beta_beta = flow_matching_beta_beta
472
+ self.mask_action_dim_padding = mask_action_dim_padding
473
+ self.enable_depth_reasoning = enable_depth_reasoning
474
+ self.depth_mode = depth_mode
475
+ self.num_depth_codes = num_depth_codes
476
+ self.action_expert_depth_gate = action_expert_depth_gate
477
+ self.action_expert_depth_gate_per_layer = action_expert_depth_gate_per_layer
478
+ self.action_expert_depth_gate_init_bias = action_expert_depth_gate_init_bias
479
+ self.action_output_token_id = action_output_token_id
480
+ self.action_start_token_id = action_start_token_id
481
+ self.action_end_token_id = action_end_token_id
482
+ self.action_token_start_id = action_token_start_id
483
+ self.num_action_tokens = num_action_tokens
484
+ self.depth_output_token_id = depth_output_token_id
485
+ self.depth_start_token_id = depth_start_token_id
486
+ self.depth_end_token_id = depth_end_token_id
487
+ self.depth_token_start_id = depth_token_start_id
488
+ self.num_depth_tokens = num_depth_tokens
489
+ self.state_start_token_id = state_start_token_id
490
+ self.state_end_token_id = state_end_token_id
491
+ self.state_token_start_id = state_token_start_id
492
+ self.num_state_tokens = num_state_tokens
493
+ self.add_setup_tokens = add_setup_tokens
494
+ self.add_control_tokens = add_control_tokens
495
+ self.norm_stats_filename = norm_stats_filename
496
+
497
+ @staticmethod
498
+ def _validate_release_action_config(
499
+ *,
500
+ state_format: str,
501
+ ) -> None:
502
+ if state_format != "discrete":
503
+ raise ValueError("MolmoAct2 HF export supports only state_format='discrete'.")
504
+
505
+ @property
506
+ def image_num_patch(self):
507
+ assert self.vit_config is not None
508
+ return self.vit_config.image_num_patch
509
+
510
+ @property
511
+ def num_attention_heads(self):
512
+ return self.text_config.num_attention_heads
513
+
514
+ @property
515
+ def num_key_value_heads(self):
516
+ return self.text_config.num_key_value_heads
517
+
518
+ @property
519
+ def head_dim(self):
520
+ return self.text_config.head_dim
521
+
522
+ @property
523
+ def num_hidden_layers(self):
524
+ return self.text_config.num_hidden_layers
525
+
526
+ @property
527
+ def hidden_size(self):
528
+ return self.text_config.hidden_size
529
+
530
+ @property
531
+ def vocab_size(self):
532
+ return self.text_config.vocab_size
533
+
534
+ @property
535
+ def max_position_embeddings(self):
536
+ return self.text_config.max_position_embeddings
537
+
538
+
539
+ MolmoAct2VitConfig.register_for_auto_class()
540
+ MolmoAct2AdapterConfig.register_for_auto_class()
541
+ MolmoAct2TextConfig.register_for_auto_class()
542
+ MolmoAct2ActionExpertConfig.register_for_auto_class()
543
+ MolmoAct2Config.register_for_auto_class()
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151645,
3
+ "eos_token_id": 151645,
4
+ "pad_token_id": 151643,
5
+ "transformers_version": "5.3.0"
6
+ }
inference.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference utilities for MolmoAct2"""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Iterable, Optional, Sequence, Tuple
5
+
6
+ import torch
7
+ from torch.nn import functional as F
8
+ from transformers.cache_utils import Cache
9
+ from transformers.configuration_utils import PretrainedConfig
10
+
11
+
12
+ @dataclass
13
+ class _ActionFlowInputs:
14
+ trajectory: torch.Tensor
15
+ context: Any
16
+ modulations: Sequence[Any]
17
+ action_dim_is_pad: Optional[torch.Tensor]
18
+
19
+
20
+ @dataclass
21
+ class _ActionFlowCudaGraph:
22
+ key: Tuple[Any, ...]
23
+ graph: torch.cuda.CUDAGraph
24
+ static_inputs: _ActionFlowInputs
25
+ output: torch.Tensor
26
+
27
+
28
+ @dataclass
29
+ class _DepthDecodeCudaGraphLayerStage:
30
+ residual: torch.Tensor
31
+ query: torch.Tensor
32
+ key: torch.Tensor
33
+ value: torch.Tensor
34
+
35
+
36
+ @dataclass
37
+ class _DepthDecodeCudaGraphPostStage:
38
+ graph: torch.cuda.CUDAGraph
39
+ attn_context: torch.Tensor
40
+
41
+
42
+ @dataclass
43
+ class _DepthDecodeCudaGraph:
44
+ cache_key: Tuple[Any, ...]
45
+ pre_graph: torch.cuda.CUDAGraph
46
+ token_ids: torch.Tensor
47
+ cos: torch.Tensor
48
+ sin: torch.Tensor
49
+ positions: torch.Tensor
50
+ stages: Sequence[_DepthDecodeCudaGraphLayerStage]
51
+ post_graphs: Sequence[_DepthDecodeCudaGraphPostStage]
52
+ output: torch.Tensor
53
+
54
+
55
+ @dataclass
56
+ class _DepthDecodeCudaGraphSpec:
57
+ eligible: bool
58
+ cache_key_prefix: Tuple[Any, ...]
59
+ num_hidden_layers: int
60
+ head_dim: int
61
+ num_attention_heads: int
62
+
63
+
64
+ def _cache_seq_len_int(past_key_values: Optional[Cache]) -> int:
65
+ if past_key_values is None:
66
+ return 0
67
+ seq_len = past_key_values.get_seq_length()
68
+ if torch.is_tensor(seq_len):
69
+ return int(seq_len.item())
70
+ return int(seq_len)
71
+
72
+
73
+ def _cache_max_len_int(past_key_values: Optional[Cache]) -> int:
74
+ if past_key_values is None:
75
+ return -1
76
+ max_len = past_key_values.get_max_cache_shape()
77
+ if torch.is_tensor(max_len):
78
+ return int(max_len.item())
79
+ return int(max_len)
80
+
81
+
82
+ def _iter_cache_key_values(
83
+ past_key_values: Cache,
84
+ ) -> Iterable[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]]:
85
+ layers = getattr(past_key_values, "layers", None)
86
+ if layers is not None:
87
+ for layer in layers:
88
+ yield getattr(layer, "keys", None), getattr(layer, "values", None)
89
+ return
90
+ for layer in past_key_values:
91
+ yield layer[0], layer[1]
92
+
93
+
94
+ class _DepthDecodeStaticLayerCache:
95
+ is_compileable = False
96
+ is_sliding = False
97
+
98
+ def __init__(self, max_cache_len: int) -> None:
99
+ self.max_cache_len = int(max_cache_len)
100
+ self.cumulative_length = 0
101
+ self.keys: Optional[torch.Tensor] = None
102
+ self.values: Optional[torch.Tensor] = None
103
+
104
+ def _allocate(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
105
+ bsz, n_heads = key_states.shape[:2]
106
+ self.keys = torch.empty(
107
+ (bsz, n_heads, self.max_cache_len, key_states.shape[-1]),
108
+ dtype=key_states.dtype,
109
+ device=key_states.device,
110
+ )
111
+ self.values = torch.empty(
112
+ (bsz, n_heads, self.max_cache_len, value_states.shape[-1]),
113
+ dtype=value_states.dtype,
114
+ device=value_states.device,
115
+ )
116
+
117
+ def update(
118
+ self,
119
+ key_states: torch.Tensor,
120
+ value_states: torch.Tensor,
121
+ *args,
122
+ **kwargs,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ if self.keys is None:
125
+ self._allocate(key_states, value_states)
126
+ start = self.cumulative_length
127
+ end = start + key_states.shape[-2]
128
+ if end > self.max_cache_len:
129
+ raise RuntimeError(
130
+ f"KV cache length {end} exceeds max_cache_len={self.max_cache_len}."
131
+ )
132
+ self.keys[:, :, start:end, :].copy_(key_states)
133
+ self.values[:, :, start:end, :].copy_(value_states)
134
+ self.cumulative_length = end
135
+ return self.keys[:, :, :end, :], self.values[:, :, :end, :]
136
+
137
+ def get_seq_length(self) -> int:
138
+ return self.cumulative_length
139
+
140
+ def get_max_cache_shape(self) -> int:
141
+ return -1
142
+
143
+ def reset(self) -> None:
144
+ self.cumulative_length = 0
145
+
146
+
147
+ class _DepthDecodeStaticCache(Cache):
148
+ def __init__(self, config: PretrainedConfig, max_cache_len: int) -> None:
149
+ text_config = config.get_text_config(decoder=True)
150
+ super().__init__(
151
+ layers=[
152
+ _DepthDecodeStaticLayerCache(max_cache_len=max_cache_len)
153
+ for _ in range(text_config.num_hidden_layers)
154
+ ]
155
+ )
156
+
157
+ def get_seq_length(self, layer_idx: int = 0) -> int:
158
+ return self.layers[layer_idx].get_seq_length()
159
+
160
+ def get_max_cache_shape(self, layer_idx: int = 0) -> int:
161
+ return self.layers[layer_idx].get_max_cache_shape()
162
+
163
+ def reset(self) -> None:
164
+ for layer in self.layers:
165
+ layer.reset()
166
+
167
+
168
+ class ActionCudaGraphManager:
169
+ def __init__(self, model: Any) -> None:
170
+ self.model = model
171
+ self.enabled = True
172
+ self.action_flow_graph: Optional[_ActionFlowCudaGraph] = None
173
+
174
+ def set_enabled(self, enabled: bool) -> None:
175
+ self.enabled = bool(enabled)
176
+
177
+ def can_use_action_flow(self, inputs: _ActionFlowInputs) -> bool:
178
+ action_model = self.model
179
+ if not self.enabled:
180
+ return False
181
+ if action_model.training or action_model._require_action_expert().training:
182
+ return False
183
+ if inputs.trajectory.device.type != "cuda":
184
+ return False
185
+
186
+ def all_on_cuda():
187
+ yield inputs.trajectory
188
+ for k, v in inputs.context.kv_contexts:
189
+ yield k
190
+ yield v
191
+ for t in (
192
+ inputs.context.cross_mask,
193
+ inputs.context.self_mask,
194
+ inputs.context.valid_action,
195
+ inputs.action_dim_is_pad,
196
+ ):
197
+ if t is not None:
198
+ yield t
199
+ if inputs.context.rope_cache is not None:
200
+ yield from inputs.context.rope_cache
201
+ for step in inputs.modulations:
202
+ yield step.conditioning
203
+ for block_modulation in step.block_modulations:
204
+ yield from block_modulation
205
+ yield from step.final_modulation
206
+
207
+ return all(t.device.type == "cuda" for t in all_on_cuda())
208
+
209
+ def run_action_flow(
210
+ self,
211
+ inputs: _ActionFlowInputs,
212
+ steps: int,
213
+ run_loop,
214
+ ) -> torch.Tensor:
215
+ key = _cuda_graph_key(inputs, steps)
216
+ cache = self.action_flow_graph
217
+ if cache is None or cache.key != key:
218
+ static_inputs = _clone_static_inputs(inputs)
219
+ graph, output = _capture_cuda_graph(
220
+ lambda: run_loop(static_inputs, steps),
221
+ inputs.trajectory.device,
222
+ after_warmup=lambda: static_inputs.trajectory.copy_(inputs.trajectory),
223
+ )
224
+ cache = _ActionFlowCudaGraph(
225
+ key=key,
226
+ graph=graph,
227
+ static_inputs=static_inputs,
228
+ output=output,
229
+ )
230
+ self.action_flow_graph = cache
231
+ else:
232
+ _copy_inputs_(cache.static_inputs, inputs)
233
+
234
+ cache.graph.replay()
235
+ return cache.output.clone()
236
+
237
+
238
+ class DepthDecodeCudaGraphManager:
239
+ def __init__(self, model: Any) -> None:
240
+ self.model = model
241
+ self.backbone = model.model
242
+ self.enabled = True
243
+ self.graph: Optional[_DepthDecodeCudaGraph] = None
244
+ self.graph_spec: Optional[_DepthDecodeCudaGraphSpec] = None
245
+
246
+ def set_enabled(self, enabled: bool) -> None:
247
+ self.enabled = bool(enabled)
248
+
249
+ def make_static_cache(self, max_cache_len: int) -> _DepthDecodeStaticCache:
250
+ return _DepthDecodeStaticCache(
251
+ config=self.model.config.text_config,
252
+ max_cache_len=max_cache_len,
253
+ )
254
+
255
+ def _depth_decode_spec(self) -> _DepthDecodeCudaGraphSpec:
256
+ static = self.graph_spec
257
+ if static is None:
258
+ cfg = self.backbone.transformer.config
259
+ rotary_emb = getattr(self.backbone.transformer, "rotary_emb", None)
260
+ static = _DepthDecodeCudaGraphSpec(
261
+ eligible=(
262
+ not cfg.norm_after
263
+ and cfg.rope_scaling_layers is None
264
+ and getattr(rotary_emb, "rope_type", None) == "default"
265
+ and cfg._attn_implementation == "sdpa"
266
+ ),
267
+ cache_key_prefix=(
268
+ cfg.hidden_size,
269
+ cfg.num_attention_heads,
270
+ cfg.num_key_value_heads,
271
+ cfg.head_dim,
272
+ cfg.num_hidden_layers,
273
+ cfg.use_qk_norm,
274
+ cfg.qk_norm_type,
275
+ cfg._attn_implementation,
276
+ ),
277
+ num_hidden_layers=cfg.num_hidden_layers,
278
+ head_dim=cfg.head_dim,
279
+ num_attention_heads=cfg.num_attention_heads,
280
+ )
281
+ self.graph_spec = static
282
+ return static
283
+
284
+ def can_use(
285
+ self,
286
+ next_input_ids: torch.Tensor,
287
+ *,
288
+ past_key_values: Cache,
289
+ attention_bias: torch.Tensor,
290
+ ) -> bool:
291
+ if (
292
+ not self.enabled
293
+ or self.model.training
294
+ or self.backbone.transformer.training
295
+ ):
296
+ return False
297
+ if next_input_ids.device.type != "cuda":
298
+ return False
299
+ if (
300
+ next_input_ids.ndim != 2
301
+ or next_input_ids.shape[0] != 1
302
+ or next_input_ids.shape[1] != 1
303
+ ):
304
+ return False
305
+ if not isinstance(past_key_values, _DepthDecodeStaticCache):
306
+ return False
307
+ if (
308
+ not torch.is_tensor(attention_bias)
309
+ or attention_bias.device != next_input_ids.device
310
+ ):
311
+ return False
312
+ return self._depth_decode_spec().eligible
313
+
314
+ def _depth_decode_key(
315
+ self,
316
+ next_input_ids: torch.Tensor,
317
+ attention_bias: torch.Tensor,
318
+ ) -> Tuple[Any, ...]:
319
+ device = next_input_ids.device
320
+ return (
321
+ self._depth_decode_spec().cache_key_prefix,
322
+ device.type,
323
+ device.index,
324
+ self.model.lm_head.weight.dtype,
325
+ attention_bias.shape[-1],
326
+ )
327
+
328
+ def _select_depth_decode_rope(
329
+ self, cos: torch.Tensor, sin: torch.Tensor, *, past_length: int
330
+ ) -> None:
331
+ emb = self.backbone.transformer.rotary_emb
332
+ cos.copy_(emb._pos_cos_cache[0, :, past_length : past_length + 1, :])
333
+ sin.copy_(emb._pos_sin_cache[0, :, past_length : past_length + 1, :])
334
+
335
+ def _depth_decode_pre_layer(
336
+ self,
337
+ layer_idx: int,
338
+ hidden_states: torch.Tensor,
339
+ cos: torch.Tensor,
340
+ sin: torch.Tensor,
341
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
342
+ block = self.backbone.transformer.blocks[layer_idx]
343
+ attention = block.self_attn
344
+ residual = hidden_states
345
+ hidden_states = block.attn_norm(hidden_states)
346
+
347
+ input_shape = hidden_states.shape[:-1]
348
+ hidden_shape = (*input_shape, -1, attention.head_dim)
349
+ qkv = attention.att_proj(hidden_states)
350
+ query_states, key_states, value_states = qkv.split(attention.fused_dims, dim=-1)
351
+ value_states = value_states.view(hidden_shape)
352
+
353
+ apply_qk_norm = attention.q_norm is not None and attention.k_norm is not None
354
+ norm_after_view = apply_qk_norm and attention.qk_norm_type == "qwen3"
355
+
356
+ if apply_qk_norm and not norm_after_view:
357
+ query_states = attention.q_norm(query_states)
358
+ key_states = attention.k_norm(key_states)
359
+
360
+ query_states = query_states.view(hidden_shape)
361
+ key_states = key_states.view(hidden_shape)
362
+
363
+ if norm_after_view:
364
+ query_states = attention.q_norm(query_states)
365
+ key_states = attention.k_norm(key_states)
366
+
367
+ query_states = query_states.transpose(1, 2)
368
+ key_states = key_states.transpose(1, 2)
369
+ value_states = value_states.transpose(1, 2)
370
+ query_states, key_states = _apply_rotary_pos_emb(
371
+ query_states, key_states, cos, sin
372
+ )
373
+ return residual, query_states, key_states, value_states
374
+
375
+ def _depth_decode_pre0(
376
+ self,
377
+ token_ids: torch.Tensor,
378
+ cos: torch.Tensor,
379
+ sin: torch.Tensor,
380
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
381
+ inputs_embeds = self.model._embed_base_tokens(token_ids)
382
+ return self._depth_decode_pre_layer(0, inputs_embeds, cos, sin)
383
+
384
+ def _depth_decode_post_layer(
385
+ self,
386
+ layer_idx: int,
387
+ residual: torch.Tensor,
388
+ attn_context: torch.Tensor,
389
+ ) -> torch.Tensor:
390
+ block = self.backbone.transformer.blocks[layer_idx]
391
+ attention = block.self_attn
392
+ input_shape = residual.shape[:-1]
393
+ attn_output = attn_context.reshape(*input_shape, -1).contiguous()
394
+ attn_output = attention.attn_out(attn_output)
395
+ hidden_states = residual + block.dropout(attn_output)
396
+
397
+ residual = hidden_states
398
+ hidden_states = block.ff_norm(hidden_states)
399
+ hidden_states = block.mlp(hidden_states)
400
+ hidden_states = residual + block.dropout(hidden_states)
401
+ return hidden_states
402
+
403
+ def _depth_decode_post_and_pre_next(
404
+ self,
405
+ layer_idx: int,
406
+ residual: torch.Tensor,
407
+ attn_context: torch.Tensor,
408
+ cos: torch.Tensor,
409
+ sin: torch.Tensor,
410
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
411
+ hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
412
+ return self._depth_decode_pre_layer(layer_idx + 1, hidden_states, cos, sin)
413
+
414
+ def _depth_decode_last_post(
415
+ self,
416
+ layer_idx: int,
417
+ residual: torch.Tensor,
418
+ attn_context: torch.Tensor,
419
+ ) -> torch.Tensor:
420
+ hidden_states = self._depth_decode_post_layer(layer_idx, residual, attn_context)
421
+ return self.backbone.transformer.ln_f(hidden_states)
422
+
423
+ def _build_depth_decode_graph(
424
+ self,
425
+ next_input_ids: torch.Tensor,
426
+ *,
427
+ past_length: int,
428
+ attention_bias: torch.Tensor,
429
+ ) -> _DepthDecodeCudaGraph:
430
+ text_config = self.backbone.transformer.config
431
+ device = next_input_ids.device
432
+ dtype = self.model.lm_head.weight.dtype
433
+ static = self._depth_decode_spec()
434
+ num_layers = static.num_hidden_layers
435
+ head_dim = static.head_dim
436
+ max_cache_len = int(attention_bias.shape[-1])
437
+ max_rope_len = max(int(text_config.max_position_embeddings or 0), max_cache_len)
438
+ self.backbone.transformer.prepare_rope_cache(
439
+ device=device, max_seq_len=max_rope_len
440
+ )
441
+
442
+ token_ids = torch.empty((1, 1), device=device, dtype=torch.long)
443
+ cos = torch.empty((1, 1, head_dim), device=device, dtype=dtype)
444
+ sin = torch.empty_like(cos)
445
+ positions = torch.arange(max_cache_len, device=device, dtype=torch.long)
446
+ context_shape = (1, 1, static.num_attention_heads, head_dim)
447
+
448
+ token_ids.copy_(next_input_ids)
449
+ self._select_depth_decode_rope(cos, sin, past_length=past_length)
450
+
451
+ pre_graph, pre_output = _capture_cuda_graph(
452
+ lambda: self._depth_decode_pre0(token_ids, cos, sin),
453
+ device,
454
+ )
455
+ stages = [_DepthDecodeCudaGraphLayerStage(*pre_output)]
456
+ post_graphs = []
457
+ for layer_idx in range(num_layers - 1):
458
+ stage = stages[-1]
459
+ attn_context = torch.empty(context_shape, device=device, dtype=dtype)
460
+ graph, output = _capture_cuda_graph(
461
+ lambda layer_idx=layer_idx, stage=stage, attn_context=attn_context: (
462
+ self._depth_decode_post_and_pre_next(
463
+ layer_idx,
464
+ stage.residual,
465
+ attn_context,
466
+ cos,
467
+ sin,
468
+ )
469
+ ),
470
+ device,
471
+ )
472
+ post_graphs.append(
473
+ _DepthDecodeCudaGraphPostStage(graph=graph, attn_context=attn_context)
474
+ )
475
+ stages.append(_DepthDecodeCudaGraphLayerStage(*output))
476
+
477
+ last_stage = stages[-1]
478
+ last_attn_context = torch.empty(context_shape, device=device, dtype=dtype)
479
+ last_graph, last_output = _capture_cuda_graph(
480
+ lambda: self._depth_decode_last_post(
481
+ num_layers - 1,
482
+ last_stage.residual,
483
+ last_attn_context,
484
+ ),
485
+ device,
486
+ )
487
+ post_graphs.append(
488
+ _DepthDecodeCudaGraphPostStage(
489
+ graph=last_graph, attn_context=last_attn_context
490
+ )
491
+ )
492
+ return _DepthDecodeCudaGraph(
493
+ cache_key=self._depth_decode_key(next_input_ids, attention_bias),
494
+ pre_graph=pre_graph,
495
+ token_ids=token_ids,
496
+ cos=cos,
497
+ sin=sin,
498
+ positions=positions,
499
+ stages=tuple(stages),
500
+ post_graphs=tuple(post_graphs),
501
+ output=last_output,
502
+ )
503
+
504
+ def _get_depth_decode_graph(
505
+ self,
506
+ next_input_ids: torch.Tensor,
507
+ *,
508
+ past_length: int,
509
+ attention_bias: torch.Tensor,
510
+ ) -> _DepthDecodeCudaGraph:
511
+ key = self._depth_decode_key(next_input_ids, attention_bias)
512
+ decode_graph = self.graph
513
+ if decode_graph is None or decode_graph.cache_key != key:
514
+ decode_graph = self._build_depth_decode_graph(
515
+ next_input_ids,
516
+ past_length=past_length,
517
+ attention_bias=attention_bias,
518
+ )
519
+ self.graph = decode_graph
520
+ else:
521
+ decode_graph.token_ids.copy_(next_input_ids)
522
+ self._select_depth_decode_rope(
523
+ decode_graph.cos, decode_graph.sin, past_length=past_length
524
+ )
525
+ return decode_graph
526
+
527
+ def _run_depth_decode_attention_core(
528
+ self,
529
+ layer_idx: int,
530
+ stage: _DepthDecodeCudaGraphLayerStage,
531
+ *,
532
+ past_key_values: Cache,
533
+ attention_bias: torch.Tensor,
534
+ cache_position: torch.Tensor,
535
+ cos: torch.Tensor,
536
+ sin: torch.Tensor,
537
+ ) -> torch.Tensor:
538
+ attention = self.backbone.transformer.blocks[layer_idx].self_attn
539
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
540
+ key_states, value_states = past_key_values.update(
541
+ stage.key,
542
+ stage.value,
543
+ layer_idx,
544
+ cache_kwargs,
545
+ )
546
+ key_states = _repeat_kv(key_states, attention.num_key_value_groups)
547
+ value_states = _repeat_kv(value_states, attention.num_key_value_groups)
548
+ attn_output = F.scaled_dot_product_attention(
549
+ stage.query,
550
+ key_states,
551
+ value_states,
552
+ attn_mask=attention_bias,
553
+ dropout_p=0.0,
554
+ is_causal=False,
555
+ )
556
+ return attn_output.transpose(1, 2)
557
+
558
+ def run(
559
+ self,
560
+ next_input_ids: torch.Tensor,
561
+ *,
562
+ past_key_values: Cache,
563
+ attention_bias: torch.Tensor,
564
+ past_length: int,
565
+ ) -> Tuple[torch.Tensor, Cache]:
566
+ end = past_length + 1
567
+ decode_graph = self._get_depth_decode_graph(
568
+ next_input_ids,
569
+ past_length=past_length,
570
+ attention_bias=attention_bias,
571
+ )
572
+ cache_position = decode_graph.positions[past_length:end]
573
+ attention_bias_q = attention_bias[:, :, past_length:end, :end]
574
+
575
+ decode_graph.pre_graph.replay()
576
+
577
+ for layer_idx, post_graph in enumerate(decode_graph.post_graphs):
578
+ attn_context = self._run_depth_decode_attention_core(
579
+ layer_idx,
580
+ decode_graph.stages[layer_idx],
581
+ past_key_values=past_key_values,
582
+ attention_bias=attention_bias_q,
583
+ cache_position=cache_position,
584
+ cos=decode_graph.cos,
585
+ sin=decode_graph.sin,
586
+ )
587
+ post_graph.attn_context.copy_(attn_context)
588
+ post_graph.graph.replay()
589
+
590
+ return decode_graph.output, past_key_values
591
+
592
+
593
+ def _cuda_graph_tensor_signature(
594
+ tensor: Optional[torch.Tensor],
595
+ ) -> Optional[Tuple[Any, ...]]:
596
+ if tensor is None:
597
+ return None
598
+ return (
599
+ tuple(tensor.shape),
600
+ tuple(tensor.stride()),
601
+ str(tensor.dtype),
602
+ str(tensor.device),
603
+ )
604
+
605
+
606
+ def _cuda_graph_context_signature(context: Any) -> Tuple[Any, ...]:
607
+ sig = _cuda_graph_tensor_signature
608
+ return (
609
+ tuple((sig(k), sig(v)) for k, v in context.kv_contexts),
610
+ sig(context.cross_mask),
611
+ sig(context.self_mask),
612
+ sig(context.valid_action),
613
+ None
614
+ if context.rope_cache is None
615
+ else tuple(sig(t) for t in context.rope_cache),
616
+ )
617
+
618
+
619
+ def _cuda_graph_modulation_signature(modulations: Sequence[Any]) -> Tuple[Any, ...]:
620
+ sig = _cuda_graph_tensor_signature
621
+ return tuple(
622
+ (
623
+ sig(step.conditioning),
624
+ tuple(
625
+ tuple(sig(t) for t in block_modulation)
626
+ for block_modulation in step.block_modulations
627
+ ),
628
+ tuple(sig(t) for t in step.final_modulation),
629
+ )
630
+ for step in modulations
631
+ )
632
+
633
+
634
+ def _cuda_graph_key(inputs: _ActionFlowInputs, steps: int) -> Tuple[Any, ...]:
635
+ sig = _cuda_graph_tensor_signature
636
+ return (
637
+ sig(inputs.trajectory),
638
+ _cuda_graph_context_signature(inputs.context),
639
+ _cuda_graph_modulation_signature(inputs.modulations),
640
+ sig(inputs.action_dim_is_pad),
641
+ int(steps),
642
+ )
643
+
644
+
645
+ def _clone_static_tensor(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
646
+ if tensor is None:
647
+ return None
648
+ static = torch.empty_strided(
649
+ tuple(tensor.shape),
650
+ tuple(tensor.stride()),
651
+ device=tensor.device,
652
+ dtype=tensor.dtype,
653
+ )
654
+ static.copy_(tensor)
655
+ return static
656
+
657
+
658
+ def _clone_static_context(context: Any) -> Any:
659
+ rope_cache = None
660
+ if context.rope_cache is not None:
661
+ rope_cache = tuple(_clone_static_tensor(t) for t in context.rope_cache)
662
+ return context.__class__(
663
+ kv_contexts=tuple(
664
+ (_clone_static_tensor(k), _clone_static_tensor(v))
665
+ for k, v in context.kv_contexts
666
+ ),
667
+ cross_mask=_clone_static_tensor(context.cross_mask),
668
+ self_mask=_clone_static_tensor(context.self_mask),
669
+ valid_action=_clone_static_tensor(context.valid_action),
670
+ rope_cache=rope_cache,
671
+ )
672
+
673
+
674
+ def _clone_static_modulations(modulations: Sequence[Any]) -> Sequence[Any]:
675
+ return tuple(
676
+ step.__class__(
677
+ conditioning=_clone_static_tensor(step.conditioning),
678
+ block_modulations=tuple(
679
+ tuple(_clone_static_tensor(t) for t in block_modulation)
680
+ for block_modulation in step.block_modulations
681
+ ),
682
+ final_modulation=tuple(
683
+ _clone_static_tensor(t) for t in step.final_modulation
684
+ ),
685
+ )
686
+ for step in modulations
687
+ )
688
+
689
+
690
+ def _clone_static_inputs(inputs: _ActionFlowInputs) -> _ActionFlowInputs:
691
+ return _ActionFlowInputs(
692
+ trajectory=_clone_static_tensor(inputs.trajectory),
693
+ context=_clone_static_context(inputs.context),
694
+ modulations=_clone_static_modulations(inputs.modulations),
695
+ action_dim_is_pad=_clone_static_tensor(inputs.action_dim_is_pad),
696
+ )
697
+
698
+
699
+ def _copy_context_(dst: Any, src: Any) -> None:
700
+ for (dst_k, dst_v), (src_k, src_v) in zip(dst.kv_contexts, src.kv_contexts):
701
+ dst_k.copy_(src_k)
702
+ dst_v.copy_(src_v)
703
+ if src.cross_mask is not None:
704
+ dst.cross_mask.copy_(src.cross_mask)
705
+ if src.self_mask is not None:
706
+ dst.self_mask.copy_(src.self_mask)
707
+ if src.valid_action is not None:
708
+ dst.valid_action.copy_(src.valid_action)
709
+ if src.rope_cache is not None:
710
+ for dst_tensor, src_tensor in zip(dst.rope_cache, src.rope_cache):
711
+ dst_tensor.copy_(src_tensor)
712
+
713
+
714
+ def _copy_inputs_(dst: _ActionFlowInputs, src: _ActionFlowInputs) -> None:
715
+ dst.trajectory.copy_(src.trajectory)
716
+ _copy_context_(dst.context, src.context)
717
+ if src.action_dim_is_pad is not None:
718
+ dst.action_dim_is_pad.copy_(src.action_dim_is_pad)
719
+
720
+
721
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
722
+ x1 = x[..., : x.shape[-1] // 2]
723
+ x2 = x[..., x.shape[-1] // 2 :]
724
+ return torch.cat((-x2, x1), dim=-1)
725
+
726
+
727
+ def _apply_rotary_pos_emb(
728
+ q: torch.Tensor,
729
+ k: torch.Tensor,
730
+ cos: torch.Tensor,
731
+ sin: torch.Tensor,
732
+ unsqueeze_dim: int = 1,
733
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
734
+ cos = cos.unsqueeze(unsqueeze_dim)
735
+ sin = sin.unsqueeze(unsqueeze_dim)
736
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
737
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
738
+ return q_embed, k_embed
739
+
740
+
741
+ def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
742
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
743
+ if n_rep == 1:
744
+ return hidden_states
745
+ hidden_states = hidden_states[:, :, None, :, :].expand(
746
+ batch, num_key_value_heads, n_rep, slen, head_dim
747
+ )
748
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
749
+
750
+
751
+ def _capture_cuda_graph(
752
+ fn,
753
+ device: torch.device,
754
+ *,
755
+ after_warmup=None,
756
+ ) -> Tuple[torch.cuda.CUDAGraph, Any]:
757
+ warmup_stream = torch.cuda.Stream(device=device)
758
+ warmup_stream.wait_stream(torch.cuda.current_stream(device))
759
+ with torch.cuda.stream(warmup_stream):
760
+ fn()
761
+ torch.cuda.current_stream(device).wait_stream(warmup_stream)
762
+ if after_warmup is not None:
763
+ after_warmup()
764
+
765
+ graph = torch.cuda.CUDAGraph()
766
+ with torch.cuda.graph(graph):
767
+ output = fn()
768
+ return graph, output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07fe33e9759258a1c37b54e7a5f0d78c53a270dc8ddc7ecb139736ae1e9315da
3
+ size 4183452356
modeling_molmoact2.py ADDED
The diff for this file is too large to render. See raw diff
 
norm_stats.json ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "format": "molmoact2_norm_stats.v1",
3
+ "norm_mode": "q01_q99",
4
+ "metadata_by_tag": {
5
+ "libero": {
6
+ "action_key": "action",
7
+ "state_key": "observation.state",
8
+ "camera_keys": [
9
+ "observation.images.image",
10
+ "observation.images.wrist_image"
11
+ ],
12
+ "normalize_gripper": false,
13
+ "action_horizon": 10,
14
+ "n_action_steps": 10,
15
+ "setup_type": "single franka robotic arm in libero",
16
+ "control_mode": "delta end-effector pose",
17
+ "action_stats": {
18
+ "min": [
19
+ -0.9375,
20
+ -0.9375,
21
+ -0.9375,
22
+ -0.2582142949104309,
23
+ -0.375,
24
+ -0.3675000071525574,
25
+ -1.0
26
+ ],
27
+ "max": [
28
+ 0.9375,
29
+ 0.9375,
30
+ 0.9375,
31
+ 0.3557142913341522,
32
+ 0.375,
33
+ 0.375,
34
+ 1.0
35
+ ],
36
+ "mean": [
37
+ 0.06278156570450202,
38
+ 0.08684081017968912,
39
+ -0.09037305936952836,
40
+ 0.0005407430783705141,
41
+ 0.0056433796450358715,
42
+ -0.005229098518603562,
43
+ -0.04964072167678376
44
+ ],
45
+ "std": [
46
+ 0.3355237114945633,
47
+ 0.3784469867268323,
48
+ 0.44472859911256607,
49
+ 0.03924354049229973,
50
+ 0.06339296407444922,
51
+ 0.07797027713976648,
52
+ 0.9987671529022402
53
+ ],
54
+ "count": [
55
+ 273465.0
56
+ ],
57
+ "q01": [
58
+ -0.6792031928846481,
59
+ -0.7736573115323259,
60
+ -0.8728073904104404,
61
+ -0.10277447185825356,
62
+ -0.15509810617083444,
63
+ -0.20289961475228455,
64
+ -1.0
65
+ ],
66
+ "q10": [
67
+ -0.328718721971874,
68
+ -0.3626162647358338,
69
+ -0.6610056625361599,
70
+ -0.03907064459203904,
71
+ -0.06428551162168497,
72
+ -0.07928202560631951,
73
+ -1.0
74
+ ],
75
+ "q50": [
76
+ 0.015333975787982875,
77
+ 0.006437010746251905,
78
+ -0.07265095199149316,
79
+ -1.701317418858285e-05,
80
+ 0.00021801956089207239,
81
+ -5.852172701796134e-05,
82
+ -0.12287333595187695
83
+ ],
84
+ "q90": [
85
+ 0.5238177265233007,
86
+ 0.671417970219526,
87
+ 0.5384412174699407,
88
+ 0.040331002487738146,
89
+ 0.08240652401791884,
90
+ 0.0690125677722944,
91
+ 0.9999141552827842
92
+ ],
93
+ "q99": [
94
+ 0.8536542808794264,
95
+ 0.8637811051429717,
96
+ 0.9363295547540081,
97
+ 0.13045695485814487,
98
+ 0.18015313802054606,
99
+ 0.24129727661704234,
100
+ 0.9999914155282784
101
+ ],
102
+ "names": [
103
+ "x",
104
+ "y",
105
+ "z",
106
+ "roll",
107
+ "pitch",
108
+ "yaw",
109
+ "gripper"
110
+ ],
111
+ "mask": [
112
+ true,
113
+ true,
114
+ true,
115
+ true,
116
+ true,
117
+ true,
118
+ false
119
+ ]
120
+ },
121
+ "state_stats": {
122
+ "min": [
123
+ -0.4828203022480011,
124
+ -0.3255046010017395,
125
+ 0.008128180168569088,
126
+ 0.35277295112609863,
127
+ -3.641430377960205,
128
+ -1.842738389968872,
129
+ -0.0013586411951109767,
130
+ -0.042040832340717316
131
+ ],
132
+ "max": [
133
+ 0.21031762659549713,
134
+ 0.39128610491752625,
135
+ 1.3660105466842651,
136
+ 3.6714255809783936,
137
+ 3.560650587081909,
138
+ 1.386339545249939,
139
+ 0.04233968257904053,
140
+ 0.0013633022317662835
141
+ ],
142
+ "mean": [
143
+ -0.04651878279191748,
144
+ 0.034409066787269356,
145
+ 0.7645525031210381,
146
+ 2.9722094975655056,
147
+ -0.22046978549041713,
148
+ -0.1255794031738752,
149
+ 0.026914253269017054,
150
+ -0.027190783616938205
151
+ ],
152
+ "std": [
153
+ 0.10494395508556839,
154
+ 0.1517661933220375,
155
+ 0.378516707505034,
156
+ 0.34427344187858827,
157
+ 0.9069468516043042,
158
+ 0.32539190149967406,
159
+ 0.01417590382231912,
160
+ 0.014058894296088888
161
+ ],
162
+ "count": [
163
+ 273465.0
164
+ ],
165
+ "q01": [
166
+ -0.31479429659059555,
167
+ -0.26691552643710226,
168
+ 0.5194626050191016,
169
+ 2.159994551314992,
170
+ -1.801294177865994,
171
+ -0.8949778881389838,
172
+ 0.003382730811955442,
173
+ -0.04008920533069468
174
+ ],
175
+ "q10": [
176
+ -0.18409729127502492,
177
+ -0.158759498072202,
178
+ 0.5694822295083012,
179
+ 2.501970046458546,
180
+ -1.1889107640062022,
181
+ -0.5297043790093273,
182
+ 0.007573322430226042,
183
+ -0.039827946964434036
184
+ ],
185
+ "q50": [
186
+ -0.02822545357081922,
187
+ 0.029718887641213443,
188
+ 0.7185643731428462,
189
+ 3.0915725099012166,
190
+ -0.12491069931831773,
191
+ -0.08338984738533357,
192
+ 0.030648370056451133,
193
+ -0.031519123023466586
194
+ ],
195
+ "q90": [
196
+ 0.06725052913150302,
197
+ 0.23387160335018267,
198
+ 0.9599947530498419,
199
+ 3.1743361507512997,
200
+ 0.5456820212337484,
201
+ 0.20414514594693875,
202
+ 0.03985537019679712,
203
+ -0.008040434619037518
204
+ ],
205
+ "q99": [
206
+ 0.1222615490116252,
207
+ 0.3140223876046953,
208
+ 1.042961724319958,
209
+ 3.277638017923068,
210
+ 1.724488202195691,
211
+ 0.5659922739094448,
212
+ 0.04009682017699841,
213
+ -0.003493522538066522
214
+ ],
215
+ "names": [
216
+ "x",
217
+ "y",
218
+ "z",
219
+ "rx",
220
+ "ry",
221
+ "rz",
222
+ "rw",
223
+ "gripper"
224
+ ],
225
+ "mask": [
226
+ true,
227
+ true,
228
+ true,
229
+ true,
230
+ true,
231
+ true,
232
+ true,
233
+ false
234
+ ]
235
+ }
236
+ }
237
+ }
238
+ }
processing_molmoact2.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for MolmoAct2.
3
+ """
4
+ from typing import Optional, Union
5
+ import dataclasses
6
+
7
+ import numpy as np
8
+
9
+ from transformers.image_utils import ImageInput
10
+ from transformers.video_utils import VideoInput
11
+ from transformers.processing_utils import (
12
+ Unpack,
13
+ ProcessingKwargs,
14
+ ProcessorMixin,
15
+ )
16
+ from transformers.feature_extraction_utils import BatchFeature
17
+ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
18
+ from transformers.utils import logging
19
+
20
+ from transformers import AutoTokenizer
21
+ from .image_processing_molmoact2 import MolmoAct2ImagesKwargs, MolmoAct2ImageProcessor
22
+ from .video_processing_molmoact2 import MolmoAct2VideoProcessorKwargs, MolmoAct2VideoProcessor
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ # Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
29
+ IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
30
+ IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
31
+ IM_START_TOKEN = f"<im_start>"
32
+ LOW_RES_IMAGE_START_TOKEN = f"<low_res_im_start>"
33
+ FRAME_START_TOKEN = f"<frame_start>"
34
+ IM_END_TOKEN = f"<im_end>"
35
+ FRAME_END_TOKEN= f"<frame_end>"
36
+ IM_COL_TOKEN = f"<im_col>"
37
+ IMAGE_PROMPT = "<|image|>"
38
+ VIDEO_PROMPT = "<|video|>"
39
+
40
+ IMAGE_TOKENS = [
41
+ IMAGE_PATCH_TOKEN,
42
+ IM_COL_TOKEN,
43
+ IM_START_TOKEN,
44
+ LOW_RES_IMAGE_START_TOKEN,
45
+ FRAME_START_TOKEN,
46
+ IM_END_TOKEN,
47
+ FRAME_END_TOKEN,
48
+ IMAGE_LOW_RES_TOKEN,
49
+ ]
50
+
51
+
52
+ class MolmoAct2ProcessorKwargs(ProcessingKwargs, total=False):
53
+ """MolmoAct2 processor kwargs"""
54
+ images_kwargs: MolmoAct2ImagesKwargs
55
+ videos_kwargs: MolmoAct2VideoProcessorKwargs
56
+ _defaults = {
57
+ "text_kwargs": {
58
+ "padding": False,
59
+ "return_mm_token_type_ids": True,
60
+ },
61
+ "videos_kwargs": {"return_metadata": True},
62
+ }
63
+
64
+
65
+ class MolmoAct2Processor(ProcessorMixin):
66
+ attributes = ["image_processor", "video_processor", "tokenizer"]
67
+ optional_attributes = [
68
+ "chat_template",
69
+ "time_mode",
70
+ "image_use_col_tokens",
71
+ "use_single_crop_col_tokens",
72
+ "use_single_crop_start_token",
73
+ "video_use_col_tokens",
74
+ "use_frame_special_tokens",
75
+ ]
76
+ image_processor_class = "AutoImageProcessor"
77
+ video_processor_class = "AutoVideoProcessor"
78
+ tokenizer_class = "AutoTokenizer"
79
+
80
+ def __init__(
81
+ self,
82
+ image_processor: MolmoAct2ImageProcessor = None,
83
+ video_processor: MolmoAct2VideoProcessor = None,
84
+ tokenizer: AutoTokenizer = None,
85
+ chat_template: Optional[str] = None,
86
+ image_use_col_tokens: Optional[bool] = True,
87
+ use_single_crop_col_tokens: Optional[bool] = None,
88
+ use_single_crop_start_token: Optional[bool] = True,
89
+ video_use_col_tokens: Optional[bool] = False,
90
+ use_frame_special_tokens: Optional[bool] = True,
91
+ **kwargs
92
+ ) -> None:
93
+ super().__init__(
94
+ image_processor,
95
+ video_processor,
96
+ tokenizer,
97
+ chat_template=chat_template,
98
+ )
99
+ self.image_use_col_tokens = image_use_col_tokens
100
+ self.use_single_crop_col_tokens = use_single_crop_col_tokens
101
+ self.use_single_crop_start_token = use_single_crop_start_token
102
+ self.video_use_col_tokens = video_use_col_tokens
103
+ self.use_frame_special_tokens = use_frame_special_tokens
104
+
105
+ self.image_placeholder_token = IMAGE_PROMPT
106
+ self.video_placeholder_token = VIDEO_PROMPT
107
+ self.image_token_ids = [
108
+ tokenizer.convert_tokens_to_ids(token)
109
+ for token in IMAGE_TOKENS
110
+ ]
111
+
112
+ def get_image_tokens(self, image_grid: np.ndarray):
113
+ resized_h, resized_w, height, width = image_grid
114
+ if int(height) == 0 or int(width) == 0:
115
+ per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
116
+ use_single_crop_col_tokens = (
117
+ self.image_use_col_tokens
118
+ if self.use_single_crop_col_tokens is None
119
+ else self.use_single_crop_col_tokens
120
+ )
121
+ if use_single_crop_col_tokens:
122
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
123
+ joint = [
124
+ [IM_START_TOKEN],
125
+ np.tile(per_row, [resized_h]),
126
+ [IM_END_TOKEN],
127
+ ]
128
+ return np.concatenate(joint)
129
+ per_row = np.full(width, IMAGE_PATCH_TOKEN)
130
+ if self.image_use_col_tokens:
131
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
132
+ joint = [
133
+ [IM_START_TOKEN],
134
+ np.tile(per_row, [height]),
135
+ [IM_END_TOKEN],
136
+ ]
137
+ per_row = np.full(resized_w, IMAGE_PATCH_TOKEN)
138
+ use_single_crop_col_tokens = (
139
+ self.image_use_col_tokens
140
+ if self.use_single_crop_col_tokens is None
141
+ else self.use_single_crop_col_tokens
142
+ )
143
+ image_start_token = (
144
+ LOW_RES_IMAGE_START_TOKEN
145
+ if self.use_single_crop_start_token
146
+ else IM_START_TOKEN
147
+ )
148
+ if use_single_crop_col_tokens:
149
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
150
+ joint = [
151
+ [image_start_token],
152
+ np.tile(per_row, [resized_h]),
153
+ [IM_END_TOKEN],
154
+ ] + joint
155
+
156
+ return np.concatenate(joint)
157
+
158
+ def get_video_string(
159
+ self,
160
+ video_grid: np.ndarray,
161
+ timestamps: np.ndarray,
162
+ ):
163
+ if self.use_frame_special_tokens:
164
+ start_token_id = FRAME_START_TOKEN
165
+ end_token_id = FRAME_END_TOKEN
166
+ else:
167
+ start_token_id = IM_START_TOKEN
168
+ end_token_id = IM_END_TOKEN
169
+
170
+ num_frames, h, w = video_grid
171
+ video_string: str = ""
172
+ for frame_idx, frame_time in enumerate(timestamps):
173
+ # `per-frame-compact` time mode
174
+ prev_space = " " if frame_idx > 0 else ""
175
+ frame_prefix = prev_space + f"{frame_time:.1f} " # explicit whitespace before/after image tokens
176
+
177
+ video_string += frame_prefix
178
+ per_row = np.full(w, IMAGE_PATCH_TOKEN)
179
+ if self.video_use_col_tokens:
180
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
181
+ extra_tokens = np.tile(per_row, [h])
182
+ video_tokens = [
183
+ [start_token_id],
184
+ extra_tokens,
185
+ [end_token_id],
186
+ ]
187
+ video_string += "".join(np.concatenate(video_tokens, 0))
188
+
189
+ return video_string
190
+
191
+ def insert_bos(
192
+ self,
193
+ input_ids: np.ndarray,
194
+ attention_mask: np.ndarray,
195
+ bos_token_id: int,
196
+ pad_token_id: int,
197
+ ):
198
+ """
199
+ Args:
200
+ input_ids: [B, S] array with left padding
201
+ attention_mask: [B, S] array (0 for pad, 1 for valid)
202
+ bos_token_id: int
203
+ pad_token_id: int
204
+ Returns:
205
+ input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
206
+ attention_mask_out: same shape as input_ids_out
207
+ """
208
+
209
+ need_to_expand = len(input_ids.shape) == 1
210
+ if need_to_expand:
211
+ input_ids = input_ids[None, :]
212
+ attention_mask = attention_mask[None, :]
213
+
214
+ B, S = input_ids.shape
215
+
216
+ # Handle zero-length sequence
217
+ if S == 0:
218
+ new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
219
+ new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
220
+ if need_to_expand:
221
+ new_input_ids = new_input_ids[0]
222
+ new_attention_mask = new_attention_mask[0]
223
+ return new_input_ids, new_attention_mask
224
+
225
+ first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
226
+ bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
227
+
228
+ if bos_already_present:
229
+ if need_to_expand:
230
+ input_ids = input_ids[0]
231
+ attention_mask = attention_mask[0]
232
+ return input_ids, attention_mask
233
+ else:
234
+ new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
235
+ new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
236
+
237
+ src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
238
+ valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
239
+ tgt_idx = src_idx + 1 # shit right
240
+ batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
241
+
242
+ # flatten valid_positions
243
+ flat_vals = input_ids[valid_mask]
244
+ flat_batch = batch_idx[valid_mask]
245
+ flat_tgt = tgt_idx[valid_mask]
246
+
247
+ new_input_ids[flat_batch, flat_tgt] = flat_vals
248
+ new_attention_mask[flat_batch, flat_tgt] = 1
249
+
250
+ insert_pos = first_valid_index
251
+ new_input_ids[np.arange(B), insert_pos] = bos_token_id
252
+ new_attention_mask[np.arange(B), insert_pos] = 1
253
+
254
+ if need_to_expand:
255
+ new_input_ids = new_input_ids[0]
256
+ new_attention_mask = new_attention_mask[0]
257
+
258
+ return new_input_ids, new_attention_mask
259
+
260
+ def __call__(
261
+ self,
262
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
263
+ images: ImageInput = None,
264
+ videos: VideoInput = None,
265
+ **kwargs: Unpack[MolmoAct2ProcessorKwargs],
266
+ ) -> BatchFeature:
267
+ """
268
+
269
+ Args:
270
+ text (`str`, `list[str]`, `list[list[str]]`):
271
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
272
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
273
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
274
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
275
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
276
+ tensor. Both channels-first and channels-last formats are supported.
277
+ videos (`dict[str, Any]` or `list[dict[str, Any]]`):
278
+ The video or batch of videos to be prepared. Each video can be a dictionary with the following keys:
279
+ - `"frames"`: `np.ndarray` of shape (T, H, W, 3)
280
+ - `"timestamps"`: `np.ndarray` of shape (T,)
281
+ - `"sampled_fps"`: `float` (optional)
282
+ - `"sampling_augmentation"`: `str` (optional)
283
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
284
+ If set, will return tensors of a particular framework. Acceptable values are:
285
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
286
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
287
+ - `'np'`: Return NumPy `np.ndarray` objects.
288
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
289
+
290
+ Returns:
291
+ `BatchFeature`: A [`BatchFeature`] with the following fields:
292
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
293
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
294
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`).
295
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
296
+ - **image_token_pooling** -- Indices of the patches in `image_grids` to pool for each token in `image_tokens`.
297
+ Returned when `images` is not `None`.
298
+ - **image_grids** -- Grids of images. Returned when `images` is not `None`.
299
+ - **image_num_crops** -- Number of crops for each image. Returned when `images` is not `None`.
300
+ - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
301
+ - **video_token_pooling** -- Indices of the patches in `video_grids` to pool for each token in `video_tokens`.
302
+ Returned when `videos` is not `None`.
303
+ - **video_grids** -- Grids of videos. Returned when `videos` is not `None`.
304
+ """
305
+
306
+ output_kwargs = self._merge_kwargs(
307
+ MolmoAct2ProcessorKwargs,
308
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
309
+ **kwargs,
310
+ )
311
+
312
+ if images is not None:
313
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
314
+ image_grids = image_inputs["image_grids"]
315
+ else:
316
+ image_inputs = {}
317
+ image_grids = None
318
+
319
+ if videos is not None:
320
+ videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
321
+ video_grids = videos_inputs["video_grids"]
322
+ # If user has not requested video metadata, pop it
323
+ if "return_metadata" not in kwargs:
324
+ video_metadata = videos_inputs.pop("video_metadata")
325
+ else:
326
+ video_metadata = videos_inputs["video_metadata"]
327
+ else:
328
+ videos_inputs = {}
329
+ video_grids = None
330
+
331
+ if not isinstance(text, list):
332
+ text = [text]
333
+
334
+ text = text.copy() # below lines change text in-place
335
+
336
+ if image_grids is not None:
337
+ index = 0
338
+ for i in range(len(text)):
339
+ num_images = text[i].count(self.image_placeholder_token)
340
+ image_grids_i = image_grids[index:index+num_images]
341
+ for image_grid in image_grids_i:
342
+ image_tokens = self.get_image_tokens(image_grid)
343
+ image_string = "".join(image_tokens)
344
+ text[i] = text[i].replace(self.image_placeholder_token, image_string, 1)
345
+ index += num_images
346
+
347
+ if video_grids is not None:
348
+ index = 0
349
+ for i in range(len(text)):
350
+ num_videos = text[i].count(self.video_placeholder_token)
351
+ assert num_videos in {0, 1}, "At most one video is supported for now"
352
+ video_grids_i = video_grids[index:index+num_videos]
353
+ metadata_i = video_metadata[index:index+num_videos]
354
+ for video_grid, metadata in zip(video_grids_i, metadata_i):
355
+ video_string = self.get_video_string(
356
+ video_grid,
357
+ metadata.timestamps,
358
+ )
359
+ text[i] = text[i].replace(self.video_placeholder_token, video_string, 1)
360
+ index += num_videos
361
+
362
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
363
+ return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
364
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
365
+
366
+ input_ids = text_inputs["input_ids"]
367
+ attention_mask = text_inputs["attention_mask"]
368
+
369
+ input_ids = np.array(input_ids)
370
+ attention_mask = np.array(attention_mask)
371
+
372
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
373
+ input_ids, attention_mask = self.insert_bos(
374
+ input_ids, attention_mask, bos, self.tokenizer.pad_token_id
375
+ )
376
+
377
+ if return_mm_token_type_ids:
378
+ image_tokens = np.array(self.image_token_ids).astype(input_ids.dtype)
379
+ token_type_ids = np.any(input_ids[:, :, None] == image_tokens[None, None, :], axis=-1)
380
+ text_inputs["token_type_ids"] = token_type_ids.tolist()
381
+
382
+ text_inputs["input_ids"] = input_ids.tolist()
383
+ text_inputs["attention_mask"] = attention_mask.tolist()
384
+
385
+ return BatchFeature(
386
+ data={**text_inputs, **image_inputs, **videos_inputs},
387
+ tensor_type=return_tensors,
388
+ )
389
+
390
+ def post_process_image_text_to_text(
391
+ self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
392
+ ):
393
+ """
394
+ Post-process the output of the model to decode the text.
395
+
396
+ Args:
397
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
398
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
399
+ or `(sequence_length,)`.
400
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
401
+ Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
402
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
403
+ Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
404
+ **kwargs:
405
+ Additional arguments to be passed to the tokenizer's `batch_decode method`.
406
+
407
+ Returns:
408
+ `list[str]`: The decoded text.
409
+ """
410
+ return self.tokenizer.batch_decode(
411
+ generated_outputs,
412
+ skip_special_tokens=skip_special_tokens,
413
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
414
+ **kwargs,
415
+ )
416
+
417
+
418
+ MolmoAct2Processor.register_for_auto_class()
processor_config.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
4
+ },
5
+ "image_processor": {
6
+ "auto_map": {
7
+ "AutoImageProcessor": "image_processing_molmoact2.MolmoAct2ImageProcessor",
8
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
9
+ },
10
+ "crop_mode": "resize",
11
+ "do_convert_rgb": true,
12
+ "image_mean": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "image_processor_type": "MolmoAct2ImageProcessor",
18
+ "image_std": [
19
+ 0.5,
20
+ 0.5,
21
+ 0.5
22
+ ],
23
+ "max_crops": 8,
24
+ "overlap_margins": [
25
+ 4,
26
+ 4
27
+ ],
28
+ "patch_size": 14,
29
+ "pooling_size": [
30
+ 2,
31
+ 2
32
+ ],
33
+ "resample": 2,
34
+ "size": {
35
+ "height": 378,
36
+ "width": 378
37
+ }
38
+ },
39
+ "image_use_col_tokens": true,
40
+ "processor_class": "MolmoAct2Processor",
41
+ "use_frame_special_tokens": true,
42
+ "use_single_crop_col_tokens": false,
43
+ "use_single_crop_start_token": true,
44
+ "video_processor": {
45
+ "auto_map": {
46
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor",
47
+ "AutoVideoProcessor": "video_processing_molmoact2.MolmoAct2VideoProcessor"
48
+ },
49
+ "data_format": "channels_first",
50
+ "default_to_square": true,
51
+ "do_convert_rgb": true,
52
+ "do_normalize": true,
53
+ "do_rescale": true,
54
+ "do_resize": true,
55
+ "do_sample_frames": true,
56
+ "frame_sample_mode": "uniform_last_frame",
57
+ "image_mean": [
58
+ 0.5,
59
+ 0.5,
60
+ 0.5
61
+ ],
62
+ "image_std": [
63
+ 0.5,
64
+ 0.5,
65
+ 0.5
66
+ ],
67
+ "max_fps": 2.0,
68
+ "num_frames": 8,
69
+ "patch_size": 14,
70
+ "pooling_size": [
71
+ 3,
72
+ 3
73
+ ],
74
+ "resample": 2,
75
+ "rescale_factor": 0.00392156862745098,
76
+ "return_metadata": false,
77
+ "sampling_fps": 2,
78
+ "size": {
79
+ "height": 378,
80
+ "width": 378
81
+ },
82
+ "video_processor_type": "MolmoAct2VideoProcessor"
83
+ },
84
+ "video_use_col_tokens": false
85
+ }
quantization_metadata.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "source_repo": "allenai/MolmoAct2-LIBERO",
3
+ "source_revision": "0d24a92bd1faf321ef497c3bbd5681af97c65aa2",
4
+ "policy_class": "transformers:AutoModelForImageTextToText",
5
+ "quantization": {
6
+ "scheme": "nf4",
7
+ "backend": "bitsandbytes",
8
+ "compute_dtype": "bfloat16",
9
+ "min_params_to_quantize": 4000000,
10
+ "rule": "Linear modules with >=4_000_000 weight elements rewritten to bnb.nn.Linear4bit; smaller heads kept in compute_dtype (bfloat16).",
11
+ "runtime_status": "loader-backed (install_prequantized_linears)"
12
+ },
13
+ "dropped_state_entries": []
14
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5395aefc9b1b7f0385d8c86a2f1775e5af81bdfbf9f2d97827ea37921d9f862
3
+ size 11983605
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_molmoact2.MolmoAct2Processor"
5
+ },
6
+ "backend": "tokenizers",
7
+ "bos_token": "<|im_end|>",
8
+ "clean_up_tokenization_spaces": false,
9
+ "eos_token": "<|im_end|>",
10
+ "errors": "replace",
11
+ "extra_special_tokens": [
12
+ "<im_start>",
13
+ "<im_end>",
14
+ "<im_patch>",
15
+ "<im_col>",
16
+ "<low_res_im_start>",
17
+ "<|image|>",
18
+ "<im_low>",
19
+ "<frame_start>",
20
+ "<frame_end>",
21
+ "<|video|>",
22
+ "<|points|>",
23
+ "<|token_index|>",
24
+ "<|vit_index|>",
25
+ "<|vit_loc|>"
26
+ ],
27
+ "is_local": false,
28
+ "model_max_length": 1010000,
29
+ "pad_token": "<|endoftext|>",
30
+ "processor_class": "MolmoAct2Processor",
31
+ "split_special_tokens": false,
32
+ "tokenizer_class": "Qwen2Tokenizer",
33
+ "unk_token": null
34
+ }