Instructions to use 229nagibator229/theia-tiny-patch16-224-cdiv with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use 229nagibator229/theia-tiny-patch16-224-cdiv with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="229nagibator229/theia-tiny-patch16-224-cdiv", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("229nagibator229/theia-tiny-patch16-224-cdiv", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved. | |
| import math | |
| from itertools import chain | |
| from typing import Any, Optional | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.functional import interpolate | |
| from einops.layers.torch import Rearrange | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from transformers import AutoConfig, AutoModel, AutoProcessor, AutoImageProcessor | |
| from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTModel | |
| def handle_feature_output( | |
| x: torch.Tensor, feature_reduce_method: Optional[str] = None, num_discard_tokens: int = 0 | |
| ) -> torch.Tensor: | |
| """Handle feature output from transformer. | |
| Args: | |
| x (torch.Tensor): input feature to be handled. shape is | |
| [B, 1+H*W+N, C] if including both CLS and register tokens. | |
| [B, 1+H*W, C] for standard model (N=0). | |
| [B, H*W, C] for model without CLS. | |
| feature_reduce_method (Optional[str]): method to select token. Options: | |
| - `mean_pooling`: average over spatial tokens (non CLS tokens), output shape = [B, C]. | |
| - `max_pooling`: max over spatial tokens, output shape = [B, C]. | |
| - `cls`: return CLS token only, output shape = [B, C]. | |
| - `identity`: return the feature without touching it, output shape = input shape. | |
| - `None`: return spatial tokens, output shape = [B, H*W, C] (assuming input is [B, 1+H*W, C]). | |
| suppose raw feature is in shape [B, 1+H*W, C], `1` corresponds to CLS token. | |
| num_discard_tokens (int): | |
| number of tokens to be discarded. Assuming they are at the end of the sequence. | |
| Returns: | |
| torch.Tensor: selected feature tokens. | |
| """ | |
| match feature_reduce_method: | |
| case "mean_pooling": | |
| return torch.mean(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) # [B, C] | |
| case "max_pooling": | |
| return torch.amax(x[:, 1 : x.size(1) - num_discard_tokens], dim=1) # [B, C] | |
| case "cls": | |
| return x[:, 0] # [B, C] | |
| case "identity": | |
| return x | |
| case None: | |
| return x[:, 1 : x.size(1) - num_discard_tokens] | |
| case _: | |
| raise NotImplementedError(f"feature_reduce_method {feature_reduce_method} it not implemented.") | |
| # Modified from huggingface transformers ViTEmbeddings | |
| # Original Copyright 2021 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| class ViTEmbeddingsNoCLS(ViTEmbeddings): | |
| """ViT Embedding Module without CLS token.""" | |
| def __init__(self, config: AutoConfig, use_mask_token: bool = False): | |
| """Initialization. | |
| Args: | |
| config (AutoConfig): config for ViT. | |
| use_mask_token (bool, optional): whether to use mask token. Defaults to False. | |
| """ | |
| super(ViTEmbeddingsNoCLS, self).__init__(config, use_mask_token=use_mask_token) | |
| self.cls_token = None | |
| def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: | |
| """ | |
| This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher | |
| resolution images. | |
| Source: | |
| https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 | |
| """ | |
| num_patches = embeddings.shape[1] | |
| num_positions = self.position_embeddings.shape[1] - 1 | |
| if num_patches == num_positions and height == width: | |
| return self.position_embeddings | |
| patch_pos_embed = self.position_embeddings[:, 1:] | |
| dim = embeddings.shape[-1] | |
| h0 = height // self.config.patch_size | |
| w0 = width // self.config.patch_size | |
| # we add a small number to avoid floating point error in the interpolation | |
| # see discussion at https://github.com/facebookresearch/dino/issues/8 | |
| h0, w0 = h0 + 0.1, w0 + 0.1 | |
| patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) | |
| patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) | |
| patch_pos_embed = nn.functional.interpolate( | |
| patch_pos_embed, | |
| scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] | |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
| return patch_pos_embed | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| bool_masked_pos: Optional[torch.BoolTensor] = None, | |
| interpolate_pos_encoding: bool = False, | |
| ) -> torch.Tensor: | |
| batch_size, num_channels, height, width = pixel_values.shape | |
| embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) | |
| if bool_masked_pos is not None: | |
| seq_length = embeddings.shape[1] | |
| mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) | |
| # replace the masked visual tokens by mask_tokens | |
| mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) | |
| embeddings = embeddings * (1.0 - mask) + mask_tokens * mask | |
| # add positional encoding to each token | |
| if interpolate_pos_encoding: | |
| embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) | |
| else: | |
| embeddings = embeddings + self.position_embeddings[:, 1:] | |
| embeddings = self.dropout(embeddings) | |
| return embeddings | |
| # modified from huggingface transformers ViTModel | |
| class ViTModelNoCLS(ViTModel): | |
| """ViT Model without CLS token.""" | |
| def __init__(self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None: | |
| super(ViTModelNoCLS, self).__init__(config, add_pooling_layer, use_mask_token) | |
| self.embeddings = ViTEmbeddingsNoCLS(config, use_mask_token=use_mask_token) | |
| self.no_cls = True | |
| def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: | |
| """Initialize the weights""" | |
| if isinstance(module, (nn.Linear, nn.Conv2d)): | |
| # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid | |
| # `trunc_normal_cpu` not implemented in `half` issues | |
| module.weight.data = nn.init.trunc_normal_( | |
| module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range | |
| ).to(module.weight.dtype) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, ViTEmbeddings): | |
| module.position_embeddings.data = nn.init.trunc_normal_( | |
| module.position_embeddings.data.to(torch.float32), | |
| mean=0.0, | |
| std=self.config.initializer_range, | |
| ).to(module.position_embeddings.dtype) | |
| # modified from huggingface transformers ViTEmbeddings | |
| class ViTEmbeddingsReg(ViTEmbeddings): | |
| """ | |
| ViT Embedding Module with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1 | |
| """ | |
| def __init__(self, config: AutoConfig, use_mask_token: bool = False, num_reg_tokens: int = 7): | |
| super(ViTEmbeddingsReg, self).__init__(config, use_mask_token=use_mask_token) | |
| self.reg_token = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size)) | |
| self.num_reg_tokens = num_reg_tokens | |
| self.reg_pos_embed = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size)) | |
| self.reg_pos_embed.data = nn.init.trunc_normal_( | |
| self.reg_pos_embed.data.to(torch.float32), | |
| mean=0.0, | |
| std=self.config.initializer_range, | |
| ).to(self.reg_pos_embed.dtype) | |
| self.reg_token.data = nn.init.trunc_normal_( | |
| self.reg_token.data.to(torch.float32), | |
| mean=0.0, | |
| std=self.config.initializer_range, | |
| ).to(self.reg_token.dtype) | |
| def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: | |
| """ | |
| This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher | |
| resolution images. | |
| Source: | |
| https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 | |
| """ | |
| num_patches = embeddings.shape[1] - 1 - self.num_reg_tokens | |
| num_positions = self.position_embeddings.shape[1] - 1 | |
| if num_patches == num_positions and height == width: | |
| return self.position_embeddings | |
| class_pos_embed = self.position_embeddings[:, 0] | |
| patch_pos_embed = self.position_embeddings[:, 1:] | |
| reg_pos_embed = self.reg_pos_embed | |
| dim = embeddings.shape[-1] | |
| h0 = height // self.config.patch_size | |
| w0 = width // self.config.patch_size | |
| # we add a small number to avoid floating point error in the interpolation | |
| # see discussion at https://github.com/facebookresearch/dino/issues/8 | |
| h0, w0 = h0 + 0.1, w0 + 0.1 | |
| patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) | |
| patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) | |
| patch_pos_embed = nn.functional.interpolate( | |
| patch_pos_embed, | |
| scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] | |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, reg_pos_embed), dim=1) | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| bool_masked_pos: Optional[torch.BoolTensor] = None, | |
| interpolate_pos_encoding: bool = False, | |
| ) -> torch.Tensor: | |
| batch_size, num_channels, height, width = pixel_values.shape | |
| embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) | |
| if bool_masked_pos is not None: | |
| seq_length = embeddings.shape[1] | |
| mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) | |
| # replace the masked visual tokens by mask_tokens | |
| mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) | |
| embeddings = embeddings * (1.0 - mask) + mask_tokens * mask | |
| # add the [CLS] token to the embedded patch tokens | |
| cls_tokens = self.cls_token.expand(batch_size, -1, -1) | |
| reg_tokens = self.reg_token.expand(batch_size, -1, -1) | |
| embeddings = torch.cat((cls_tokens, embeddings, reg_tokens), dim=1) | |
| # add positional encoding to each token | |
| if interpolate_pos_encoding: | |
| embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) | |
| else: | |
| embeddings = embeddings + torch.cat([self.position_embeddings, self.reg_pos_embed], dim=1) | |
| embeddings = self.dropout(embeddings) | |
| return embeddings | |
| # modified from huggingface transformers ViTModel | |
| class ViTModelReg(ViTModel): | |
| """ViT Model with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1""" | |
| def __init__( | |
| self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, num_reg_tokens: int = 7 | |
| ): | |
| super(ViTModelReg, self).__init__(config, add_pooling_layer, use_mask_token) | |
| self.embeddings = ViTEmbeddingsReg(config, use_mask_token=use_mask_token, num_reg_tokens=num_reg_tokens) | |
| self.num_reg_tokens = num_reg_tokens | |
| def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None: | |
| """Initialize the weights""" | |
| if isinstance(module, (nn.Linear, nn.Conv2d)): | |
| # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid | |
| # `trunc_normal_cpu` not implemented in `half` issues | |
| module.weight.data = nn.init.trunc_normal_( | |
| module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range | |
| ).to(module.weight.dtype) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| elif isinstance(module, ViTEmbeddings): | |
| module.position_embeddings.data = nn.init.trunc_normal_( | |
| module.position_embeddings.data.to(torch.float32), | |
| mean=0.0, | |
| std=self.config.initializer_range, | |
| ).to(module.position_embeddings.dtype) | |
| module.cls_token.data = nn.init.trunc_normal_( | |
| module.cls_token.data.to(torch.float32), | |
| mean=0.0, | |
| std=self.config.initializer_range, | |
| ).to(module.cls_token.dtype) | |
| class DeiT(nn.Module): | |
| """DeiT model. | |
| Paper: Training data-efficient image transformers & distillation through attention | |
| https://arxiv.org/abs/2012.12877 | |
| Huggingface Reference: https://huggingface.co/docs/transformers/en/model_doc/deit | |
| Attributes: | |
| model_name (str): name of the model. | |
| pretrained (bool): whether to use pretrained weights. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "facebook/deit-small-patch16-224", | |
| pretrained: bool = False, | |
| image_size: int = 224, | |
| ): | |
| super().__init__() | |
| self.image_size = image_size | |
| model = AutoModel.from_pretrained(model_name) | |
| if pretrained: | |
| self.model = model | |
| else: | |
| deit_config = model.config | |
| self.model = AutoModel.from_config(deit_config) | |
| del model | |
| self.model.pooler = nn.Identity() | |
| self.processor = AutoProcessor.from_pretrained(model_name) | |
| def get_feature_size( | |
| self, | |
| keep_spatial: bool = False, | |
| return_torch_size: bool = False, | |
| ) -> torch.Size | tuple[int, ...]: | |
| """Get the size of the feature. | |
| Args: | |
| keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. | |
| return_torch_size (bool): if true, return torch.Size type. Defaults to False. | |
| Returns: | |
| torch.Size | tuple[int, ...]: returned feature shape. | |
| """ | |
| with torch.inference_mode(): | |
| image_size = (224, 224) | |
| x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) | |
| y = self.forward(x)[:, 1:] # for getting feature size, discard cls token | |
| size = y.size()[1:][::-1] | |
| if keep_spatial: | |
| assert math.isqrt(size[-1]) | |
| h = w = int(math.sqrt(size[-1])) | |
| size = (size[0], h, w) | |
| if return_torch_size: | |
| size = torch.Size(size) | |
| return size | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| do_resize: bool = True, | |
| interpolate_pos_encoding: Optional[bool] = None, | |
| do_rescale: bool = True, | |
| do_normalize: bool = True, | |
| ) -> torch.Tensor: | |
| """Forward pass of the model | |
| Args: | |
| x (torch.Tensor): model input. | |
| - arguments for self.processor. Details can be find at | |
| https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor | |
| do_resize (bool): if do resizing in processor. Defaults to True. | |
| interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. | |
| do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. | |
| do_normalize (bool): if do normalize in processor. Defaults to True. | |
| Returns: | |
| torch.Tensor: model output. | |
| """ | |
| input = self.processor( | |
| x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize | |
| ).to(self.model.device) | |
| y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) | |
| return y.last_hidden_state | |
| class DeiTNoCLS(nn.Module): | |
| """Modified DeiT model without CLS token.""" | |
| def __init__( | |
| self, model_name: str = "nocls-facebook/deit-small-patch16-224", pretrained: bool = False, image_size: int = 224 | |
| ): | |
| super().__init__() | |
| self.image_size = image_size | |
| pretrained_model_name = model_name.replace("nocls-", "") | |
| deit_config = AutoConfig.from_pretrained(pretrained_model_name) | |
| self.model = ViTModelNoCLS(deit_config) | |
| if pretrained: | |
| pretrained_model = AutoModel.from_pretrained(pretrained_model_name) | |
| pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()} | |
| self.load_state_dict(pretrained_dict, strict=False) | |
| del pretrained_model, pretrained_dict | |
| self.model.pooler = nn.Identity() | |
| self.processor = AutoProcessor.from_pretrained(pretrained_model_name) | |
| self.no_cls = True | |
| def get_feature_size( | |
| self, | |
| keep_spatial: bool = False, | |
| return_torch_size: bool = False, | |
| ) -> torch.Size | tuple[int, ...]: | |
| """Get the size of the feature. | |
| Args: | |
| keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. | |
| return_torch_size (bool): if true, return torch.Size type. Defaults to False. | |
| Returns: | |
| torch.Size | tuple[int, ...]: returned feature shape. | |
| """ | |
| with torch.inference_mode(): | |
| image_size = (self.image_size, self.image_size) | |
| x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) | |
| y = self.forward(x) | |
| size = y.size()[1:][::-1] | |
| if keep_spatial: | |
| assert math.isqrt(size[-1]) | |
| h = w = int(math.sqrt(size[-1])) | |
| size = (size[0], h, w) | |
| if return_torch_size: | |
| size = torch.Size(size) | |
| return size | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| do_resize: bool = True, | |
| interpolate_pos_encoding: Optional[bool] = None, | |
| do_rescale: bool = True, | |
| do_normalize: bool = True, | |
| ) -> torch.Tensor: | |
| """Forward pass of the model | |
| Args: | |
| x (torch.Tensor): model input. | |
| - arguments for self.processor. Details can be find at | |
| https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor | |
| do_resize (bool): if do resizing in processor. Defaults to True. | |
| do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. | |
| do_normalize (bool): if do normalize in processor. Defaults to True. | |
| - argument for forward | |
| interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. | |
| Returns: | |
| torch.Tensor: model output. | |
| """ | |
| input = self.processor( | |
| x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize | |
| ).to(self.model.device) | |
| y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) | |
| return y.last_hidden_state | |
| class DeiTReg(nn.Module): | |
| """Modified DeiT model with register tokens.""" | |
| def __init__( | |
| self, | |
| model_name: str = "reg-facebook/deit-small-patch16-224", | |
| pretrained: bool = False, | |
| image_size: int = 224, | |
| num_reg_tokens: int = 7, | |
| ): | |
| super().__init__() | |
| self.image_size = image_size | |
| pretrained_model_name = model_name.replace("reg-", "") | |
| deit_config = AutoConfig.from_pretrained(pretrained_model_name) | |
| self.model = ViTModelReg(deit_config, num_reg_tokens=num_reg_tokens) | |
| if pretrained: | |
| pretrained_model = AutoModel.from_pretrained(pretrained_model_name) | |
| pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()} | |
| self.load_state_dict(pretrained_dict, strict=False) | |
| del pretrained_model, pretrained_dict | |
| self.model.pooler = nn.Identity() | |
| self.processor = AutoProcessor.from_pretrained(pretrained_model_name) | |
| self.num_reg_tokens = num_reg_tokens | |
| def get_feature_size( | |
| self, | |
| keep_spatial: bool = False, | |
| return_torch_size: bool = False, | |
| ) -> torch.Size | tuple[int, ...]: | |
| """Get the size of the feature. | |
| Args: | |
| keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False. | |
| return_torch_size (bool): if true, return torch.Size type. Defaults to False. | |
| Returns: | |
| torch.Size | tuple[int, ...]: returned feature shape. | |
| """ | |
| with torch.inference_mode(): | |
| image_size = (self.image_size, self.image_size) | |
| x = torch.zeros((1, *image_size, 3), dtype=torch.uint8) | |
| y = self.forward(x)[:, 1 : -self.num_reg_tokens] | |
| size = y.size()[1:][::-1] | |
| if keep_spatial: | |
| assert math.isqrt(size[-1]) | |
| h = w = int(math.sqrt(size[-1])) | |
| size = (size[0], h, w) | |
| if return_torch_size: | |
| size = torch.Size(size) | |
| return size | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| do_resize: bool = True, | |
| interpolate_pos_encoding: Optional[bool] = None, | |
| do_rescale: bool = True, | |
| do_normalize: bool = True, | |
| ) -> torch.Tensor: | |
| """Forward pass of the model | |
| Args: | |
| x (torch.Tensor): model input. | |
| - arguments for self.processor. Details can be find at | |
| https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor | |
| do_resize (bool): if do resizing in processor. Defaults to True. | |
| interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None. | |
| do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True. | |
| do_normalize (bool): if do normalize in processor. Defaults to True. | |
| Returns: | |
| torch.Tensor: model output. | |
| """ | |
| input = self.processor( | |
| x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize | |
| ).to(self.model.device) | |
| y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding) | |
| return y.last_hidden_state | |
| def build_backbone(model_name: str, pretrained: bool = False, image_size: int = 224, **kwargs: Any) -> nn.Module: | |
| """Build the backbone visual encoder of robot vision foundation model. | |
| Args: | |
| model_name (str): name of the model. | |
| pretrained (bool): whether to use pretrained weights. Defaults to False. | |
| image_size (int): size of the image. Assume a square image. Defaults to 224 | |
| kwargs (Any): any kwargs specific to some models. For example, | |
| `num_reg_tokens` for `DeiTReg` when `"reg"` in `model_name` | |
| Returns: | |
| nn.Module: backbone network. | |
| """ | |
| if "reg" in model_name: | |
| return DeiTReg(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs) | |
| elif "nocls" in model_name: | |
| return DeiTNoCLS(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs) | |
| elif "deit" in model_name: | |
| return DeiT(model_name=model_name, pretrained=pretrained, image_size=image_size) | |
| else: | |
| raise NotImplementedError(f"Requested {model_name} is not implemented.") | |
| class Interpolation(nn.Module): | |
| """Interpolation nn.Module wrap for nn.functional.interpolate. | |
| Attributes: | |
| target_size (tuple[int, int] | torch.Size): target spatial size of this interpolation. | |
| """ | |
| def __init__(self, target_size: tuple[int, int] | torch.Size) -> None: | |
| super().__init__() | |
| self.target_size = target_size | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Very simple forward pass to call interpolate().""" | |
| return interpolate(x, self.target_size) | |
| class LinearAdapterHead(nn.Module): | |
| """Adapter head contains a single linear layer.""" | |
| def __init__( | |
| self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size | |
| ): | |
| """Initialization function for LinearAdapterHead. | |
| Args: | |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. | |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. | |
| num_layer (int): number of MLP layers (One linear layer if num_layer = 1). | |
| """ | |
| super().__init__() | |
| self.source_size = source_size | |
| self.target_size = target_size | |
| source_channel_size = self.source_size[0] | |
| target_channel_size = self.target_size[0] | |
| self.adapter = nn.Sequential( | |
| nn.Linear(source_channel_size, target_channel_size), | |
| ) | |
| def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: | |
| """Forward pass for the adapter. """ | |
| assert backbone_no_cls == False | |
| # x: [B, (1+H*W), C] | |
| # LinearAdapterHead is used only when there is cls token in the backbone. | |
| x = x[:, 0] | |
| x = self.adapter(x) | |
| return x # [B, (H*W), C] | |
| class MLPAdapterHead(nn.Module): | |
| """MLP Adapter module. | |
| Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. | |
| Will first do interpolation to match the spatial size [H_t, W_t], | |
| followed by MLP to project to the target channel dimension [C_t]. | |
| Attributes: | |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. [C, H, W] | |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. [C, H, W] | |
| adapter (nn.Module): the adapter module. | |
| interpolation (nn.Module): interpolation to adjust sizes before MLP. | |
| """ | |
| def __init__( | |
| self, source_size: tuple[int, ...] | torch.Size, target_size: tuple[int, ...] | torch.Size, num_layer: int | |
| ): | |
| """Initialization function for MLPAdapter. | |
| Args: | |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. | |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. | |
| num_layer (int): number of MLP layers (One linear layer if num_layer = 1). | |
| """ | |
| super().__init__() | |
| assert num_layer >= 1, f"`num_layer` in {self._get_name()} should >= 1. Got {num_layer}" | |
| self.source_size = source_size | |
| self.target_size = target_size | |
| source_channel_size = self.source_size[0] | |
| target_channel_size = self.target_size[0] | |
| self.interpolation = nn.Sequential( | |
| nn.Identity(), | |
| ) | |
| if self.source_size[1] != self.target_size[1]: | |
| self.interpolation = nn.Sequential( | |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), | |
| Interpolation(self.target_size[1:]), | |
| Rearrange("b c h w-> b (h w) c"), | |
| ) | |
| if num_layer == 1: | |
| self.adapter = nn.Sequential( | |
| nn.Linear(source_channel_size, target_channel_size), | |
| ) | |
| elif num_layer >= 2: | |
| hidden_dim = source_channel_size * 2 | |
| self.adapter = nn.Sequential( | |
| nn.Linear(source_channel_size, hidden_dim), | |
| *list( | |
| chain.from_iterable([[nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)] for _ in range(num_layer - 2)]) | |
| ), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, target_channel_size), | |
| ) | |
| def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: | |
| """Forward pass for the adapter. First interpolation then MLP.""" | |
| # x: [B, (1)+H*W, C] | |
| if not backbone_no_cls: | |
| x = x[:, 1:] | |
| # x: [B, (H*W), C] | |
| x = self.interpolation(x) | |
| x = self.adapter(x) | |
| return x # [B, (H*W), C] | |
| class ConvAdapterHead(nn.Module): | |
| """Convolutional Adapter module. | |
| Transforms features in shape source size [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. | |
| Uses CNN to map channel and spatial sizes jointly. | |
| Note: only work for (16, 16), (any, any), any <= 14, and (64, 64) spatial sizes for now. | |
| Attributes: | |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. | |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. | |
| adapter (nn.Module): the adapter module. | |
| interpolation (nn.Module): interpolation to adjust sizes before MLP. | |
| """ | |
| def __init__( | |
| self, | |
| source_size: tuple[int, ...] | torch.Size, | |
| target_size: tuple[int, ...] | torch.Size, | |
| ): | |
| """Initialization function for ConvAdapter. | |
| Args: | |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. | |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. | |
| """ | |
| super().__init__() | |
| self.source_size = source_size | |
| self.target_size = target_size | |
| hidden_dim = self.source_size[0] * 2 | |
| source_channel_size = self.source_size[0] | |
| target_channel_size = self.target_size[0] | |
| if self.source_size[1] < 12: | |
| raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.") | |
| elif self.source_size[1] < 16: # pad (any, any), any <= 14 to (16, 16) | |
| self.pad = nn.Sequential( | |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), | |
| nn.ConvTranspose2d( | |
| source_channel_size, | |
| source_channel_size, | |
| kernel_size=3, | |
| stride=1, | |
| output_padding=14 - self.source_size[1], | |
| ), | |
| ) | |
| self.source_size = (self.source_size[0], 16, 16) | |
| elif self.source_size[1] == 16 or self.source_size[1] == 64: # do nothing for (16, 16) and (64, 64) | |
| self.pad = nn.Sequential( | |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), | |
| ) | |
| else: | |
| raise NotImplementedError("feature spatial size (>=16x16) other than 16x16 and 64x64 is not supported.") | |
| if self.source_size[1] < self.target_size[1]: # (16, 16) / (14, 14) to (64, 64) | |
| self.adapter = nn.Sequential( | |
| nn.LayerNorm(self.source_size), | |
| nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 31 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 31, 31]), | |
| nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), # 64 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 64, 64]), | |
| nn.ConvTranspose2d(hidden_dim, target_channel_size, kernel_size=3, stride=1, padding=1), # 64 | |
| Rearrange("b c h w-> b (h w) c"), | |
| ) | |
| elif self.source_size[1] == self.target_size[1]: # (16, 16) to (16, 16) | |
| self.adapter = nn.Sequential( | |
| nn.LayerNorm(self.source_size), | |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), # 16 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, *self.source_size[1:]]), | |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), # 16 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, *self.source_size[1:]]), | |
| nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), # 16 | |
| Rearrange("b c h w-> b (h w) c"), | |
| ) | |
| else: # (64, 64) to (16, 16) | |
| self.adapter = nn.Sequential( | |
| nn.LayerNorm(self.source_size), | |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 32 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 32, 32]), | |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), # 16 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 16, 16]), | |
| nn.Conv2d(hidden_dim, target_channel_size, kernel_size=3, padding=1), # 16 | |
| Rearrange("b c h w-> b (h w) c"), | |
| ) | |
| def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: | |
| """Forward pass for ConvAdapter""" | |
| # x: [B, (1)+H*W, C] | |
| if not backbone_no_cls: | |
| x = x[:, 1:] | |
| # x: [B, H*W, C] | |
| x = self.pad(x) | |
| x = self.adapter(x) | |
| return x # B, (H*W), C | |
| class LightConvAdapterHead(nn.Module): | |
| """Light Convolutional Adapter module. | |
| Transforms features from source size in [B, (H_s*W_s), C_s] to target size [B, (H_t*W_t), C_t]. | |
| Uses CNN to map channel and spatial sizes jointly. | |
| Note: only work for source sizes (H_s, W_s): (16, 16), (any, any), 12 <= any <= 14, | |
| and target sizes (H_t, W_t): (16, 16) and (64, 64) for now. | |
| Attributes: | |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature, | |
| channel first (C, H, W). | |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature, | |
| channel first (C, H, W). | |
| adapter (nn.Module): the adapter module. | |
| interpolation (nn.Module): interpolation to adjust sizes before MLP. | |
| """ | |
| def __init__( | |
| self, | |
| source_size: tuple[int, ...] | torch.Size, | |
| target_size: tuple[int, ...] | torch.Size, | |
| hidden_size_factor: int | float = 1.0, | |
| ): | |
| """Initialization function for ConvAdapter. | |
| Args: | |
| source_size (tuple[int, ...] | torch.Size): the size of the source feature. | |
| target_size (tuple[int, ...] | torch.Size): the size of the target feature. | |
| hidden_size_factor (int | float): the size of hidden dim of feature translator | |
| as a factor of input feature hidden dim. | |
| """ | |
| super().__init__() | |
| if source_size[1] != source_size[2] or target_size[1] != target_size[2]: | |
| raise NotImplementedError( | |
| "Currently does not support non-square feature maps like source size" | |
| "{source_size} and target size {target_size}." | |
| ) | |
| self.source_size = source_size | |
| self.target_size = target_size | |
| self.hidden_size_factor = hidden_size_factor | |
| hidden_dim = int(self.source_size[0] * hidden_size_factor) | |
| source_channel_size = self.source_size[0] | |
| target_channel_size = self.target_size[0] | |
| if self.source_size[1] < 12: | |
| raise NotImplementedError("feature spatial size smaller than 12x12 is not supported.") | |
| elif self.source_size[1] < 16 and self.target_size[1] >= 16: # pad (any, any), any <= 14 to (16, 16) | |
| self.pad = nn.Sequential( | |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), | |
| nn.ConvTranspose2d( | |
| source_channel_size, | |
| source_channel_size, | |
| kernel_size=3, | |
| stride=1, | |
| output_padding=14 - self.source_size[1], | |
| ), | |
| ) | |
| self.source_size = (self.source_size[0], 16, 16) | |
| elif (self.source_size[1] == 16 or self.source_size[1] == 64) or \ | |
| (self.source_size[1] == 14 and self.target_size[1] == 14): | |
| # no padding for (16, 16), (64, 64) and (14, 14) <-> (14, 14) | |
| self.pad = nn.Sequential( | |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), | |
| ) | |
| elif self.target_size[1] < 14: | |
| self.pad = nn.Sequential( | |
| Rearrange("b (h w) c-> b c h w", h=self.source_size[1], w=self.source_size[2]), | |
| ) | |
| else: | |
| raise NotImplementedError("feature spatial size larger than 16x16 (other than 64x64) is not supported.") | |
| if self.source_size[1] == 16 and self.target_size[1] == 64: # (16, 16) to (64, 64) | |
| self.adapter = nn.Sequential( | |
| nn.LayerNorm(self.source_size), | |
| nn.ConvTranspose2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 31 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 31, 31]), | |
| nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, output_padding=1), # 64 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 64, 64]), | |
| Rearrange("b c h w-> b (h w) c"), | |
| nn.Linear(hidden_dim, target_channel_size), | |
| ) | |
| elif self.source_size[1] == self.target_size[1]: # (16, 16) to (16, 16) | |
| self.adapter = nn.Sequential( | |
| nn.LayerNorm(self.source_size), | |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, padding=1), # 16 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, *self.source_size[1:]]), | |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), # 16 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, *self.source_size[1:]]), | |
| Rearrange("b c h w-> b (h w) c"), | |
| nn.Linear(hidden_dim, target_channel_size), | |
| ) | |
| elif self.source_size[1] == 64 and self.target_size[1] == 16: # (64, 64) to (16, 16) | |
| self.adapter = nn.Sequential( | |
| nn.LayerNorm(self.source_size), | |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=3, stride=2, padding=1), # 32 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 32, 32]), | |
| nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1), # 16 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 16, 16]), | |
| Rearrange("b c h w-> b (h w) c"), | |
| nn.Linear(hidden_dim, target_channel_size), | |
| ) | |
| elif self.target_size[1] == 7: | |
| self.adapter = nn.Sequential( | |
| nn.LayerNorm(self.source_size), | |
| nn.Conv2d(source_channel_size, hidden_dim, kernel_size=4, stride=2, padding=1), #14x14 -> 7x7 | |
| nn.ReLU(), | |
| nn.LayerNorm([hidden_dim, 7, 7]), | |
| Rearrange("b c h w-> b (h w) c"), | |
| nn.Linear(hidden_dim, target_channel_size) | |
| ) | |
| else: | |
| NotImplementedError(f"{self.source_size} to {self.target_size} is not supported.") | |
| def forward(self, x: torch.Tensor, backbone_no_cls: bool = False) -> torch.Tensor: | |
| """Forward pass for ConvAdapter""" | |
| # x: [B, (1)+H*W, C] | |
| if not backbone_no_cls: | |
| x = x[:, 1:] | |
| x = self.pad(x) | |
| x = self.adapter(x) | |
| return x # [B, H*W, C] | |
| class FeatureTranslator(nn.Module): | |
| """Base class for the feature translator. | |
| The flow is backbone_adapter -> translator_stem -> translator_heads. | |
| Attributes: | |
| backbone_feature_size (torch.Size): the size of features of the backbone. | |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. | |
| translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048. | |
| target_model_names (list[str]): convenient attribute to hold all the names of the target models. | |
| backbone_adapter (nn.Module): the adapter to map channel dim of backbone to the translator hidden dim. | |
| translator_stem (nn.Module): the shared stem for all target models. | |
| translator_heads (nn.ModuleDict): specific heads for different target models. | |
| """ | |
| def __init__( | |
| self, | |
| backbone_feature_size: torch.Size, | |
| target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], | |
| translator_hidden_size: int = 1024, | |
| ) -> None: | |
| """Initalization function for FeatureTranslator. | |
| Args: | |
| backbone_feature_size (torch.Size): the size of features of the backbone. | |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. | |
| translator_hidden_size (int): the hidden dim of the translator. Defaults to 2048. | |
| """ | |
| super().__init__() | |
| self.backbone_feature_size = backbone_feature_size # (C, H, W) | |
| self.target_feature_sizes = target_feature_sizes # [(C, H, W)] | |
| self.translator_hidden_size = translator_hidden_size # C | |
| self.target_model_names = list(target_feature_sizes.keys()) | |
| self.legit_target_model_name_map: dict[str, str] = {t: t.replace(".", "_") for t in self.target_model_names} | |
| self.translator_heads: nn.ModuleDict = None | |
| self.backbone_adapter = nn.Sequential( | |
| nn.LayerNorm(self.backbone_feature_size[0]), # do a pre-norm | |
| nn.Linear( | |
| self.backbone_feature_size[0], # C in [C,H,W] | |
| self.translator_hidden_size, | |
| ), | |
| ) | |
| self.translator_stem: nn.Module = nn.Identity() | |
| self.build_translator_heads() | |
| def build_translator_heads(self) -> None: | |
| """Build translator heads to match the dimension of each target feature set. | |
| Example: | |
| translator_heads: dict[str, nn.Module] = ... | |
| self.translator_heads = nn.ModuleDict(translator_heads) | |
| """ | |
| raise NotImplementedError("build_translator_heads() should be overridden") | |
| def forward( | |
| self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False | |
| ) -> torch.Tensor: | |
| """Forward pass for a base feature translator. | |
| Args: | |
| x (torch.Tensor): input features from the backbone. [B, (1)+H*W, C]. | |
| (1) means optional CLS token. If `backbone_no_cls==True`, then [B, H*W, C]. | |
| target_model_names (Optional[list[str]]): names of the target models. | |
| backbone_no_cls (bool): indicate backbone has cls token or not. | |
| Can use it to customize whether to drop cls. | |
| Returns: | |
| dict[str, torch.Tensor]: predicted features for target models. | |
| """ | |
| # x: [B, (1)+H*W, C] | |
| x = self.backbone_adapter(x) | |
| x = self.translator_stem(x) | |
| target_model_names = target_model_names if target_model_names is not None else self.target_model_names | |
| features = {t: self.translator_heads[self.legit_target_model_name_map[t]](x, backbone_no_cls=backbone_no_cls) for t in target_model_names} | |
| return features | |
| class MLPFeatureTranslator(FeatureTranslator): | |
| def __init__( | |
| self, | |
| backbone_feature_size: torch.Size, | |
| target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], | |
| translator_hidden_size: int = 1024, | |
| translator_n_layer: int = 3, | |
| ) -> None: | |
| """Initalization function for MLPFeatureTranslator. | |
| Args: | |
| backbone_feature_size (torch.Size): the size of features of the backbone. | |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. | |
| translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048. | |
| translator_n_layer (int): number of MLP layers. Defaults to 3. | |
| """ | |
| self.translator_n_layer = translator_n_layer | |
| super().__init__( | |
| backbone_feature_size=backbone_feature_size, | |
| target_feature_sizes=target_feature_sizes, | |
| translator_hidden_size=translator_hidden_size, | |
| ) | |
| def build_translator_heads(self) -> nn.ModuleDict: | |
| """Build MLP translator heads to match the dimension of each target feature set.""" | |
| translator_heads = {} | |
| source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:]) | |
| for target_model, target_size in self.target_feature_sizes.items(): | |
| head = MLPAdapterHead(source_size=source_size, target_size=target_size, num_layer=self.translator_n_layer) | |
| translator_heads[self.legit_target_model_name_map[target_model]] = head | |
| self.translator_heads = nn.ModuleDict(translator_heads) | |
| class ConvFeatureTranslator(FeatureTranslator): | |
| def __init__( | |
| self, | |
| backbone_feature_size: torch.Size, | |
| target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], | |
| translator_hidden_size: int = 1024, | |
| ) -> None: | |
| """Initalization function for ConvFeatureTranslator. | |
| Args: | |
| backbone_feature_size (torch.Size): the size of features of the backbone. | |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. | |
| translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 2048. | |
| """ | |
| super().__init__( | |
| backbone_feature_size=backbone_feature_size, | |
| target_feature_sizes=target_feature_sizes, | |
| translator_hidden_size=translator_hidden_size, | |
| ) | |
| def build_translator_heads(self) -> nn.ModuleDict: | |
| """Build translator heads to match the dimension of each target feature set. | |
| Returns: | |
| nn.ModuleDict: the translator heads. | |
| """ | |
| translator_heads = {} | |
| source_size = (self.translator_hidden_size, *self.backbone_feature_size[1:]) | |
| for target_model, target_size in self.target_feature_sizes.items(): | |
| head = ConvAdapterHead(source_size=source_size, target_size=target_size) | |
| translator_heads[self.legit_target_model_name_map[target_model]] = head | |
| self.translator_heads = nn.ModuleDict(translator_heads) | |
| class LightConvFeatureTranslator(FeatureTranslator): | |
| def __init__( | |
| self, | |
| backbone_feature_size: torch.Size, | |
| target_feature_sizes: dict[str, torch.Size | tuple[int, ...]], | |
| translator_hidden_size: int = 1024, | |
| hidden_size_factor: int | float = 1.0, | |
| ) -> None: | |
| """Initalization function for LightConvFeatureTranslator. | |
| It's for a smaller translator compared to ConvFeatureTranslator. | |
| Args: | |
| backbone_feature_size (torch.Size): the size of features of the backbone. | |
| target_feature_sizes (dict[str, torch.Size | tuple[int, ...]]): the sizes of features of target models. | |
| translator_hidden_size (Optional[int]): the hidden dim of the translator. Defaults to 1024. | |
| hidden_size_factor: the size of hidden dim of feature translator | |
| as a factor of input feature hidden dim. Defaults to 1.0 | |
| """ | |
| self.hidden_size_factor = hidden_size_factor | |
| super().__init__( | |
| backbone_feature_size=backbone_feature_size, | |
| target_feature_sizes=target_feature_sizes, | |
| translator_hidden_size=translator_hidden_size, | |
| ) | |
| self.backbone_adapter = nn.Identity() | |
| def build_translator_heads(self) -> nn.ModuleDict: | |
| """Build translator heads to match the dimension of each target feature set. | |
| Returns: | |
| nn.ModuleDict: the translator heads. | |
| """ | |
| translator_heads = {} | |
| for target_model, target_size in self.target_feature_sizes.items(): | |
| if "_cls" in target_model: | |
| head = LinearAdapterHead( | |
| source_size=self.backbone_feature_size, | |
| target_size=target_size | |
| ) | |
| else: | |
| head = LightConvAdapterHead( | |
| source_size=self.backbone_feature_size, | |
| target_size=target_size, | |
| hidden_size_factor=self.hidden_size_factor | |
| ) | |
| translator_heads[self.legit_target_model_name_map[target_model]] = head | |
| self.translator_heads = nn.ModuleDict(translator_heads) | |
| class TransformerFreatureTranslator(FeatureTranslator): | |
| def __init__( | |
| self, | |
| backbone_feature_size: torch.Size, | |
| target_feature_sizes: dict[str, torch.Size | tuple[int, int]], | |
| translator_hidden_size: int = 1024, | |
| translator_n_layers: int = 2, | |
| translator_n_heads: int = 8, | |
| translator_activation: str = "gelu", | |
| ) -> None: | |
| super().__init__( | |
| backbone_feature_size=backbone_feature_size, | |
| target_feature_sizes=target_feature_sizes, | |
| translator_hidden_size=translator_hidden_size, | |
| ) | |
| self.translator_stem = nn.TransformerDecoder( | |
| nn.TransformerDecoderLayer( | |
| d_model=translator_hidden_size, | |
| nhead=translator_n_heads, | |
| dim_feedforward=translator_hidden_size * 2, | |
| activation=translator_activation, | |
| batch_first=True, | |
| norm_first=True, | |
| ), | |
| num_layers=translator_n_layers, | |
| ) | |
| self.decode_tokens = nn.Parameter( | |
| torch.randn((1, math.prod(self.backbone_feature_size[1:]), translator_hidden_size)) | |
| ) | |
| self.target_model_emb = nn.ParameterDict( | |
| { | |
| self.legit_target_model_name_map[t]: torch.randn(1, 1, translator_hidden_size) | |
| for t in self.target_model_names | |
| } | |
| ) | |
| def build_translator_heads(self) -> None: | |
| """Build Transformer translator heads to match the dimension of each target feature set.""" | |
| translator_heads = {} | |
| for target_model, target_size in self.target_feature_sizes.items(): | |
| head = MLPAdapterHead( | |
| source_size=(self.translator_hidden_size, *self.backbone_feature_size[1:]), | |
| target_size=target_size, | |
| num_layer=2, | |
| ) | |
| translator_heads[self.legit_target_model_name_map[target_model]] = head | |
| self.translator_heads = nn.ModuleDict(translator_heads) | |
| def forward( | |
| self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, backbone_no_cls: bool = False | |
| ) -> torch.Tensor: | |
| """Forward pass for a simple linear translator. | |
| Args: | |
| x (torch.Tensor): input features from the backbone. | |
| target_model_names (Optional[str]): names of the target models. | |
| backbone_no_cls (bool): indicate backbone has cls token or not. | |
| Can use it to customize whether to drop cls. | |
| Returns: | |
| dict[str, torch.Tensor]: predicted features for target models. | |
| """ | |
| if not backbone_no_cls: | |
| x = x[:, 1:] | |
| x = self.backbone_adapter(x) | |
| features = {} | |
| target_model_names = target_model_names if target_model_names is not None else self.target_model_names | |
| for t in target_model_names: | |
| feature = self.translator_stem( | |
| torch.cat( | |
| [ | |
| self.decode_tokens.repeat(x.size(0), 1, 1), | |
| self.target_model_emb[self.legit_target_model_name_map[t]].repeat(x.size(0), 1, 1), | |
| ], | |
| dim=1, | |
| ), | |
| memory=x, | |
| )[:, 1:, ...] | |
| features[t] = self.translator_heads[self.legit_target_model_name_map[t]](feature) | |
| return features | |
| def build_feature_translator(translator_type: str, **kwargs: Any) -> FeatureTranslator: | |
| """Handy function to build feature translators given the type | |
| Args: | |
| translator_type (str): the type of the translator, | |
| one in `"mlp"`, `"conv"`, `"lconv"`, `"transformer"` (or `"trans"`). | |
| At the moment we are actively using `"lconv"`. | |
| Returns: | |
| FeatureTranslator: the corresponding FeatureTranslator | |
| """ | |
| if translator_type == "mlp": | |
| return MLPFeatureTranslator(**kwargs) | |
| elif translator_type == "conv": | |
| return ConvFeatureTranslator(**kwargs) | |
| elif translator_type == "lconv": | |
| return LightConvFeatureTranslator(**kwargs) | |
| elif translator_type == "transformer" or translator_type == "trans": | |
| return TransformerFreatureTranslator(**kwargs) | |
| else: | |
| raise NotImplementedError(f"Requested {translator_type} is not implemented yet.") | |
| class TheiaConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| backbone: str | nn.Module = "facebook/deit-tiny-patch16-224", | |
| pretrained: bool = False, | |
| target_feature_sizes: Optional[dict[str, torch.Size | tuple[int, ...]]] = None, | |
| translator_type: str = "lconv", | |
| translator_hidden_size_factor: float | int = 1.0, | |
| target_loss_weights: Optional[dict[str, float]] = None, | |
| feature_reduce_method: Optional[str] = None, | |
| feature_neck: bool = False, | |
| feature_neck_hidden_dim: int = 256, | |
| forward_neck: bool = False, | |
| feature_neck_nonlinearity: str = "relu", | |
| iamge_size: int = 224, | |
| num_reg_tokens: int = 0, | |
| **kwargs: Any | |
| ): | |
| self.backbone = backbone | |
| self.pretrained = pretrained | |
| self.target_feature_sizes = target_feature_sizes | |
| self.translator_type = translator_type | |
| self.translator_hidden_size_factor = translator_hidden_size_factor | |
| self.target_loss_weights = target_loss_weights | |
| self.feature_reduce_method = feature_reduce_method | |
| self.feature_neck = feature_neck | |
| self.feature_neck_hidden_dim = feature_neck_hidden_dim | |
| self.forward_neck = forward_neck | |
| self.feature_neck_nonlinearity = feature_neck_nonlinearity | |
| self.image_size = 224 | |
| self.num_reg_tokens = num_reg_tokens | |
| super().__init__(**kwargs) | |
| class TheiaModel(PreTrainedModel): | |
| config_class = TheiaConfig | |
| def __init__(self, config: TheiaConfig): | |
| super().__init__(config) | |
| self.target_feature_sizes = config.target_feature_sizes | |
| self.preprocessor = None | |
| self.pretrained = config.pretrained | |
| # backbone | |
| self.image_size = config.image_size | |
| if "reg" in config.backbone: | |
| self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size, num_reg_tokens = config.num_reg_tokens) | |
| else: | |
| self.backbone: nn.Module = build_backbone(config.backbone, config.pretrained, image_size=config.image_size) | |
| # handle output feature (feature reduce) | |
| self.feature_reduce_method = config.feature_reduce_method | |
| self.no_cls = hasattr(self.backbone, "no_cls") | |
| self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0 | |
| # translator | |
| backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True) | |
| if self.target_feature_sizes: | |
| translator_kwargs = { | |
| "hidden_size_factor": config.translator_hidden_size_factor | |
| } | |
| translator_kwargs["backbone_feature_size"] = backbone_feature_size | |
| translator_kwargs["target_feature_sizes"] = config.target_feature_sizes | |
| self.translator = build_feature_translator( | |
| config.translator_type, **translator_kwargs | |
| ) | |
| else: | |
| self.translator = None | |
| self.feature_neck = config.feature_neck | |
| self.feature_neck_hidden_dim = config.feature_neck_hidden_dim | |
| self.forward_neck = config.forward_neck | |
| if self.feature_neck: | |
| num_tokens_edge = self.backbone.model.config.image_size // self.backbone.model.config.patch_size | |
| self.neck = nn.Sequential( | |
| Rearrange("b (h w) c -> b c h w", h=num_tokens_edge, w=num_tokens_edge), | |
| nn.Conv2d(self.backbone.model.config.hidden_size, self.feature_neck_hidden_dim, kernel_size=4, stride=2, padding=1), #14x14 -> 7x7 | |
| nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), # just to keep the same as super class | |
| nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=2), #7x7 -> 3x3 | |
| nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), | |
| nn.Conv2d(self.feature_neck_hidden_dim, self.feature_neck_hidden_dim, kernel_size=3, stride=1), #3x3 -> 1x1 | |
| nn.ReLU() if config.feature_neck_nonlinearity == 'relu' else nn.Tanh(), | |
| nn.Flatten() | |
| ) | |
| else: | |
| self.neck = None | |
| # loss | |
| self.mse_loss = nn.MSELoss() | |
| self.l1_loss = nn.SmoothL1Loss() | |
| self.cos_loss = nn.CosineEmbeddingLoss() | |
| self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False) | |
| self.target_loss_weights = config.target_loss_weights | |
| def load_pretrained_weights(self, checkpoint_path: str) -> None: | |
| """ | |
| Load weights from `checkpoint_path` manually. | |
| Args: | |
| checkpoint_path (str): path to the weights. | |
| """ | |
| # load theia weights | |
| if checkpoint_path: | |
| weights_dict = torch.load(checkpoint_path, map_location="cpu") | |
| # Filter out unnecessary keys | |
| pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()} | |
| self.load_state_dict(pretrained_dict, strict=False) | |
| def freeze_translator(self) -> None: | |
| """Freeze feature translators `self.translator`.""" | |
| if self.translator is not None: | |
| for param in self.translator.parameters(): | |
| param.requires_grad = False | |
| def freeze_backbone(self) -> None: | |
| """Freeze backbone (encoder) `self.backbone`. """ | |
| self.freeze_encoder() | |
| def freeze_encoder(self) -> None: | |
| """Freeze backbone (encoder) `self.backbone`. """ | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| def freeze_neck(self) -> None: | |
| """Freeze feature neck `self.neck`.""" | |
| if self.neck is not None: | |
| for param in self.neck.parameters(): | |
| param.requires_grad = False | |
| def freeze_everything(self) -> None: | |
| """Freeze all parameters in the model.""" | |
| self.freeze_translator() | |
| self.freeze_neck() | |
| self.freeze_encoder() | |
| def unfreeze_translator(self) -> None: | |
| if self.translator is not None: | |
| for param in self.translator.parameters(): | |
| param.requires_grad = True | |
| def unfreeze_backbone(self) -> None: | |
| "Set parameters in backbone (encoder) `self.backbone` trainable." | |
| self.unfreeze_encoder() | |
| def unfreeze_encoder(self) -> None: | |
| "Set parameters in backbone (encoder) `self.backbone` trainable." | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = True | |
| def unfreeze_neck(self) -> None: | |
| "Set parameters in feature neck `self.neck` trainable." | |
| if self.neck is not None: | |
| for param in self.neck.parameters(): | |
| param.requires_grad = True | |
| def unfreeze_everything(self) -> None: | |
| """Set all parameters trainable.""" | |
| self.unfreeze_translator() | |
| self.unfreeze_neck() | |
| self.unfreeze_encoder() | |
| def set_forward_neck(self, forward_neck: bool = True) -> None: | |
| """ | |
| Set `self.forward_neck` to `forward_neck` value. | |
| Args: | |
| forward_neck (bool): whether forward the feature through the random initialized neck. | |
| If set to True, the output from `self.forward()` will be in shape [batch_size, self.config.feature_neck_hidden_dim] | |
| """ | |
| self.forward_neck = forward_neck | |
| def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: | |
| """Forward RVFM feature only (before translators). | |
| Args: | |
| x (torch.Tensor): input image. By default it accepts images | |
| in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8. | |
| kwargs (Any): kwargs including mainly those for huggingface preprocessor: | |
| `do_resize` (bool) defaults to True. | |
| `interpolate_pos_encoding` (Optional[bool]) defaults to None. | |
| `do_rescale` (bool) defaults to True. | |
| `do_normalize` (bool) defaults to True. | |
| Returns: | |
| torch.Tensor: RVFM feature. | |
| """ | |
| feature = self.backbone(x, **kwargs) | |
| # [B, 1+H*W+N, C] if including both CLS and register tokens. | |
| # [B, 1+H*W, C] for standard model (N=0). | |
| # [B, H*W, C] for model without CLS. | |
| return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens) | |
| def forward(self, x: torch.Tensor, target_model_names: Optional[list[str]] = None, **kwargs: Any) -> dict[str, torch.Tensor] | torch.Tensor: | |
| """Forward pass of Robot Vision Foundation Model. | |
| Args: | |
| x (torch.Tensor): input image. By default it accepts images | |
| in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8. | |
| target_model_names (Optional[list[str]]): names of the target foundation models. | |
| kwargs (Any): kwargs including mainly those for huggingface preprocessor: | |
| `do_resize` (bool) defaults to True. | |
| `interpolate_pos_encoding` (Optional[bool]) defaults to None. | |
| `do_rescale` (bool) defaults to True. | |
| `do_normalize` (bool) defaults to True. | |
| Returns: | |
| if `self.forward_neck`: | |
| torch.Tensor: compact vector feature passed through the neck. [B, C_neck] | |
| else: | |
| dict[str, torch.Tensor]: features that match to each foundation model. | |
| Each feature is in [B, (H*W), C] or [B, C]. | |
| """ | |
| if self.forward_neck: | |
| x = self.forward_feature(x) | |
| return self.neck(x) | |
| else: | |
| x = self.backbone(x, **kwargs) | |
| if self.num_reg_tokens > 0: | |
| x = x[:, :-self.num_reg_tokens] # [B, (1)+H*W, C] | |
| features = self.translator(x, target_model_names, backbone_no_cls=self.no_cls) # each is [B, H*W, C] or [B, C] | |
| return features | |
| def get_loss(self, pred_features: dict[str, torch.Tensor], y: dict[str, torch.Tensor]) -> dict[str, Any]: | |
| """Get loss terms given predictions and targets. | |
| Args: | |
| pred_features (dict[str, torch.Tensor]): predictions. | |
| y (dict[str, torch.Tensor]): targets. | |
| Returns: | |
| tuple[Any, ...]: loss terms | |
| """ | |
| mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0 | |
| mse_losses_per_model = {} | |
| cos_losses_per_model = {} | |
| l1_losses_per_model = {} | |
| for t in pred_features: | |
| pred = pred_features[t] | |
| target = y[t] | |
| # mse loss | |
| mse_loss = self.mse_loss(pred, target) | |
| weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features) | |
| # l1 loss | |
| l1_loss = self.l1_loss(pred, target) | |
| # cos loss | |
| pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2) | |
| target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2) | |
| target = self.cos_target.repeat(pred.size(0)).to(pred.device) | |
| cos_loss = self.cos_loss(pred_norm, target_norm, target) | |
| mse_loss_avg += mse_loss * weight | |
| cos_loss_avg += cos_loss / len(pred_features) # balance cos by default for meaningful eval | |
| l1_loss_avg += l1_loss * weight | |
| mse_losses_per_model[t] = mse_loss.item() | |
| cos_losses_per_model[t] = cos_loss.item() | |
| l1_losses_per_model[t] = l1_loss.item() | |
| return { | |
| "mse_loss": mse_loss_avg, | |
| "cos_loss": cos_loss_avg, | |
| "l1_loss": l1_loss_avg, | |
| "mse_losses_per_model": mse_losses_per_model, | |
| "cos_losses_per_model": cos_losses_per_model, | |
| "l1_losses_per_model": l1_losses_per_model, | |
| } |