# Adapted from Meta's code base: https://github.com/facebookresearch/sam2 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # print(torch.cuda.memory_summary()) import logging from functools import partial from typing import List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from iopath.common.file_io import g_pathmgr from sam2.modeling.backbones.utils import ( PatchEmbed, window_partition, window_unpartition, ) from sam2.modeling.sam2_utils import DropPath, MLP from transformers import PretrainedConfig, PreTrainedModel import json def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: if pool is None: return x # (B, H, W, C) -> (B, C, H, W) x = x.permute(0, 3, 1, 2) x = pool(x) # (B, C, H', W') -> (B, H', W', C) x = x.permute(0, 2, 3, 1) if norm: x = norm(x) return x def enhanced_scaled_dot_product_attention(query, key, value): """ Computes scaled dot-product attention with a safeguard for large batch sizes. In practice, if the batch size or the resulting tensor size exceeds 2**16, it can cause CUDA launch or memory errors due to hardware limitations. To address this, we check whether the intermediate tensor size exceeds this threshold. If it does, we split the attention computation into smaller chunks, perform the attention calculation on each chunk separately, and finally merge the results to obtain the final attention output. """ batch_size = query.shape[0] if batch_size<=2**15: return F.scaled_dot_product_attention( query, key, value, ) else: results = [] chunk_size = 2**15 for i in range(0,batch_size,chunk_size): q_chunk = query[i:i+chunk_size] k_chunk = key[i:i+chunk_size] v_chunk = value[i:i+chunk_size] out_chunk = F.scaled_dot_product_attention(q_chunk, k_chunk, v_chunk) results.append(out_chunk) x_chunked = torch.cat(results, dim=0) return x_chunked class MultiScaleAttention(nn.Module): def __init__( self, dim: int, dim_out: int, num_heads: int, q_pool: nn.Module = None, ): super().__init__() self.dim = dim self.dim_out = dim_out self.num_heads = num_heads self.q_pool = q_pool self.qkv = nn.Linear(dim, dim_out * 3) self.proj = nn.Linear(dim_out, dim_out) def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, _ = x.shape # qkv with shape (B, H * W, 3, nHead, C) qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) # q, k, v with shape (B, H * W, nheads, C) q, k, v = torch.unbind(qkv, 2) # Q pooling (for downsample at stage changes) if self.q_pool: q = do_pool(q.reshape(B, H, W, -1), self.q_pool) H, W = q.shape[1:3] # downsampled shape q = q.reshape(B, H * W, self.num_heads, -1) # Torch's SDPA expects [B, nheads, H*W, C] so we transpose # x = F.scaled_dot_product_attention( # q.transpose(1, 2), # k.transpose(1, 2), # v.transpose(1, 2), # ) x = enhanced_scaled_dot_product_attention( query=q.transpose(1, 2), key=k.transpose(1, 2), value=v.transpose(1, 2), ) # Transpose back x = x.transpose(1, 2) x = x.reshape(B, H, W, -1) x = self.proj(x) return x class MultiScaleBlock(nn.Module): def __init__( self, dim: int, dim_out: int, num_heads: int, mlp_ratio: float = 4.0, drop_path: float = 0.0, norm_layer: Union[nn.Module, str] = "LayerNorm", q_stride: Tuple[int, int] = None, act_layer: nn.Module = nn.GELU, window_size: int = 0, ): super().__init__() if isinstance(norm_layer, str): norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) self.dim = dim self.dim_out = dim_out self.norm1 = norm_layer(dim) self.window_size = window_size self.pool, self.q_stride = None, q_stride if self.q_stride: self.pool = nn.MaxPool2d( kernel_size=q_stride, stride=q_stride, ceil_mode=False ) self.attn = MultiScaleAttention( dim, dim_out, num_heads=num_heads, q_pool=self.pool, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim_out) self.mlp = MLP( dim_out, int(dim_out * mlp_ratio), dim_out, num_layers=2, activation=act_layer, ) if dim != dim_out: self.proj = nn.Linear(dim, dim_out) def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x # B, H, W, C x = self.norm1(x) # Skip connection if self.dim != self.dim_out: shortcut = do_pool(self.proj(x), self.pool) # Window partition window_size = self.window_size if window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, window_size) # Window Attention + Q Pooling (if stage change) x = self.attn(x) if self.q_stride: # Shapes have changed due to Q pooling window_size = self.window_size // self.q_stride[0] H, W = shortcut.shape[1:3] pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size pad_hw = (H + pad_h, W + pad_w) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, window_size, pad_hw, (H, W)) x = shortcut + self.drop_path(x) # MLP x = x + self.drop_path(self.mlp(self.norm2(x))) return x class Hiera(nn.Module): """ Reference: https://arxiv.org/abs/2306.00989 """ def __init__( self, embed_dim: int = 96, # initial embed dim num_heads: int = 1, # initial number of heads drop_path_rate: float = 0.0, # stochastic depth q_pool: int = 3, # number of q_pool stages q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage dim_mul: float = 2.0, # dim_mul factor at stage shift head_mul: float = 2.0, # head_mul factor at stage shift window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), # window size per stage, when not using global att. window_spec: Tuple[int, ...] = ( 8, 4, 14, 7, ), # global attn in these blocks global_att_blocks: Tuple[int, ...] = ( 12, 16, 20, ), weights_path=None, return_interm_layers=True, # return feats from every stage ): super().__init__() assert len(stages) == len(window_spec) self.window_spec = window_spec depth = sum(stages) self.q_stride = q_stride self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] assert 0 <= q_pool <= len(self.stage_ends[:-1]) self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] self.return_interm_layers = return_interm_layers self.patch_embed = PatchEmbed( embed_dim=embed_dim, ) # Which blocks have global att? self.global_att_blocks = global_att_blocks # Windowed positional embedding (https://arxiv.org/abs/2311.05613) self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size self.pos_embed = nn.Parameter( torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) ) self.pos_embed_window = nn.Parameter( torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) ) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule cur_stage = 1 self.blocks = nn.ModuleList() for i in range(depth): dim_out = embed_dim # lags by a block, so first block of # next stage uses an initial window size # of previous stage and final window size of current stage window_size = self.window_spec[cur_stage - 1] if self.global_att_blocks is not None: window_size = 0 if i in self.global_att_blocks else window_size if i - 1 in self.stage_ends: dim_out = int(embed_dim * dim_mul) num_heads = int(num_heads * head_mul) cur_stage += 1 block = MultiScaleBlock( dim=embed_dim, dim_out=dim_out, num_heads=num_heads, drop_path=dpr[i], q_stride=self.q_stride if i in self.q_pool_blocks else None, window_size=window_size, ) embed_dim = dim_out self.blocks.append(block) self.channel_list = ( [self.blocks[i].dim_out for i in self.stage_ends[::-1]] if return_interm_layers else [self.blocks[-1].dim_out] ) if weights_path is not None: with g_pathmgr.open(weights_path, "rb") as f: chkpt = torch.load(f, map_location="cpu") # logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False)) res = self.load_state_dict(chkpt, strict=False) logging.info(f"loading Hiera: {res}") def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: h, w = hw window_embed = self.pos_embed_window pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") pos_embed = pos_embed + window_embed.tile( [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] ) pos_embed = pos_embed.permute(0, 2, 3, 1) return pos_embed def forward(self, x: torch.Tensor) -> List[torch.Tensor]: x = self.patch_embed(x) # x: (B, H, W, C) # Add pos embed x = x + self._get_pos_embed(x.shape[1:3]) outputs = [] for i, blk in enumerate(self.blocks): x = blk(x) if (i == self.stage_ends[-1]) or ( i in self.stage_ends and self.return_interm_layers ): feats = x.permute(0, 3, 1, 2) outputs.append(feats) return outputs def get_layer_id(self, layer_name): # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 num_layers = self.get_num_layers() if layer_name.find("rel_pos") != -1: return num_layers + 1 elif layer_name.find("pos_embed") != -1: return 0 elif layer_name.find("patch_embed") != -1: return 0 elif layer_name.find("blocks") != -1: return int(layer_name.split("blocks")[1].split(".")[1]) + 1 else: return num_layers + 1 def get_num_layers(self) -> int: return len(self.blocks) class HieraConfig(PretrainedConfig): model_type = "hiera" def __init__( self, embed_dim=96, num_heads=1, drop_path_rate=0.0, q_pool=3, q_stride=(2, 2), stages=(2, 3, 16, 3), dim_mul=2.0, head_mul=2.0, window_pos_embed_bkg_spatial_size=(14, 14), window_spec=(8, 4, 14, 7), global_att_blocks=(12, 16, 20), weights_path=None, return_interm_layers=True, **kwargs, ): super().__init__(**kwargs) self.embed_dim = embed_dim self.num_heads = num_heads self.drop_path_rate = drop_path_rate self.q_pool = q_pool self.q_stride = q_stride self.stages = stages self.dim_mul = dim_mul self.head_mul = head_mul self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size self.window_spec = window_spec self.global_att_blocks = global_att_blocks self.weights_path = weights_path self.return_interm_layers = return_interm_layers @classmethod def from_json_file(cls, json_file): with open(json_file, "r") as f: config_dict = json.load(f) return cls(**config_dict) class HieraVisionModel(PreTrainedModel): config_class = HieraConfig _no_split_modules = ["Hiera"] def __init__(self, config, weights_path=None): super().__init__(config) self.hiera = Hiera( embed_dim=config.embed_dim, num_heads=config.num_heads, drop_path_rate=config.drop_path_rate, q_pool=config.q_pool, q_stride=config.q_stride, stages=config.stages, dim_mul=config.dim_mul, head_mul=config.head_mul, window_pos_embed_bkg_spatial_size=config.window_pos_embed_bkg_spatial_size, window_spec=config.window_spec, global_att_blocks=config.global_att_blocks, return_interm_layers=config.return_interm_layers, weights_path=weights_path, ) def forward(self, x): return self.hiera(x) if __name__ == "__main__": model = HieraVisionModel.from_pretrained("nkkbr/hiera-base-plus-in-sam2.1") model = model.hiera for name,param in model.named_parameters(): print(f"{name:50} {param.shape}") # Check whether the weights are consistent with the hiera module in sam2.1-hiera-base-plus. import torch from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-base-plus") hiera_model_in_predictor = predictor.model.image_encoder.trunk for name,param in model.named_parameters(): if not torch.equal(param, hiera_model_in_predictor.state_dict()[name]): print(f"The parameter {name} has different weights in the two models.") print("Comparison complete!")