import torch from sentence_transformers.base.modules import Transformer # The processor outputs float32 tensors (pixel_values, input_features) by default, # but this model's weights are bfloat16. PyTorch does not cast dtypes automatically, # so this subclass casts all floating-point feature tensors to the model dtype before # the forward pass to avoid dtype mismatch errors. class BidirLMOmniTransformer(Transformer): def forward(self, features, **kwargs): model_dtype = next(self.model.parameters()).dtype for key, value in features.items(): if isinstance(value, torch.Tensor) and value.is_floating_point() and value.dtype != model_dtype: features[key] = value.to(dtype=model_dtype) return super().forward(features, **kwargs)