import torch import torch.nn as nn import torch.nn.functional as F import timm import logging from types import SimpleNamespace as Namespace class DINOv2FeatureExtractor(nn.Module): def __init__( self, image_size=518, # Default for DINOv2 models model_type="vit_base_patch14_reg4_dinov2.lvd142m", num_of_layers_to_unfreeze=1, ): super().__init__() # Initialize backbone with registers self.backbone = timm.create_model( model_type, pretrained=True, num_classes=0, img_size=image_size ) # Store configuration parameters self.model_type = model_type self.num_channels = self.backbone.embed_dim self.image_size = image_size self.num_of_layers_to_unfreeze = num_of_layers_to_unfreeze def _freeze_parameters(self): """ Freeze all parameters except the last N transformer blocks and norm layer. """ # First freeze everything for param in self.backbone.parameters(): param.requires_grad = False # Unfreeze the last N blocks if self.num_of_layers_to_unfreeze > 0: for block in self.backbone.blocks[ -self.num_of_layers_to_unfreeze : ]: for param in block.parameters(): param.requires_grad = True # Unfreeze norm layer for param in self.backbone.norm.parameters(): param.requires_grad = True # Count trainable parameters for backbone def count_trainable_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) logging.info( f"Number of trainable parameters backbone: {count_trainable_params(self.backbone):,}" ) def forward(self, x): B, _, H, W = x.shape # x = self.backbone.forward_features(x) # Default behavior: extract features from CLS pooling features = self.backbone.forward_head(x, pre_logits=True) # L2 normalization return F.normalize(features, p=2, dim=-1)