import logging import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from typing import Optional, Tuple, Union, List, Dict, Any from argus.layers import Mlp from argus.layers import PatchEmbed from argus.layers.block import Block from argus.layers.rope import RotaryPositionEmbedding2D, PositionGetter from argus.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 from argus.heads.utils import reorder_by_reference logger = logging.getLogger(__name__) _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] class Aggregator(nn.Module): """ Args: img_size (int): Image size in pixels. patch_size (int): Size of each patch for PatchEmbed. embed_dim (int): Dimension of the token embeddings. depth (int): Number of blocks. num_heads (int): Number of attention heads. mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. num_register_tokens (int): Number of register tokens. block_fn (nn.Module): The block type used for attention (Block by default). qkv_bias (bool): Whether to include bias in QKV projections. proj_bias (bool): Whether to include bias in the output projection. ffn_bias (bool): Whether to include bias in MLP layers. patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. qk_norm (bool): Whether to apply QK normalization. rope_freq (int): Base frequency for rotary embedding. -1 to disable. init_values (float): Init scale for layer scale. reorder_by_learning_ref (bool): Whether to reorder features by learning reference view index. ref_aa_block_num (int): Number of aa blocks for reference view learning. """ def __init__( self, img_size=518, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, num_register_tokens=4, block_fn=Block, qkv_bias=True, proj_bias=True, ffn_bias=True, patch_embed="dinov2_vitl14_reg", aa_order=["frame", "global"], aa_block_size=1, qk_norm=True, rope_freq=100, init_values=0.01, reorder_by_learning_ref=True, ref_aa_block_num=2, save_inference_memory=True, ): super().__init__() self.reorder_by_learning_ref = reorder_by_learning_ref self.save_inference_memory = save_inference_memory self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim) # Initialize rotary position embedding if frequency > 0 self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None self.position_getter = PositionGetter() if self.rope is not None else None self.frame_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(depth) ] ) self.global_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(depth) ] ) self.depth = depth self.aa_order = aa_order self.patch_size = patch_size self.aa_block_size = aa_block_size # Validate that depth is divisible by aa_block_size if self.depth % self.aa_block_size != 0: raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") self.aa_block_num = self.depth // self.aa_block_size # Reference Learning Network if self.reorder_by_learning_ref: self.ref_aa_block_num = ref_aa_block_num self.ref_frame_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(self.ref_aa_block_num) ] ) self.ref_global_blocks = nn.ModuleList( [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, rope=self.rope, ) for _ in range(self.ref_aa_block_num) ] ) # Note: We have two camera tokens, one for the first frame and one for the rest # The same applies for register tokens self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim)) if self.reorder_by_learning_ref: # describe the covisibility of the current frame with other frames self.covisibility_token = nn.Parameter(torch.randn(1, 1, 1, embed_dim)) # The patch tokens start after the camera and register tokens self.patch_start_idx = 1 + num_register_tokens # Initialize parameters with small values nn.init.normal_(self.camera_token, std=1e-6) nn.init.normal_(self.register_token, std=1e-6) if self.reorder_by_learning_ref: nn.init.normal_(self.covisibility_token, std=1e-6) # Register normalization constants as buffers for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False) self.use_reentrant = False # hardcoded to False def __build_patch_embed__( self, patch_embed, img_size, patch_size, num_register_tokens, interpolate_antialias=True, interpolate_offset=0.0, block_chunks=0, init_values=1.0, embed_dim=1024, ): """ Build the patch embed layer. If 'conv', we use a simple PatchEmbed conv layer. Otherwise, we use a vision transformer. """ if "conv" in patch_embed: self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim) else: vit_models = { "dinov2_vitl14_reg": vit_large, "dinov2_vitb14_reg": vit_base, "dinov2_vits14_reg": vit_small, "dinov2_vitg2_reg": vit_giant2, } self.patch_embed = vit_models[patch_embed]( img_size=img_size, patch_size=patch_size, num_register_tokens=num_register_tokens, interpolate_antialias=interpolate_antialias, interpolate_offset=interpolate_offset, block_chunks=block_chunks, init_values=init_values, ) # Disable gradient updates for mask token if hasattr(self.patch_embed, "mask_token"): # self.patch_embed.mask_token.requires_grad_(False) del self.patch_embed.mask_token # covisibility head if self.reorder_by_learning_ref: self.token_norm = nn.LayerNorm(embed_dim * 2) self.covisibility_head = Mlp(in_features=embed_dim * 2, hidden_features=embed_dim * 2 // 2, out_features=1, drop=0) def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]: """ Args: images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. B: batch size, S: sequence length, 3: RGB channels, H: height, W: width Returns: (list[torch.Tensor], int): The list of outputs from the attention blocks, and the patch_start_idx indicating where patch tokens begin. """ B, S, C_in, H, W = images.shape if C_in != 3: raise ValueError(f"Expected 3 input channels, got {C_in}") # Normalize images and reshape for patch embed images = (images - self._resnet_mean) / self._resnet_std # Reshape to [B*S, C, H, W] for patch embedding images = images.view(B * S, C_in, H, W) patch_tokens = self.patch_embed(images) if isinstance(patch_tokens, dict): patch_tokens = patch_tokens["x_norm_patchtokens"] _, P, C = patch_tokens.shape ################# ref learning covisibility_scores = None ref_idx = None if self.reorder_by_learning_ref: # expand covisibility token to match batch size and sequence length covisibility_token = self.covisibility_token.expand(B, S, 1, C).view(B * S, 1, C).contiguous() # Concatenate covisibility token with patch tokens covisibility_patch_tokens = torch.cat([covisibility_token, patch_tokens], dim=1) # [BS,1+HW,C] covisibility_pos = None if self.rope is not None: covisibility_pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) # do not use position embedding for special covisibility_token # so set pos to 0 for the special tokens covisibility_pos = covisibility_pos + 1 covisibility_pos_special = torch.zeros(B * S, 1, 2).to(images.device).to(covisibility_pos.dtype) covisibility_pos = torch.cat([covisibility_pos_special, covisibility_pos], dim=1) # [BS, 1+HW, 2] # update P because we added special tokens _, P_covis, C_covis = covisibility_patch_tokens.shape frame_idx = 0 global_idx = 0 output_list = [] for ref_block_i in range(self.ref_aa_block_num): for attn_type in self.aa_order: if attn_type == "frame": covisibility_patch_tokens, frame_idx, frame_intermediates = self._ref_process_frame_attention( covisibility_patch_tokens, B, S, P_covis, C_covis, frame_idx, pos=covisibility_pos ) elif attn_type == "global": covisibility_patch_tokens, global_idx, global_intermediates = self._ref_process_global_attention( covisibility_patch_tokens, B, S, P_covis, C_covis, global_idx, pos=covisibility_pos ) else: raise ValueError(f"Unknown attention type: {attn_type}") for i in range(len(frame_intermediates)): # concat frame and global intermediates, [B x S x P x 2C] concat_inter = torch.cat([frame_intermediates[-1], global_intermediates[-1]], dim=-1) output_list.append(concat_inter) last_covisibility_patch_tokens = output_list[-1][:,:,0,:] # [B, S, C] # normalize last_covisibility_patch_tokens = self.token_norm(last_covisibility_patch_tokens) covisibility_scores = self.covisibility_head(last_covisibility_patch_tokens).squeeze(-1) # [B, S] # # cos # feat_norm = F.normalize(covisibility_features, p=2, dim=-1, eps=1e-8) # [B, S, D] # covisibility_scores = feat_norm @ feat_norm.transpose(-1, -2) ref_idx = covisibility_scores.argmax(-1) # [B, S] -> [B] patch_tokens = patch_tokens.view(B,S,P,C) patch_tokens = reorder_by_reference(patch_tokens, ref_idx) patch_tokens = patch_tokens.view(B*S,P,C).contiguous() #################### # Expand camera and register tokens to match batch size and sequence length camera_token = slice_expand_and_flatten(self.camera_token, B, S) register_token = slice_expand_and_flatten(self.register_token, B, S) # Concatenate special tokens with patch tokens tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) # [BS,1+4+HW,C] pos = None if self.rope is not None: pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device) if self.patch_start_idx > 0: # do not use position embedding for special tokens (camera and register tokens) # so set pos to 0 for the special tokens pos = pos + 1 pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype) pos = torch.cat([pos_special, pos], dim=1) # [BS, 1+4+HW, 2] # update P because we added special tokens _, P, C = tokens.shape frame_idx = 0 global_idx = 0 output_list = [] for block_i in range(self.aa_block_num): for attn_type in self.aa_order: if attn_type == "frame": tokens, frame_idx, frame_intermediates = self._process_frame_attention( tokens, B, S, P, C, frame_idx, pos=pos ) elif attn_type == "global": tokens, global_idx, global_intermediates = self._process_global_attention( tokens, B, S, P, C, global_idx, pos=pos ) else: raise ValueError(f"Unknown attention type: {attn_type}") for i in range(len(frame_intermediates)): concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) if (not self.training ) and (self.save_inference_memory) and (block_i not in [4,11,17,23]): # only save the useful indices of intermediates output_list.append(torch.tensor(0)) else: # concat frame and global intermediates, [B x S x P x 2C] output_list.append(concat_inter) del concat_inter del frame_intermediates del global_intermediates return output_list, self.patch_start_idx, covisibility_scores, ref_idx def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): """ Process frame attention blocks. We keep tokens in shape (B*S, P, C). """ # If needed, reshape tokens or positions: if tokens.shape != (B * S, P, C): tokens = tokens.view(B, S, P, C).view(B * S, P, C) if pos is not None and pos.shape != (B * S, P, 2): pos = pos.view(B, S, P, 2).view(B * S, P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): if self.training: tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant) else: tokens = self.frame_blocks[frame_idx](tokens, pos=pos) frame_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, frame_idx, intermediates def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): """ Process global attention blocks. We keep tokens in shape (B, S*P, C). """ if tokens.shape != (B, S * P, C): tokens = tokens.view(B, S, P, C).view(B, S * P, C) if pos is not None and pos.shape != (B, S * P, 2): pos = pos.view(B, S, P, 2).view(B, S * P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): if self.training: tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant) else: tokens = self.global_blocks[global_idx](tokens, pos=pos) global_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, global_idx, intermediates def _ref_process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None): """ Process frame attention blocks. We keep tokens in shape (B*S, P, C). """ # If needed, reshape tokens or positions: if tokens.shape != (B * S, P, C): tokens = tokens.view(B, S, P, C).view(B * S, P, C) if pos is not None and pos.shape != (B * S, P, 2): pos = pos.view(B, S, P, 2).view(B * S, P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): if self.training: tokens = checkpoint(self.ref_frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant) else: tokens = self.ref_frame_blocks[frame_idx](tokens, pos=pos) frame_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, frame_idx, intermediates def _ref_process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None): """ Process global attention blocks. We keep tokens in shape (B, S*P, C). """ if tokens.shape != (B, S * P, C): tokens = tokens.view(B, S, P, C).view(B, S * P, C) if pos is not None and pos.shape != (B, S * P, 2): pos = pos.view(B, S, P, 2).view(B, S * P, 2) intermediates = [] # by default, self.aa_block_size=1, which processes one block at a time for _ in range(self.aa_block_size): if self.training: tokens = checkpoint(self.ref_global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant) else: tokens = self.ref_global_blocks[global_idx](tokens, pos=pos) global_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, global_idx, intermediates def slice_expand_and_flatten(token_tensor, B, S): """ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: 1) Uses the first position (index=0) for the first frame only 2) Uses the second position (index=1) for all remaining frames (S-1 frames) 3) Expands both to match batch size B 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token followed by (S-1) second-position tokens 5) Flattens to (B*S, X, C) for processing Returns: torch.Tensor: Processed tokens with shape (B*S, X, C) """ # Slice out the "query" tokens => shape (1, 1, ...) query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) # Slice out the "other" tokens => shape (1, S-1, ...) others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) # Concatenate => shape (B, S, ...) combined = torch.cat([query, others], dim=1) # Finally flatten => shape (B*S, ...) combined = combined.view(B * S, *combined.shape[2:]) return combined