Sentence Similarity
sentence-transformers
Safetensors
Transformers
bidirlm_omni
mteb
embedding
bidirectional
custom_code
Instructions to use BidirLM/BidirLM-Omni-2.5B-Embedding with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use BidirLM/BidirLM-Omni-2.5B-Embedding with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("BidirLM/BidirLM-Omni-2.5B-Embedding", trust_remote_code=True) sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Transformers
How to use BidirLM/BidirLM-Omni-2.5B-Embedding with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("BidirLM/BidirLM-Omni-2.5B-Embedding", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| from typing import Callable, Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.activations import ACT2FN | |
| from transformers.integrations import use_kernel_forward_from_hub | |
| from transformers.modeling_layers import GradientCheckpointingLayer | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutput, | |
| MaskedLMOutput, | |
| SequenceClassifierOutput, | |
| TokenClassifierOutput, | |
| ) | |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS | |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel | |
| from transformers.processing_utils import Unpack | |
| from transformers.utils import auto_docstring | |
| from transformers.utils.generic import TransformersKwargs | |
| try: | |
| from .configuration_bidirlm_omni import ( | |
| BidirLMOmniAudioConfig, | |
| BidirLMOmniConfig, | |
| BidirLMOmniTextConfig, | |
| BidirLMOmniVisionConfig, | |
| ) | |
| except ImportError: | |
| from configuration_bidirlm_omni import ( | |
| BidirLMOmniAudioConfig, | |
| BidirLMOmniConfig, | |
| BidirLMOmniTextConfig, | |
| BidirLMOmniVisionConfig, | |
| ) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Shared utilities | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
| if n_rep == 1: | |
| return hidden_states | |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
| def eager_attention_forward( | |
| module: nn.Module, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| scaling: float, | |
| dropout: float = 0.0, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ): | |
| key_states = repeat_kv(key, module.num_key_value_groups) | |
| value_states = repeat_kv(value, module.num_key_value_groups) | |
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling | |
| if attention_mask is not None: | |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | |
| attn_weights = attn_weights + causal_mask | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) | |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | |
| attn_output = torch.matmul(attn_weights, value_states) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| return attn_output, attn_weights | |
| def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): | |
| cos = cos.unsqueeze(unsqueeze_dim) | |
| sin = sin.unsqueeze(unsqueeze_dim) | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| def _get_feat_extract_output_lengths(input_lengths): | |
| # Three Conv2d layers each with kernel=3, stride=2, padding=1. | |
| # Per-layer formula: floor((L - 1) / 2) + 1 | |
| L = (input_lengths - 1) // 2 + 1 | |
| L = (L - 1) // 2 + 1 | |
| L = (L - 1) // 2 + 1 | |
| return L | |
| class BidirLMOmniRMSNorm(nn.Module): | |
| def __init__(self, hidden_size, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| return self.weight * hidden_states.to(input_dtype) | |
| def extra_repr(self): | |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Audio encoder (copied from BiQwen3-SP) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniAudioAttention(nn.Module): | |
| def __init__(self, config: BidirLMOmniAudioConfig): | |
| super().__init__() | |
| self.embed_dim = config.d_model | |
| self.num_heads = config.encoder_attention_heads | |
| self.head_dim = self.embed_dim // self.num_heads | |
| self.num_key_value_groups = 1 | |
| self.config = config | |
| self.scaling = self.head_dim**-0.5 | |
| self.attention_dropout = 0.0 | |
| self.is_causal = False | |
| if (self.head_dim * self.num_heads) != self.embed_dim: | |
| raise ValueError( | |
| f"embed_dim must be divisible by num_heads " | |
| f"(got embed_dim={self.embed_dim}, num_heads={self.num_heads})." | |
| ) | |
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| seq_length, _ = hidden_states.size() | |
| query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) | |
| key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) | |
| value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) | |
| query_states = query_states.transpose(0, 1).unsqueeze(0) | |
| key_states = key_states.transpose(0, 1).unsqueeze(0) | |
| value_states = value_states.transpose(0, 1).unsqueeze(0) | |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() | |
| attention_interface: Callable = eager_attention_forward | |
| if self.config._attn_implementation != "eager": | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| attn_output, _ = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask=attention_mask, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| scaling=self.scaling, | |
| cu_seq_lens_q=cu_seqlens, | |
| cu_seq_lens_k=cu_seqlens, | |
| max_length_q=max_seqlen, | |
| max_length_k=max_seqlen, | |
| is_causal=False, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.reshape(seq_length, -1).contiguous() | |
| return self.out_proj(attn_output) | |
| class BidirLMOmniAudioEncoderLayer(GradientCheckpointingLayer): | |
| def __init__(self, config: BidirLMOmniAudioConfig): | |
| super().__init__() | |
| self.embed_dim = config.d_model | |
| self.self_attn = BidirLMOmniAudioAttention(config) | |
| self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) | |
| self.dropout = config.dropout | |
| self.activation_fn = ACT2FN[config.activation_function] | |
| self.activation_dropout = config.activation_dropout | |
| self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) | |
| self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) | |
| self.final_layer_norm = nn.LayerNorm(self.embed_dim) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor]: | |
| residual = hidden_states | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | |
| hidden_states = self.self_attn( | |
| hidden_states=hidden_states, | |
| cu_seqlens=cu_seqlens, | |
| attention_mask=attention_mask, | |
| **kwargs, | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| hidden_states = self.activation_fn(self.fc1(hidden_states)) | |
| hidden_states = self.fc2(hidden_states) | |
| hidden_states = residual + hidden_states | |
| if hidden_states.dtype == torch.float16: | |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 | |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | |
| return (hidden_states,) | |
| class SinusoidsPositionEmbedding(nn.Module): | |
| def __init__(self, length, channels, max_timescale=10000): | |
| super().__init__() | |
| if channels % 2 != 0: | |
| raise ValueError("SinusoidsPositionEmbedding needs even channels input") | |
| # Store scalars so forward can recompute in float32. | |
| # tf5 casts persistent=False buffers to the model dtype (bfloat16), | |
| # which degrades sinusoidal precision and diverges from tf4 numerics. | |
| self.length = length | |
| self.channels = channels | |
| self.max_timescale = max_timescale | |
| # Register a dummy buffer only for device tracking. | |
| log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) | |
| inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) | |
| scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] | |
| self.register_buffer( | |
| "positional_embedding", | |
| torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), | |
| persistent=False, | |
| ) | |
| def _recompute(self, seqlen: int, device) -> torch.Tensor: | |
| log_timescale_increment = np.log(self.max_timescale) / (self.channels // 2 - 1) | |
| inv_timescales = torch.exp( | |
| -log_timescale_increment * torch.arange(self.channels // 2, dtype=torch.float32, device=device) | |
| ) | |
| scaled_time = torch.arange(seqlen, dtype=torch.float32, device=device)[:, None] * inv_timescales[None, :] | |
| return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) | |
| def forward(self, seqlen: int): | |
| # Recompute in float32 every call — do NOT use the stored buffer whose | |
| # values may have been cast to bfloat16 by from_pretrained. | |
| return self._recompute(seqlen, self.positional_embedding.device) | |
| class BidirLMOmniAudioEncoder(PreTrainedModel): | |
| config: BidirLMOmniAudioConfig | |
| main_input_name = "input_features" | |
| _no_split_modules = ["BidirLMOmniAudioEncoderLayer"] | |
| _supports_sdpa = True | |
| def __init__(self, config: BidirLMOmniAudioConfig): | |
| super().__init__(config) | |
| self.dropout = config.dropout | |
| embed_dim = config.d_model | |
| self.num_mel_bins = config.num_mel_bins | |
| self.max_source_positions = config.max_source_positions | |
| self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 | |
| self.n_window = config.n_window | |
| self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) | |
| self.layers = nn.ModuleList([BidirLMOmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) | |
| self.ln_post = nn.LayerNorm(config.d_model) | |
| self.gradient_checkpointing = False | |
| self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1) | |
| self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) | |
| self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) | |
| self.conv_out = nn.Linear( | |
| config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), | |
| config.d_model, | |
| bias=False, | |
| ) | |
| self.proj1 = nn.Linear(config.d_model, config.d_model) | |
| self.act = ACT2FN[config.activation_function] | |
| self.proj2 = nn.Linear(config.d_model, config.output_dim) | |
| self.n_window_infer = self.config.n_window_infer | |
| self.conv_chunksize = self.config.conv_chunksize | |
| self.post_init() | |
| def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> Optional[torch.Tensor]: | |
| if self.config._attn_implementation == "flash_attention_2": | |
| return None | |
| seq_length = inputs_tensor.shape[0] | |
| attention_mask = torch.full( | |
| [1, 1, seq_length, seq_length], | |
| torch.finfo(inputs_tensor.dtype).min, | |
| device=inputs_tensor.device, | |
| dtype=inputs_tensor.dtype, | |
| ) | |
| for i in range(1, len(cu_seqlens)): | |
| attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 | |
| return attention_mask | |
| def forward(self, input_features, feature_lens=None, aftercnn_lens=None): | |
| aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) | |
| chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() | |
| chunk_lengths = torch.tensor( | |
| [self.n_window * 2] * chunk_num.sum(), | |
| dtype=torch.long, | |
| device=feature_lens.device, | |
| ) | |
| tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] | |
| chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) | |
| chunk_lengths[chunk_lengths == 0] = self.n_window * 2 | |
| chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) | |
| padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) | |
| feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) | |
| padded_mask_after_cnn = nn.utils.rnn.pad_sequence( | |
| [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], | |
| batch_first=True, | |
| ) | |
| padded_feature = padded_feature.unsqueeze(1) | |
| padded_embeds = [] | |
| for chunk in padded_feature.split(self.conv_chunksize, dim=0): | |
| padded_embed = F.gelu(self.conv2d1(chunk)) | |
| padded_embed = F.gelu(self.conv2d2(padded_embed)) | |
| padded_embed = F.gelu(self.conv2d3(padded_embed)) | |
| padded_embeds.append(padded_embed) | |
| padded_embed = torch.cat(padded_embeds, dim=0) | |
| b, c, f, t = padded_embed.size() | |
| padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) | |
| # Call forward() which recomputes sinusoids in float32 (avoids bfloat16 buffer precision loss). | |
| positional_embedding = ( | |
| self.positional_embedding(padded_embed.shape[1]) | |
| .unsqueeze(0) | |
| .to(padded_embed.dtype) | |
| ) | |
| padded_embed = padded_embed + positional_embedding | |
| hidden_states = padded_embed[padded_mask_after_cnn] | |
| cu_chunk_lens = [0] | |
| window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) | |
| for cnn_len in aftercnn_lens: | |
| cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) | |
| remainder = cnn_len % window_aftercnn | |
| if remainder != 0: | |
| cu_chunk_lens += [remainder] | |
| cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) | |
| for encoder_layer in self.layers: | |
| layer_outputs = encoder_layer(hidden_states, cu_seqlens) | |
| hidden_states = layer_outputs[0] | |
| hidden_states = self.ln_post(hidden_states) | |
| hidden_states = self.proj1(hidden_states) | |
| hidden_states = self.act(hidden_states) | |
| hidden_states = self.proj2(hidden_states) | |
| return BaseModelOutput(last_hidden_state=hidden_states) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Vision encoder (copied from BiQwen3-VL) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniVisionMLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.intermediate_size | |
| self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) | |
| self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward(self, hidden_state): | |
| return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) | |
| class BidirLMOmniVisionPatchEmbed(nn.Module): | |
| def __init__(self, config) -> None: | |
| super().__init__() | |
| self.patch_size = config.patch_size | |
| self.temporal_patch_size = config.temporal_patch_size | |
| self.in_channels = config.in_channels | |
| self.embed_dim = config.hidden_size | |
| kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] | |
| self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| target_dtype = self.proj.weight.dtype | |
| hidden_states = hidden_states.view( | |
| -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size | |
| ) | |
| hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) | |
| return hidden_states | |
| class BidirLMOmniVisionRotaryEmbedding(nn.Module): | |
| inv_freq: torch.Tensor | |
| def __init__(self, dim: int, theta: float = 10000.0) -> None: | |
| super().__init__() | |
| # Store theta/dim so forward can recompute inv_freq in float32. | |
| # We still register a buffer for device tracking, but don't use its values. | |
| self.dim = dim | |
| self.theta = theta | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def forward(self, seqlen: int) -> torch.Tensor: | |
| # Recompute inv_freq in float32 — matches training precision. | |
| inv_freq = 1.0 / ( | |
| self.theta ** ( | |
| torch.arange(0, self.dim, 2, dtype=torch.float32, device=self.inv_freq.device) | |
| / self.dim | |
| ) | |
| ) | |
| seq = torch.arange(seqlen, dtype=torch.float32, device=self.inv_freq.device) | |
| return torch.outer(seq, inv_freq) | |
| class BidirLMOmniVisionPatchMerger(nn.Module): | |
| def __init__(self, config: BidirLMOmniVisionConfig, use_postshuffle_norm=False) -> None: | |
| super().__init__() | |
| self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) | |
| self.use_postshuffle_norm = use_postshuffle_norm | |
| self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) | |
| self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.act_fn = nn.GELU() | |
| self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) | |
| x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) | |
| return x | |
| def apply_rotary_pos_emb_vision( | |
| q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| orig_q_dtype = q.dtype | |
| orig_k_dtype = k.dtype | |
| q, k = q.float(), k.float() | |
| cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype) | |
| class BidirLMOmniVisionAttention(nn.Module): | |
| def __init__(self, config: BidirLMOmniVisionConfig) -> None: | |
| super().__init__() | |
| self.dim = config.hidden_size | |
| self.num_heads = config.num_heads | |
| self.head_dim = self.dim // self.num_heads | |
| self.num_key_value_groups = 1 | |
| self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) | |
| self.proj = nn.Linear(self.dim, self.dim) | |
| self.scaling = self.head_dim**-0.5 | |
| self.config = config | |
| self.attention_dropout = 0.0 | |
| self.is_causal = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| seq_length = hidden_states.shape[0] | |
| query_states, key_states, value_states = ( | |
| self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
| ) | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) | |
| query_states = query_states.transpose(0, 1).unsqueeze(0) | |
| key_states = key_states.transpose(0, 1).unsqueeze(0) | |
| value_states = value_states.transpose(0, 1).unsqueeze(0) | |
| attention_interface: Callable = eager_attention_forward | |
| if self.config._attn_implementation != "eager": | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| if self.config._attn_implementation == "flash_attention_2": | |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() | |
| attn_output, _ = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask=None, | |
| scaling=self.scaling, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| cu_seq_lens_q=cu_seqlens, | |
| cu_seq_lens_k=cu_seqlens, | |
| max_length_q=max_seqlen, | |
| max_length_k=max_seqlen, | |
| is_causal=self.is_causal, | |
| **kwargs, | |
| ) | |
| else: | |
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] | |
| splits = [ | |
| torch.split(tensor, lengths.tolist(), dim=2) | |
| for tensor in (query_states, key_states, value_states) | |
| ] | |
| attn_outputs = [ | |
| attention_interface( | |
| self, q, k, v, | |
| attention_mask=None, | |
| scaling=self.scaling, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| is_causal=self.is_causal, | |
| **kwargs, | |
| )[0] | |
| for q, k, v in zip(*splits) | |
| ] | |
| attn_output = torch.cat(attn_outputs, dim=1) | |
| attn_output = attn_output.reshape(seq_length, -1).contiguous() | |
| attn_output = self.proj(attn_output) | |
| return attn_output | |
| class BidirLMOmniVisionBlock(GradientCheckpointingLayer): | |
| def __init__(self, config, attn_implementation: str = "sdpa") -> None: | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) | |
| self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) | |
| self.attn = BidirLMOmniVisionAttention(config=config) | |
| self.mlp = BidirLMOmniVisionMLP(config=config) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.Tensor, | |
| position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| hidden_states = hidden_states + self.attn( | |
| self.norm1(hidden_states), | |
| cu_seqlens=cu_seqlens, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) | |
| return hidden_states | |
| class BidirLMOmniVisionModel(PreTrainedModel): | |
| config: BidirLMOmniVisionConfig | |
| _no_split_modules = ["BidirLMOmniVisionBlock"] | |
| def __init__(self, config, *inputs, **kwargs) -> None: | |
| super().__init__(config, *inputs, **kwargs) | |
| self.spatial_merge_size = config.spatial_merge_size | |
| self.patch_size = config.patch_size | |
| self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size | |
| self.patch_embed = BidirLMOmniVisionPatchEmbed(config=config) | |
| self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) | |
| self.num_grid_per_side = int(config.num_position_embeddings**0.5) | |
| head_dim = config.hidden_size // config.num_heads | |
| self.rotary_pos_emb = BidirLMOmniVisionRotaryEmbedding(head_dim // 2) | |
| self.blocks = nn.ModuleList([BidirLMOmniVisionBlock(config) for _ in range(config.depth)]) | |
| self.merger = BidirLMOmniVisionPatchMerger(config=config, use_postshuffle_norm=False) | |
| self.deepstack_visual_indexes = config.deepstack_visual_indexes | |
| self.deepstack_merger_list = nn.ModuleList( | |
| [ | |
| BidirLMOmniVisionPatchMerger(config=config, use_postshuffle_norm=True) | |
| for _ in range(len(config.deepstack_visual_indexes)) | |
| ] | |
| ) | |
| self.gradient_checkpointing = False | |
| def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: | |
| merge_size = self.spatial_merge_size | |
| max_hw = int(grid_thw[:, 1:].max().item()) | |
| freq_table = self.rotary_pos_emb(max_hw) | |
| device = freq_table.device | |
| total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) | |
| pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) | |
| offset = 0 | |
| for num_frames, height, width in grid_thw: | |
| merged_h, merged_w = height // merge_size, width // merge_size | |
| block_rows = torch.arange(merged_h, device=device) | |
| block_cols = torch.arange(merged_w, device=device) | |
| intra_row = torch.arange(merge_size, device=device) | |
| intra_col = torch.arange(merge_size, device=device) | |
| row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] | |
| col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] | |
| row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) | |
| col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) | |
| coords = torch.stack((row_idx, col_idx), dim=-1) | |
| if num_frames > 1: | |
| coords = coords.repeat(num_frames, 1) | |
| num_tokens = coords.shape[0] | |
| pos_ids[offset : offset + num_tokens] = coords | |
| offset += num_tokens | |
| embeddings = freq_table[pos_ids] | |
| return embeddings.flatten(1) | |
| def fast_pos_embed_interpolate(self, grid_thw): | |
| grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] | |
| idx_list = [[] for _ in range(4)] | |
| weight_list = [[] for _ in range(4)] | |
| for t, h, w in zip(grid_ts, grid_hs, grid_ws): | |
| h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) | |
| w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) | |
| h_idxs_floor = h_idxs.int() | |
| w_idxs_floor = w_idxs.int() | |
| h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) | |
| w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) | |
| dh = h_idxs - h_idxs_floor | |
| dw = w_idxs - w_idxs_floor | |
| base_h = h_idxs_floor * self.num_grid_per_side | |
| base_h_ceil = h_idxs_ceil * self.num_grid_per_side | |
| indices = [ | |
| (base_h[None].T + w_idxs_floor[None]).flatten(), | |
| (base_h[None].T + w_idxs_ceil[None]).flatten(), | |
| (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), | |
| (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), | |
| ] | |
| weights = [ | |
| ((1 - dh)[None].T * (1 - dw)[None]).flatten(), | |
| ((1 - dh)[None].T * dw[None]).flatten(), | |
| (dh[None].T * (1 - dw)[None]).flatten(), | |
| (dh[None].T * dw[None]).flatten(), | |
| ] | |
| for i in range(4): | |
| idx_list[i].extend(indices[i].tolist()) | |
| weight_list[i].extend(weights[i].tolist()) | |
| idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) | |
| weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device) | |
| pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] | |
| patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] | |
| patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) | |
| merge_size = self.config.spatial_merge_size | |
| patch_pos_embeds_permute = [] | |
| for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): | |
| pos_embed = pos_embed.repeat(t, 1) | |
| pos_embed = ( | |
| pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) | |
| .permute(0, 1, 3, 2, 4, 5) | |
| .flatten(0, 4) | |
| ) | |
| patch_pos_embeds_permute.append(pos_embed) | |
| return torch.cat(patch_pos_embeds_permute) | |
| def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: | |
| hidden_states = self.patch_embed(hidden_states) | |
| hidden_states = hidden_states + self.fast_pos_embed_interpolate(grid_thw) | |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) | |
| seq_len = hidden_states.shape[0] | |
| hidden_states = hidden_states.reshape(seq_len, -1) | |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) | |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) | |
| position_embeddings = (emb.cos(), emb.sin()) | |
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( | |
| dim=0, | |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, | |
| ) | |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) | |
| deepstack_feature_lists = [] | |
| for layer_num, blk in enumerate(self.blocks): | |
| hidden_states = blk( | |
| hidden_states, | |
| cu_seqlens=cu_seqlens, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| if layer_num in self.deepstack_visual_indexes: | |
| deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( | |
| hidden_states | |
| ) | |
| deepstack_feature_lists.append(deepstack_feature) | |
| hidden_states = self.merger(hidden_states) | |
| return hidden_states, deepstack_feature_lists | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Shared text encoder (supports both audio injection + DeepStack visual) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniTextRotaryEmbedding(nn.Module): | |
| inv_freq: torch.Tensor | |
| def __init__(self, config: BidirLMOmniTextConfig, device=None): | |
| super().__init__() | |
| self.rope_type = ( | |
| config.rope_scaling.get("rope_type", "default") | |
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None | |
| else "default" | |
| ) | |
| self.max_seq_len_cached = config.max_position_embeddings | |
| self.original_max_seq_len = config.max_position_embeddings | |
| self.config = config | |
| # transformers 5.x removed "default" from ROPE_INIT_FUNCTIONS. | |
| # For rope_type="default" (standard inv_freq, no scaling) we compute directly. | |
| if self.rope_type == "default" or self.rope_type not in ROPE_INIT_FUNCTIONS: | |
| head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
| rope_theta = getattr(config, "rope_theta", 10000.0) | |
| inv_freq = 1.0 / ( | |
| rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim) | |
| ) | |
| self.attention_scaling = 1.0 | |
| else: | |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] | |
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| self.original_inv_freq = self.inv_freq | |
| self.mrope_section = (config.rope_scaling or {}).get("mrope_section", [24, 20, 20]) | |
| def compute_default_rope_parameters(self, config=None): | |
| """Required by transformers 5.x _init_weights when rope_type='default'.""" | |
| config = config or self.config | |
| head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
| rope_theta = getattr(config, "rope_theta", 10000.0) | |
| inv_freq = 1.0 / ( | |
| rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) | |
| ) | |
| return inv_freq, 1.0 | |
| def apply_interleaved_mrope(self, freqs, mrope_section): | |
| freqs_t = freqs[0] | |
| for dim, offset in enumerate((1, 2), start=1): | |
| length = mrope_section[dim] * 3 | |
| idx = slice(offset, length, 3) | |
| freqs_t[..., idx] = freqs[dim, ..., idx] | |
| return freqs_t | |
| def forward(self, x, position_ids): | |
| if position_ids.ndim == 2: | |
| position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) | |
| inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) | |
| position_ids_expanded = position_ids[:, :, None, :].float() | |
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" | |
| with torch.autocast(device_type=device_type, enabled=False): | |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) | |
| freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos = emb.cos() * self.attention_scaling | |
| sin = emb.sin() * self.attention_scaling | |
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
| class BidirLMOmniTextAttention(nn.Module): | |
| """Bidirectional multi-head attention (no causal mask, no KV cache).""" | |
| def __init__(self, config: BidirLMOmniTextConfig, layer_idx: int): | |
| super().__init__() | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads | |
| self.scaling = self.head_dim**-0.5 | |
| self.attention_dropout = config.attention_dropout | |
| self.is_causal = False | |
| self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) | |
| self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) | |
| self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) | |
| self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) | |
| self.q_norm = BidirLMOmniRMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.k_norm = BidirLMOmniRMSNorm(self.head_dim, eps=config.rms_norm_eps) | |
| self.sliding_window = None | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | |
| attention_mask: Optional[torch.Tensor] = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> torch.Tensor: | |
| input_shape = hidden_states.shape[:-1] | |
| hidden_shape = (*input_shape, -1, self.head_dim) | |
| query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) | |
| key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) | |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| attention_interface: Callable = eager_attention_forward | |
| if self.config._attn_implementation != "eager": | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| attn_output, _ = attention_interface( | |
| self, | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| dropout=0.0 if not self.training else self.attention_dropout, | |
| scaling=self.scaling, | |
| sliding_window=self.sliding_window, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | |
| return self.o_proj(attn_output) | |
| class BidirLMOmniTextMLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.intermediate_size | |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward(self, x): | |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
| class BidirLMOmniTextEncoderLayer(GradientCheckpointingLayer): | |
| def __init__(self, config: BidirLMOmniTextConfig, layer_idx: int): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.self_attn = BidirLMOmniTextAttention(config=config, layer_idx=layer_idx) | |
| self.mlp = BidirLMOmniTextMLP(config) | |
| self.input_layernorm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], | |
| attention_mask: Optional[torch.Tensor] = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> torch.Tensor: | |
| residual = hidden_states | |
| hidden_states = self.self_attn( | |
| hidden_states=self.input_layernorm(hidden_states), | |
| attention_mask=attention_mask, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.mlp(self.post_attention_layernorm(hidden_states)) | |
| hidden_states = residual + hidden_states | |
| return hidden_states | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # PreTrainedModel base | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniPreTrainedModel(PreTrainedModel): | |
| config: BidirLMOmniConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = [ | |
| "BidirLMOmniTextEncoderLayer", | |
| "BidirLMOmniAudioEncoderLayer", | |
| "BidirLMOmniVisionBlock", | |
| ] | |
| _supports_flash_attn = True | |
| _supports_sdpa = True | |
| _supports_attention_backend = True | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Text encoder model (with DeepStack visual injection support) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniTextModel(BidirLMOmniPreTrainedModel): | |
| """ | |
| Bidirectional text encoder. Supports: | |
| - audio feature injection via ``masked_scatter`` | |
| - DeepStack visual feature injection at intermediate layers | |
| """ | |
| config: BidirLMOmniTextConfig | |
| _no_split_modules = ["BidirLMOmniTextEncoderLayer"] | |
| def __init__(self, config: BidirLMOmniTextConfig): | |
| super().__init__(config) | |
| self.padding_idx = getattr(config, "pad_token_id", None) | |
| self.vocab_size = config.vocab_size | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | |
| self.layers = nn.ModuleList( | |
| [BidirLMOmniTextEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
| ) | |
| self.norm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.rotary_emb = BidirLMOmniTextRotaryEmbedding(config) | |
| self.gradient_checkpointing = False | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.embed_tokens = value | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| # DeepStack visual injection args | |
| visual_pos_masks: Optional[torch.Tensor] = None, | |
| deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> BaseModelOutput: | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("Specify exactly one of input_ids or inputs_embeds.") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| batch_size, seq_len = inputs_embeds.shape[:2] | |
| if position_ids is None: | |
| position_ids = torch.arange(seq_len, device=inputs_embeds.device) | |
| position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) | |
| if position_ids.ndim == 2: | |
| position_ids = position_ids[None, ...].expand(3, batch_size, -1) | |
| extended_attention_mask: Optional[torch.Tensor] = None | |
| if attention_mask is not None: | |
| if self.config._attn_implementation == "flash_attention_2": | |
| # Flash attention computes cu_seqlens from a 2D mask internally; | |
| # passing a 4D mask breaks the varlen path. | |
| extended_attention_mask = attention_mask | |
| else: | |
| # Convert 1/0 mask to additive float mask (0.0 = attend, -inf = ignore). | |
| # The old boolean expand (True→+1, False→+0) added to attn_weights was NOT | |
| # a real mask: padding tokens still participated in softmax, corrupting | |
| # embeddings of shorter sequences when batched with longer ones. | |
| float_mask = attention_mask.to(dtype=inputs_embeds.dtype) | |
| extended_attention_mask = ( | |
| (1.0 - float_mask)[:, None, None, :] * torch.finfo(inputs_embeds.dtype).min | |
| ) | |
| hidden_states = inputs_embeds | |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
| for layer_idx, encoder_layer in enumerate(self.layers): | |
| hidden_states = encoder_layer( | |
| hidden_states, | |
| attention_mask=extended_attention_mask, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| # DeepStack: add visual features at intermediate layers | |
| if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds): | |
| hidden_states = self._deepstack_process( | |
| hidden_states, | |
| visual_pos_masks, | |
| deepstack_visual_embeds[layer_idx], | |
| ) | |
| hidden_states = self.norm(hidden_states) | |
| return BaseModelOutput(last_hidden_state=hidden_states) | |
| def _deepstack_process( | |
| self, | |
| hidden_states: torch.Tensor, | |
| visual_pos_masks: torch.Tensor, | |
| visual_embeds: torch.Tensor, | |
| ) -> torch.Tensor: | |
| visual_pos_masks = visual_pos_masks.to(hidden_states.device) | |
| visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) | |
| hidden_states[visual_pos_masks, :] = hidden_states[visual_pos_masks, :].clone() + visual_embeds | |
| return hidden_states | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Top-level Omni model | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniModel(BidirLMOmniPreTrainedModel): | |
| """ | |
| Audio + Vision + Text omni encoder. | |
| Accepts any combination of modalities: text-only, text+audio, text+vision, text+audio+vision. | |
| """ | |
| config: BidirLMOmniConfig | |
| def __init__(self, config: BidirLMOmniConfig): | |
| super().__init__(config) | |
| # Flash/SDPA attention only applies to the text encoder; | |
| # audio and vision towers always run eager (no causal masking or varlen path needed). | |
| config.audio_config._attn_implementation = "eager" | |
| config.vision_config._attn_implementation = "eager" | |
| config.text_config._attn_implementation = config._attn_implementation | |
| self.audio_tower = BidirLMOmniAudioEncoder._from_config(config.audio_config) | |
| self.visual = BidirLMOmniVisionModel._from_config(config.vision_config) | |
| self.language_model = BidirLMOmniTextModel._from_config(config.text_config) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.language_model.get_input_embeddings() | |
| def set_input_embeddings(self, value): | |
| self.language_model.set_input_embeddings(value) | |
| # ── Audio helpers ────────────────────────────────────────────────── | |
| def get_audio_features( | |
| self, | |
| input_features: torch.FloatTensor, | |
| feature_attention_mask: Optional[torch.LongTensor] = None, | |
| audio_feature_lengths: Optional[torch.LongTensor] = None, | |
| ) -> torch.Tensor: | |
| if feature_attention_mask is not None: | |
| audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) | |
| feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) | |
| audio_features = [] | |
| for input_feature, feature_len in zip(input_features, feature_lens): | |
| audio_output = self.audio_tower( | |
| input_feature[:, :feature_len], | |
| feature_lens=feature_len.unsqueeze(0), | |
| ) | |
| audio_features.append(audio_output.last_hidden_state) | |
| return torch.cat(audio_features, dim=0) | |
| def get_audio_placeholder_mask( | |
| self, | |
| input_ids: Optional[torch.LongTensor], | |
| inputs_embeds: torch.FloatTensor, | |
| ) -> torch.Tensor: | |
| if input_ids is None: | |
| special_audio_mask = ( | |
| inputs_embeds | |
| == self.get_input_embeddings()( | |
| torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| ).all(-1) | |
| else: | |
| special_audio_mask = input_ids == self.config.audio_token_id | |
| return special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) | |
| # ── Vision helpers ───────────────────────────────────────────────── | |
| def get_image_features( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| ): | |
| pixel_values = pixel_values.type(self.visual.dtype) | |
| image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) | |
| split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() | |
| image_embeds = torch.split(image_embeds, split_sizes) | |
| return image_embeds, deepstack_image_embeds | |
| def get_video_features( | |
| self, | |
| pixel_values_videos: torch.FloatTensor, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| ): | |
| return self.get_image_features(pixel_values_videos, video_grid_thw) | |
| def get_vision_placeholder_mask( | |
| self, | |
| input_ids: torch.LongTensor, | |
| inputs_embeds: torch.FloatTensor, | |
| image_features: Optional[torch.FloatTensor] = None, | |
| video_features: Optional[torch.FloatTensor] = None, | |
| ): | |
| if input_ids is None: | |
| img_embed = self.get_input_embeddings()( | |
| torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| vid_embed = self.get_input_embeddings()( | |
| torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) | |
| ) | |
| special_image_mask = (inputs_embeds == img_embed).all(-1) | |
| special_video_mask = (inputs_embeds == vid_embed).all(-1) | |
| else: | |
| special_image_mask = input_ids == self.config.image_token_id | |
| special_video_mask = input_ids == self.config.video_token_id | |
| special_image_mask_expanded = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) | |
| special_video_mask_expanded = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) | |
| if image_features is not None and inputs_embeds[special_image_mask_expanded].numel() != image_features.numel(): | |
| n_image_tokens = special_image_mask.sum() | |
| raise ValueError( | |
| f"Image features and image tokens do not match: tokens {n_image_tokens}, features {image_features.shape[0]}" | |
| ) | |
| if video_features is not None and inputs_embeds[special_video_mask_expanded].numel() != video_features.numel(): | |
| n_video_tokens = special_video_mask.sum() | |
| raise ValueError( | |
| f"Video features and video tokens do not match: tokens {n_video_tokens}, features {video_features.shape[0]}" | |
| ) | |
| return special_image_mask_expanded, special_video_mask_expanded | |
| # ── MRoPE position ids ───────────────────────────────────────────── | |
| def get_rope_index( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Build 3-D MRoPE position ids. Returns (3, batch, seq_len).""" | |
| if video_grid_thw is not None: | |
| video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) | |
| video_grid_thw[:, 0] = 1 | |
| spatial_merge_size = self.config.vision_config.spatial_merge_size | |
| image_token_id = self.config.image_token_id | |
| video_token_id = self.config.video_token_id | |
| vision_start_token_id = self.config.vision_start_token_id | |
| if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): | |
| total_input_ids = input_ids | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(total_input_ids) | |
| position_ids = torch.ones( | |
| 3, input_ids.shape[0], input_ids.shape[1], | |
| dtype=input_ids.dtype, device=input_ids.device, | |
| ) | |
| image_index, video_index = 0, 0 | |
| for i, ids in enumerate(total_input_ids): | |
| ids = ids[attention_mask[i] == 1] | |
| vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) | |
| vision_tokens = ids[vision_start_indices + 1] | |
| image_nums = (vision_tokens == image_token_id).sum() | |
| video_nums = (vision_tokens == video_token_id).sum() | |
| input_tokens = ids.tolist() | |
| llm_pos_ids_list: list = [] | |
| st = 0 | |
| remain_images, remain_videos = image_nums, video_nums | |
| for _ in range(image_nums + video_nums): | |
| ed_image = input_tokens.index(image_token_id, st) if image_token_id in input_tokens and remain_images > 0 else len(input_tokens) + 1 | |
| ed_video = input_tokens.index(video_token_id, st) if video_token_id in input_tokens and remain_videos > 0 else len(input_tokens) + 1 | |
| if ed_image < ed_video: | |
| t, h, w = image_grid_thw[image_index] | |
| image_index += 1; remain_images -= 1; ed = ed_image | |
| else: | |
| t, h, w = video_grid_thw[video_index] | |
| video_index += 1; remain_videos -= 1; ed = ed_video | |
| llm_grid_t = t.item() | |
| llm_grid_h = h.item() // spatial_merge_size | |
| llm_grid_w = w.item() // spatial_merge_size | |
| text_len = ed - st | |
| st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 | |
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | |
| t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() | |
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() | |
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() | |
| llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) | |
| st = ed + llm_grid_t * llm_grid_h * llm_grid_w | |
| if st < len(input_tokens): | |
| st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 | |
| llm_pos_ids_list.append(torch.arange(len(input_tokens) - st).view(1, -1).expand(3, -1) + st_idx) | |
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | |
| position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) | |
| return position_ids | |
| # Text-only / audio-only path (no spatial position structure) | |
| if attention_mask is not None: | |
| position_ids = attention_mask.long().cumsum(-1) - 1 | |
| position_ids.masked_fill_(attention_mask == 0, 1) | |
| else: | |
| position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1) | |
| return position_ids.unsqueeze(0).expand(3, -1, -1) | |
| # ── Forward ──────────────────────────────────────────────────────── | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| # Audio inputs | |
| input_features: Optional[torch.FloatTensor] = None, | |
| feature_attention_mask: Optional[torch.LongTensor] = None, | |
| audio_feature_lengths: Optional[torch.LongTensor] = None, | |
| # Vision inputs | |
| pixel_values: Optional[torch.Tensor] = None, | |
| pixel_values_videos: Optional[torch.FloatTensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> BaseModelOutput: | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("Specify exactly one of input_ids or inputs_embeds.") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| # ── Audio injection ──────────────────────────────────────────── | |
| if input_features is not None: | |
| audio_features = self.get_audio_features( | |
| input_features, | |
| feature_attention_mask=feature_attention_mask, | |
| audio_feature_lengths=audio_feature_lengths, | |
| ).to(inputs_embeds.device, inputs_embeds.dtype) | |
| audio_mask = self.get_audio_placeholder_mask(input_ids, inputs_embeds) | |
| inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) | |
| # ── Vision injection ─────────────────────────────────────────── | |
| image_mask = video_mask = None | |
| if pixel_values is not None: | |
| image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) | |
| image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) | |
| image_mask, _ = self.get_vision_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) | |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) | |
| if pixel_values_videos is not None: | |
| video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) | |
| video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) | |
| _, video_mask = self.get_vision_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) | |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) | |
| # ── Assemble DeepStack masks / embeds ────────────────────────── | |
| visual_pos_masks = deepstack_visual_embeds = None | |
| if image_mask is not None and video_mask is not None: | |
| im = image_mask[..., 0]; vm = video_mask[..., 0] | |
| visual_pos_masks = im | vm | |
| image_mask_joint = im[visual_pos_masks] | |
| video_mask_joint = vm[visual_pos_masks] | |
| deepstack_visual_embeds = [] | |
| for img_e, vid_e in zip(deepstack_image_embeds, deepstack_video_embeds): | |
| joint = img_e.new_zeros(visual_pos_masks.sum(), img_e.shape[-1]).to(img_e.device) | |
| joint[image_mask_joint] = img_e | |
| joint[video_mask_joint] = vid_e | |
| deepstack_visual_embeds.append(joint) | |
| elif image_mask is not None: | |
| visual_pos_masks = image_mask[..., 0] | |
| deepstack_visual_embeds = deepstack_image_embeds | |
| elif video_mask is not None: | |
| visual_pos_masks = video_mask[..., 0] | |
| deepstack_visual_embeds = deepstack_video_embeds | |
| # ── Build position ids ───────────────────────────────────────── | |
| if position_ids is None: | |
| position_ids = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) | |
| return self.language_model( | |
| input_ids=None, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| visual_pos_masks=visual_pos_masks, | |
| deepstack_visual_embeds=deepstack_visual_embeds, | |
| **kwargs, | |
| ) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Masked language model head | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniForMaskedLM(BidirLMOmniPreTrainedModel): | |
| _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} | |
| def __init__(self, config: BidirLMOmniConfig): | |
| super().__init__(config) | |
| self.model = BidirLMOmniModel(config) | |
| self.lm_head = nn.Linear( | |
| config.text_config.hidden_size, | |
| config.text_config.vocab_size, | |
| bias=False, | |
| ) | |
| self.post_init() | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| pixel_values_videos: Optional[torch.Tensor] = None, | |
| input_features: Optional[torch.FloatTensor] = None, | |
| feature_attention_mask: Optional[torch.LongTensor] = None, | |
| audio_feature_lengths: Optional[torch.LongTensor] = None, | |
| **kwargs, | |
| ) -> MaskedLMOutput: | |
| encoder_output = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| inputs_embeds=inputs_embeds, | |
| pixel_values=pixel_values, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| pixel_values_videos=pixel_values_videos, | |
| input_features=input_features, | |
| feature_attention_mask=feature_attention_mask, | |
| audio_feature_lengths=audio_feature_lengths, | |
| **kwargs, | |
| ) | |
| logits = self.lm_head(encoder_output.last_hidden_state) | |
| loss = None | |
| if labels is not None: | |
| loss = self.loss_function(logits, labels, vocab_size=self.config.text_config.vocab_size) | |
| return MaskedLMOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=encoder_output.hidden_states, | |
| attentions=encoder_output.attentions, | |
| ) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Sequence classification head | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniForSequenceClassification(BidirLMOmniPreTrainedModel): | |
| def __init__(self, config: BidirLMOmniConfig): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.clf_pooling = config.clf_pooling | |
| self.model = BidirLMOmniModel(config) | |
| self.dense = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) | |
| self.activation = nn.GELU() | |
| self.classifier = nn.Linear(config.text_config.hidden_size, self.num_labels) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| pixel_values_videos: Optional[torch.Tensor] = None, | |
| input_features: Optional[torch.FloatTensor] = None, | |
| feature_attention_mask: Optional[torch.LongTensor] = None, | |
| audio_feature_lengths: Optional[torch.LongTensor] = None, | |
| **kwargs, | |
| ) -> SequenceClassifierOutput: | |
| encoder_output = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| inputs_embeds=inputs_embeds, | |
| pixel_values=pixel_values, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| pixel_values_videos=pixel_values_videos, | |
| input_features=input_features, | |
| feature_attention_mask=feature_attention_mask, | |
| audio_feature_lengths=audio_feature_lengths, | |
| **kwargs, | |
| ) | |
| last_hidden_state = encoder_output.last_hidden_state | |
| if self.clf_pooling == "bos": | |
| pooled = last_hidden_state[:, 0] | |
| elif self.clf_pooling == "mean": | |
| if attention_mask is None: | |
| pooled = last_hidden_state.mean(dim=1) | |
| else: | |
| pooled = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) | |
| pooled = pooled / attention_mask.sum(dim=1, keepdim=True) | |
| else: # "late" — project each token then mean-pool | |
| pooled = last_hidden_state | |
| pooled = self.dense(pooled) | |
| pooled = self.activation(pooled) | |
| logits = self.classifier(pooled) | |
| if self.clf_pooling == "late": | |
| if attention_mask is None: | |
| logits = logits.mean(dim=1) | |
| else: | |
| logits = (logits * attention_mask.unsqueeze(-1)).sum(dim=1) | |
| logits = logits / attention_mask.sum(dim=1, keepdim=True) | |
| loss = None | |
| if labels is not None: | |
| labels = labels.to(logits.device) | |
| if self.config.problem_type is None: | |
| if self.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = MSELoss() | |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=encoder_output.hidden_states, | |
| attentions=encoder_output.attentions, | |
| ) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # Token classification head | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| class BidirLMOmniForTokenClassification(BidirLMOmniPreTrainedModel): | |
| def __init__(self, config: BidirLMOmniConfig): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.model = BidirLMOmniModel(config) | |
| self.classifier = nn.Linear(config.text_config.hidden_size, self.num_labels) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| pixel_values_videos: Optional[torch.Tensor] = None, | |
| input_features: Optional[torch.FloatTensor] = None, | |
| feature_attention_mask: Optional[torch.LongTensor] = None, | |
| audio_feature_lengths: Optional[torch.LongTensor] = None, | |
| **kwargs, | |
| ) -> TokenClassifierOutput: | |
| encoder_output = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| inputs_embeds=inputs_embeds, | |
| pixel_values=pixel_values, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| pixel_values_videos=pixel_values_videos, | |
| input_features=input_features, | |
| feature_attention_mask=feature_attention_mask, | |
| audio_feature_lengths=audio_feature_lengths, | |
| **kwargs, | |
| ) | |
| logits = self.classifier(encoder_output.last_hidden_state) | |
| loss = None | |
| if labels is not None: | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| return TokenClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=encoder_output.hidden_states, | |
| attentions=encoder_output.attentions, | |
| ) | |
| __all__ = [ | |
| "BidirLMOmniPreTrainedModel", | |
| "BidirLMOmniAudioEncoder", | |
| "BidirLMOmniVisionModel", | |
| "BidirLMOmniTextModel", | |
| "BidirLMOmniModel", | |
| "BidirLMOmniForMaskedLM", | |
| "BidirLMOmniForSequenceClassification", | |
| "BidirLMOmniForTokenClassification", | |
| ] |