ibrahimkettaneh commited on
Commit
81386ff
·
verified ·
1 Parent(s): de08b8e

Upload vision_encoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vision_encoder.py +452 -0
vision_encoder.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers.activations import ACT2FN
7
+
8
+
9
+ from .configuration_step3p7 import StepRoboticsVisionEncoderConfig
10
+
11
+
12
+
13
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
14
+ """Rotate last dimension halves (used by RoPE)."""
15
+ x = x.reshape(*x.shape[:-1], -1, 2)
16
+ x1, x2 = x.unbind(dim=-1)
17
+ x = torch.stack((-x2, x1), dim=-1)
18
+ return x.reshape(*x.shape[:-2], -1)
19
+
20
+
21
+ def apply_rotary_emb(freqs: torch.Tensor,
22
+ t: torch.Tensor,
23
+ start_index: int = 0,
24
+ scale: float = 1.0,
25
+ seq_dim: int = -2) -> torch.Tensor:
26
+ """Apply 2D rotary embeddings to queries / keys."""
27
+ dtype = t.dtype
28
+
29
+ if t.ndim == 3:
30
+ seq_len = t.shape[seq_dim]
31
+ freqs = freqs[-seq_len:]
32
+
33
+ rot_dim = freqs.shape[-1]
34
+ end_index = start_index + rot_dim
35
+ assert rot_dim <= t.shape[-1], (
36
+ f"feature dimension {t.shape[-1]} is too small for rot_dim {rot_dim}")
37
+
38
+ t_left, t, t_right = (
39
+ t[..., :start_index],
40
+ t[..., start_index:end_index],
41
+ t[..., end_index:],
42
+ )
43
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
44
+ out = torch.cat((t_left, t, t_right), dim=-1)
45
+ return out.type(dtype)
46
+
47
+
48
+ class EncoderRope2D(nn.Module):
49
+ """Cacheable 2D rotary positional embedding."""
50
+
51
+ def __init__(
52
+ self,
53
+ dim: int,
54
+ max_grid_height: int,
55
+ max_grid_width: int,
56
+ use_cls_token: bool = False,
57
+ theta: Union[int, float] = 10000,
58
+ max_freq: int = 10,
59
+ num_freqs: int = 1,
60
+ theta_rescale_factor: float = 1.0,
61
+ ):
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.max_grid_height = max_grid_height
65
+ self.max_grid_width = max_grid_width
66
+ self.use_cls_token = use_cls_token
67
+ self.theta = theta * theta_rescale_factor**(dim / (dim - 2))
68
+ self.max_freq = max_freq
69
+ self.num_freqs = num_freqs
70
+ cache = self._compute_2d_freqs()
71
+ self.register_buffer("freqs_cache", cache, persistent=False)
72
+
73
+ def _compute_inv_freq(self, base: Union[int, float],
74
+ dim: int) -> torch.Tensor:
75
+
76
+ freqs = 1.0 / (base**(
77
+ torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
78
+ return freqs
79
+
80
+ def _compute_freqs(self, t: torch.Tensor, inv_freq: torch.Tensor):
81
+ freqs = torch.einsum("..., f -> ... f", t.type(inv_freq.dtype),
82
+ inv_freq)
83
+ freqs = freqs.repeat_interleave(2, dim=-1)
84
+ return freqs
85
+
86
+ def _compute_2d_freqs(self) -> torch.Tensor:
87
+ grid_h_range = torch.arange(self.max_grid_height, dtype=torch.float)
88
+ grid_w_range = torch.arange(self.max_grid_width, dtype=torch.float)
89
+ if self.use_cls_token:
90
+ grid_h_range += 1
91
+ grid_w_range += 1
92
+ inv_freq = self._compute_inv_freq(self.theta, self.dim // 2)
93
+ freqs_h = self._compute_freqs(grid_h_range, inv_freq)[:, None].expand(
94
+ self.max_grid_height, self.max_grid_width, -1)
95
+ freqs_w = self._compute_freqs(grid_w_range, inv_freq)[None, :].expand(
96
+ self.max_grid_height, self.max_grid_width, -1)
97
+ freqs = torch.cat([freqs_w, freqs_h], dim=-1).reshape(
98
+ self.max_grid_height * self.max_grid_width, -1)
99
+ if self.use_cls_token:
100
+ freqs = torch.cat([torch.zeros(1, freqs.shape[-1]), freqs], dim=0)
101
+ freqs = freqs[None, None, ...]
102
+ return freqs
103
+
104
+ def forward(self, q: torch.Tensor, k: torch.Tensor,
105
+ grid_hw: tuple[int, int]):
106
+ # If grid matches cached shape we reuse directly to avoid recomputation.
107
+ if grid_hw[0] != self.max_grid_height or grid_hw[1] != self.max_grid_width:
108
+ rows = torch.arange(grid_hw[0], device=q.device).view(-1, 1)
109
+ cols = torch.arange(grid_hw[1], device=q.device).view(1, -1)
110
+ positions = (rows * self.max_grid_width + cols).reshape(-1).to(
111
+ torch.long)
112
+ if self.use_cls_token:
113
+ positions = torch.cat(
114
+ [torch.zeros(1, device=q.device), positions + 1], dim=0)
115
+ freqs = self.freqs_cache.index_select(2, positions)
116
+ else:
117
+ freqs = self.freqs_cache
118
+ q = apply_rotary_emb(freqs, q)
119
+ k = apply_rotary_emb(freqs, k)
120
+ return q, k
121
+
122
+
123
+ class EncoderLayerScale(nn.Module):
124
+ """Per-channel residual scaling used when ls_init_value is set."""
125
+
126
+ def __init__(self, dim: int, init_values: float):
127
+ super().__init__()
128
+ self.gamma = nn.Parameter(torch.full((dim,), init_values))
129
+
130
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # (B, L, D)
131
+ return hidden_states * self.gamma
132
+
133
+
134
+ class EncoderMLP(nn.Module):
135
+ """Feed-forward network used inside each transformer block."""
136
+
137
+ def __init__(self, hidden_size: int, intermediate_size: int,
138
+ hidden_act: str):
139
+ super().__init__()
140
+ self.c_fc = nn.Linear(hidden_size, intermediate_size, bias=True)
141
+ self.act_fn = ACT2FN[hidden_act]
142
+ self.c_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
143
+
144
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
+
146
+ hidden_states = self.c_proj(self.act_fn(self.c_fc(hidden_states)))
147
+ return hidden_states
148
+
149
+
150
+ class EncoderVisionAttention(nn.Module):
151
+ """Multi-head self attention with optional 2D RoPE."""
152
+
153
+ def __init__(
154
+ self,
155
+ hidden_size: int,
156
+ num_heads: int,
157
+ max_grid_height: int,
158
+ max_grid_width: int,
159
+ use_cls_token: bool = False,
160
+ use_rope2d: bool = True,
161
+ rope_theta: Union[int, float] = 10000,
162
+ rope_max_freq: int = 10,
163
+ rope_num_freqs: int = 1,
164
+ rope_theta_rescale_factor: float = 1.0,
165
+ rope_freqs_for: Literal["lang", "pixel", "constant"] = "lang",
166
+ ):
167
+ super().__init__()
168
+ if hidden_size % num_heads != 0:
169
+ raise ValueError(
170
+ f"hidden_size ({hidden_size}) must be divisible by num_heads ({num_heads})."
171
+ )
172
+ self.num_heads = num_heads
173
+ self.head_dim = hidden_size // num_heads
174
+ self.scale = self.head_dim**-0.5
175
+ self.in_proj_weight = nn.Parameter(torch.zeros(hidden_size * 3, hidden_size))
176
+ self.in_proj_bias = nn.Parameter(torch.zeros(hidden_size * 3))
177
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
178
+
179
+ self.rope = None
180
+ if use_rope2d:
181
+ self.rope = EncoderRope2D(
182
+ dim=self.head_dim,
183
+ max_grid_height=max_grid_height,
184
+ max_grid_width=max_grid_width,
185
+ use_cls_token=use_cls_token,
186
+ theta=rope_theta,
187
+ max_freq=rope_max_freq,
188
+ num_freqs=rope_num_freqs,
189
+ theta_rescale_factor=rope_theta_rescale_factor,
190
+ )
191
+
192
+ def forward(self, hidden_states: torch.Tensor, grid_hw: tuple[int, int]) -> torch.Tensor:
193
+ bsz, seq_len, _ = hidden_states.shape
194
+ qkv = F.linear(
195
+ hidden_states,
196
+ self.in_proj_weight,
197
+ self.in_proj_bias,
198
+ )
199
+ q, k, v = qkv.chunk(3, dim=-1)
200
+
201
+ q = q.view(bsz, seq_len, self.num_heads,
202
+ self.head_dim).transpose(1, 2)
203
+ k = k.view(bsz, seq_len, self.num_heads,
204
+ self.head_dim).transpose(1, 2)
205
+ if self.rope is not None:
206
+ q, k = self.rope(q, k, grid_hw=grid_hw)
207
+ v = v.view(bsz, seq_len, self.num_heads,
208
+ self.head_dim).transpose(1, 2)
209
+
210
+ attn_output = F.scaled_dot_product_attention(
211
+ q, k, v, is_causal=False, scale=self.scale)
212
+ attn_output = attn_output.transpose(1, 2).reshape(
213
+ bsz, seq_len, self.num_heads * self.head_dim)
214
+ return self.out_proj(attn_output)
215
+
216
+
217
+ class EncoderVisionBlock(nn.Module):
218
+ """A single Vision Transformer block (self-attention + MLP)."""
219
+
220
+ def __init__(
221
+ self,
222
+ hidden_size: int,
223
+ num_heads: int,
224
+ mlp_ratio: float,
225
+ hidden_act: str,
226
+ layer_norm_eps: float,
227
+ ls_init_value: Optional[float] = None,
228
+ max_grid_height: Optional[int] = None,
229
+ max_grid_width: Optional[int] = None,
230
+ use_cls_token: bool = False,
231
+ use_rope2d: bool = True,
232
+ rope_kwargs: Optional[dict] = None,
233
+ ):
234
+ super().__init__()
235
+ rope_kwargs = rope_kwargs or {}
236
+ self.attn = EncoderVisionAttention(
237
+ hidden_size,
238
+ num_heads,
239
+ max_grid_height=max_grid_height,
240
+ max_grid_width=max_grid_width,
241
+ use_cls_token=use_cls_token,
242
+ use_rope2d=use_rope2d,
243
+ **rope_kwargs,
244
+ )
245
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
246
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
247
+
248
+ intermediate = int(hidden_size * mlp_ratio)
249
+ self.mlp = EncoderMLP(hidden_size, intermediate, hidden_act)
250
+
251
+ self.ls_1 = EncoderLayerScale(hidden_size, ls_init_value)
252
+ self.ls_2 = EncoderLayerScale(hidden_size, ls_init_value)
253
+
254
+ def forward(self, hidden_states: torch.Tensor,
255
+ grid_hw: tuple[int, int]) -> torch.Tensor:
256
+ # breakpoint()
257
+ residual = hidden_states
258
+ hidden_states = self.ln_1(hidden_states)
259
+ hidden_states = self.attn(hidden_states, grid_hw=grid_hw)
260
+ hidden_states = residual + self.ls_1(hidden_states)
261
+
262
+ residual = hidden_states
263
+ hidden_states = self.ln_2(hidden_states)
264
+ hidden_states = self.mlp(hidden_states)
265
+ hidden_states = residual + self.ls_2(hidden_states)
266
+ return hidden_states
267
+
268
+
269
+ class EncoderVisionTransformer(nn.Module):
270
+ """Stack of encoder blocks parameterised by Step35VisionEncoderConfig."""
271
+
272
+ def __init__(
273
+ self,
274
+ embed_dim: int,
275
+ depth: int,
276
+ num_heads: int,
277
+ mlp_ratio: float,
278
+ hidden_act: str,
279
+ layer_norm_eps: float,
280
+ ls_init_value: Optional[float] = None,
281
+ max_grid_height: Optional[int] = None,
282
+ max_grid_width: Optional[int] = None,
283
+ use_cls_token: bool = False,
284
+ use_rope2d: bool = True,
285
+ rope_kwargs: Optional[dict] = None,
286
+ ):
287
+ super().__init__()
288
+ self.layers = depth
289
+ rope_kwargs = rope_kwargs or {}
290
+ self.resblocks = nn.ModuleList([
291
+ EncoderVisionBlock(embed_dim, num_heads, mlp_ratio, hidden_act,
292
+ layer_norm_eps,
293
+ max_grid_height=max_grid_height,
294
+ max_grid_width=max_grid_width,
295
+ use_cls_token=use_cls_token,
296
+ use_rope2d=use_rope2d,
297
+ ls_init_value=ls_init_value,
298
+ rope_kwargs=rope_kwargs)
299
+ for _ in range(depth)
300
+ ])
301
+
302
+ def forward(self,
303
+ hidden_states: torch.Tensor,
304
+ grid_hw: tuple[int, int]) -> torch.Tensor:
305
+ for block in self.resblocks:
306
+ hidden_states = block(hidden_states, grid_hw=grid_hw)
307
+ return hidden_states
308
+
309
+
310
+ class StepRoboticsVisionEncoder(nn.Module):
311
+ """
312
+ Vision encoder built from StepRoboticsVisionEncoderConfig.
313
+
314
+ The encoder performs patch embedding followed by a stack of transformer
315
+ blocks. Only the config fields defined in StepRoboticsVisionEncoderConfig (and
316
+ StepRoboticVLConfig.vision_config) are expected.
317
+ """
318
+
319
+ def __init__(self, config: StepRoboticsVisionEncoderConfig):
320
+ super().__init__()
321
+ self.config = config
322
+
323
+ # Align commonly used attributes so downstream code (e.g. StepRoboticVL)
324
+ # can access them without extra renaming.
325
+ self.hidden_size = config.width
326
+ self.num_heads = config.heads
327
+ self.num_hidden_layers = config.layers
328
+ self.patch_size = config.patch_size
329
+ self.image_size = config.image_size
330
+ self.use_cls_token = getattr(config, "use_cls_token", False)
331
+ self.use_rope2d = getattr(config, "use_rope2d", True)
332
+ self.use_abs_posemb = getattr(config, "use_abs_posemb", True)
333
+ self.layer_norm_eps = config.layer_norm_eps
334
+ self.mlp_ratio = getattr(config, "mlp_ratio", 8960 / 1536)
335
+ self.ls_init_value = getattr(config, "ls_init_value", None)
336
+ self.hidden_act = config.hidden_act
337
+ self.use_ln_pre = getattr(config, "use_ln_pre", False)
338
+ self.use_ln_post = getattr(config, "use_ln_post", True)
339
+
340
+ # Patch embedding.
341
+ self.conv1 = nn.Conv2d(in_channels=config.num_channels,
342
+ out_channels=self.hidden_size,
343
+ kernel_size=self.patch_size,
344
+ stride=self.patch_size,
345
+ bias=False)
346
+
347
+ self.ln_pre = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_pre else nn.Identity()
348
+ self.ln_post = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) if self.use_ln_post else nn.Identity()
349
+
350
+ grid_size = self.image_size // self.patch_size
351
+ self.base_grid = (grid_size, grid_size)
352
+
353
+ if self.use_cls_token:
354
+ self.class_embedding = nn.Parameter(
355
+ torch.randn(self.hidden_size) * (self.hidden_size**-0.5))
356
+ else:
357
+ self.class_embedding = None
358
+
359
+ if self.use_abs_posemb:
360
+ self.posemb_grid_size = self.image_size // self.patch_size
361
+ self.positional_embedding = nn.Parameter(
362
+ (self.hidden_size**-0.5) * torch.randn(
363
+ int(self.use_cls_token) + self.posemb_grid_size**2,
364
+ self.hidden_size,
365
+ ))
366
+
367
+ self.transformer = EncoderVisionTransformer(
368
+ embed_dim=self.hidden_size,
369
+ depth=self.num_hidden_layers,
370
+ num_heads=self.num_heads,
371
+ mlp_ratio=self.mlp_ratio,
372
+ hidden_act=self.hidden_act,
373
+ layer_norm_eps=self.layer_norm_eps,
374
+ ls_init_value=self.ls_init_value,
375
+ max_grid_height=self.base_grid[0],
376
+ max_grid_width=self.base_grid[1],
377
+ use_cls_token=self.use_cls_token,
378
+ use_rope2d=self.use_rope2d,
379
+ rope_kwargs={
380
+ "rope_theta": getattr(config, "rope_theta", 10000),
381
+ "rope_max_freq": getattr(config, "rope_max_freq", 10),
382
+ "rope_num_freqs": getattr(config, "rope_num_freqs", 1),
383
+ "rope_theta_rescale_factor":
384
+ getattr(config, "rope_theta_rescale_factor", 1.0),
385
+ "rope_freqs_for": getattr(config, "rope_freqs_for", "lang"),
386
+ },
387
+ )
388
+ self.vit_downsampler1 = nn.Conv2d(self.hidden_size,
389
+ self.hidden_size * 2,
390
+ kernel_size=3,
391
+ stride=2,
392
+ padding=1)
393
+ self.vit_downsampler2 = nn.Conv2d(self.hidden_size * 2,
394
+ self.hidden_size * 4,
395
+ kernel_size=3,
396
+ stride=2,
397
+ padding=1)
398
+
399
+
400
+ def sample_abs_posemb(self, grid_h: int, grid_w: int):
401
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
402
+ return self.positional_embedding[None, ...]
403
+
404
+ pos_embed = self.positional_embedding
405
+ if self.use_cls_token:
406
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
407
+
408
+ pos_embed = (pos_embed.reshape(1, self.posemb_grid_size,
409
+ self.posemb_grid_size,
410
+ -1).permute(0, 3, 1, 2).contiguous())
411
+ pos_embed = F.interpolate(pos_embed,
412
+ size=(grid_h, grid_w),
413
+ mode="bilinear",
414
+ align_corners=False)
415
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.hidden_size)
416
+
417
+ if self.use_cls_token:
418
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
419
+
420
+ return pos_embed[None, ...]
421
+
422
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
423
+ """
424
+ Args:
425
+ pixel_values: Image tensor of shape (B, C, H, W).
426
+ layer_idx: Negative indices stop after a given block (e.g., -1 uses all blocks).
427
+ strip_cls_token: If True and cls token is used, remove it from output.
428
+ """
429
+ bsz, _, height, width = pixel_values.shape
430
+ grid_h, grid_w = height // self.patch_size, width // self.patch_size
431
+
432
+ hidden_state = self.conv1(pixel_values) # (B, D, Gh, Gw)
433
+ hidden_state = hidden_state.flatten(2).transpose(1, 2) # (B, Gh*Gw, D)
434
+
435
+ if self.use_cls_token:
436
+ cls_token = self.class_embedding.view(1, 1,
437
+ -1).expand(bsz, -1, -1)
438
+ hidden_state = torch.cat([cls_token, hidden_state], dim=1)
439
+
440
+ if self.use_abs_posemb:
441
+ pos_emb = self.sample_abs_posemb(grid_h, grid_w)
442
+ hidden_state = hidden_state + pos_emb
443
+ hidden_state = self.ln_pre(hidden_state)
444
+ hidden_state = self.transformer(hidden_state, grid_hw=(grid_h, grid_w))
445
+
446
+ if self.use_ln_post:
447
+ hidden_state = self.ln_post(hidden_state)
448
+
449
+ if self.use_cls_token:
450
+ hidden_state = hidden_state[:, 1:, :]
451
+
452
+ return hidden_state