import torch import torch.nn as nn import torch.nn.functional as F import logging from transformers import AutoModel import os class DINOv3FeatureExtractor(nn.Module): def __init__( self, image_size=512, model_type="facebook/dinov3-vitb16-pretrain-lvd1689m", ): super().__init__() # Store configuration parameters self.model_type = model_type self.image_size = image_size self.backbone = AutoModel.from_pretrained(model_type) self.num_channels = self.backbone.config.hidden_size self.desc_dim = self.backbone.config.hidden_size def _freeze_parameters(self): """Freeze parameters except last N layers and final norm""" # Freeze all parameters first for param in self.backbone.parameters(): param.requires_grad = False # Unfreeze last N transformer blocks if hasattr(self.backbone, "layer"): layers = self.backbone.layer # ModuleList of DINOv3ViTLayer num_layers = len(layers) start_idx = max(0, num_layers - self.num_of_layers_to_unfreeze) for i in range(start_idx, num_layers): for param in layers[i].parameters(): param.requires_grad = True else: raise ValueError("Could not find transformer layers (expected self.backbone.layer)") # Unfreeze final norm layer if hasattr(self.backbone, "norm") and self.num_of_layers_to_unfreeze !=0: for param in self.backbone.norm.parameters(): param.requires_grad = True # Count trainable parameters trainable_params = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad) logging.info(f"Number of trainable parameters in backbone: {trainable_params:,}") def forward(self, x): outputs = self.backbone(x) last_hidden_state = outputs.last_hidden_state B, seq_len, _ = last_hidden_state.shape # Handle DINOv3 register tokens num_register_tokens = getattr(self.backbone.config, 'num_register_tokens', 0) assert num_register_tokens != 0, 'Error number of register tokens cannot be 0 for Dinov3' num_patch_tokens = seq_len - 1 - num_register_tokens # Extract tokens cls_token = last_hidden_state[:, 0] # CLS token patch_tokens = last_hidden_state[:, 1:1+num_patch_tokens] # Patch tokens # Calculate spatial dimensions H = W = int(num_patch_tokens ** 0.5) # Assuming square layout # Default behavior: use CLS token with normalization return F.normalize(cls_token, p=2, dim=-1)