# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Open-vocabulary CLIP alignment for inference. Self-contained extract of the training repo's ``CLIPAlignmentEval``: encode class names into a text-embedding matrix (optionally with prompt ensembling) and score per-query CLIP features against it via cosine similarity. """ import logging from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F log = logging.getLogger(__name__) class CLIPAlignmentEval(nn.Module): """Cosine-similarity classifier between per-query CLIP features and text embeddings. Args: normalize_input: L2-normalize the query features before the cosine product. For SpaceFormer set this to ``False`` — the clip head output is already compared directly (matches the official eval recipe). """ def __init__(self, normalize_input: bool = False): super().__init__() self.normalize_input = normalize_input self.emb_target: Optional[torch.Tensor] = None # [C, D] L2-normalized def set_target_embedding(self, text_embeddings: torch.Tensor) -> None: self.emb_target = text_embeddings.float() def forward(self, x: torch.Tensor) -> torch.Tensor: if self.normalize_input: return F.normalize(x, p=2, dim=1) return x def predict(self, x: torch.Tensor, return_logit: bool = False) -> torch.Tensor: """Score features ``x`` [Q, D] against the text embeddings -> [Q, C].""" assert self.emb_target is not None, "call prepare_target_embedding() first" pred = self.forward(x) logit = torch.matmul(pred, self.emb_target.t().to(pred.dtype)) if return_logit: return logit return logit.argmax(dim=1) @torch.inference_mode() def prepare_target_embedding( self, class_names: List[str], clip_encoder: nn.Module, device: torch.device, use_prompt: bool = False, prompt_template: Optional[str] = None, prompt_templates: Optional[List[str]] = None, ) -> None: """Encode ``class_names`` into the [C, D] target matrix. Three mutually exclusive prompting modes (first non-empty wins): - ``prompt_templates``: prompt ensembling — render each class under every ``"... {} ..."`` template, per-row L2-normalize, mean, re-normalize. This is the recommended eval-time free win. - ``prompt_template``: a single ``"... {c} ..."`` format string. - ``use_prompt``: the OpenScene default ``"a {} in a scene"``. The token ``"other"`` is always encoded bare (no template) so the background/void class stays neutral. """ log.info("Preparing CLIP target embedding for %d classes", len(class_names)) if prompt_templates: log.info("Prompt ensembling over %d templates", len(prompt_templates)) ensembled = None for template in prompt_templates: rendered = [ template.format(c) if "other" not in c else "other" for c in class_names ] emb = F.normalize(clip_encoder(rendered, normalize=True).float(), p=2, dim=-1) ensembled = emb if ensembled is None else ensembled + emb text_embedding = F.normalize(ensembled / float(len(prompt_templates)), p=2, dim=-1) elif prompt_template is not None: rendered = [prompt_template.format(c=c) for c in class_names] text_embedding = clip_encoder(rendered, normalize=True) elif use_prompt: rendered = [f"a {c} in a scene" if "other" not in c else "other" for c in class_names] text_embedding = clip_encoder(rendered, normalize=True) else: text_embedding = clip_encoder(class_names, normalize=True) self.set_target_embedding(text_embedding.to(device))