import math from typing import Callable, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers.activations import ACT2FN from transformers.integrations import use_kernel_forward_from_hub from transformers.modeling_layers import GradientCheckpointingLayer from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.modeling_outputs import ( BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import auto_docstring from transformers.utils.generic import TransformersKwargs try: from .configuration_bidirlm_omni import ( BidirLMOmniAudioConfig, BidirLMOmniConfig, BidirLMOmniTextConfig, BidirLMOmniVisionConfig, ) except ImportError: from configuration_bidirlm_omni import ( BidirLMOmniAudioConfig, BidirLMOmniConfig, BidirLMOmniTextConfig, BidirLMOmniVisionConfig, ) # ═══════════════════════════════════════════════════════════════════════════ # Shared utilities # ═══════════════════════════════════════════════════════════════════════════ def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def _get_feat_extract_output_lengths(input_lengths): # Three Conv2d layers each with kernel=3, stride=2, padding=1. # Per-layer formula: floor((L - 1) / 2) + 1 L = (input_lengths - 1) // 2 + 1 L = (L - 1) // 2 + 1 L = (L - 1) // 2 + 1 return L @use_kernel_forward_from_hub("RMSNorm") class BidirLMOmniRMSNorm(nn.Module): def __init__(self, hidden_size, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" # ═══════════════════════════════════════════════════════════════════════════ # Audio encoder (copied from BiQwen3-SP) # ═══════════════════════════════════════════════════════════════════════════ class BidirLMOmniAudioAttention(nn.Module): def __init__(self, config: BidirLMOmniAudioConfig): super().__init__() self.embed_dim = config.d_model self.num_heads = config.encoder_attention_heads self.head_dim = self.embed_dim // self.num_heads self.num_key_value_groups = 1 self.config = config self.scaling = self.head_dim**-0.5 self.attention_dropout = 0.0 self.is_causal = False if (self.head_dim * self.num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads " f"(got embed_dim={self.embed_dim}, num_heads={self.num_heads})." ) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) def forward( self, hidden_states: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: seq_length, _ = hidden_states.size() query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, _ = attention_interface( self, query_states, key_states, value_states, attention_mask=attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, cu_seq_lens_q=cu_seqlens, cu_seq_lens_k=cu_seqlens, max_length_q=max_seqlen, max_length_k=max_seqlen, is_causal=False, **kwargs, ) attn_output = attn_output.reshape(seq_length, -1).contiguous() return self.out_proj(attn_output) class BidirLMOmniAudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BidirLMOmniAudioConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = BidirLMOmniAudioAttention(config) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> tuple[torch.Tensor]: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, attention_mask=attention_mask, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.fc2(hidden_states) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) return (hidden_states,) class SinusoidsPositionEmbedding(nn.Module): def __init__(self, length, channels, max_timescale=10000): super().__init__() if channels % 2 != 0: raise ValueError("SinusoidsPositionEmbedding needs even channels input") # Store scalars so forward can recompute in float32. # tf5 casts persistent=False buffers to the model dtype (bfloat16), # which degrades sinusoidal precision and diverges from tf4 numerics. self.length = length self.channels = channels self.max_timescale = max_timescale # Register a dummy buffer only for device tracking. log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] self.register_buffer( "positional_embedding", torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), persistent=False, ) def _recompute(self, seqlen: int, device) -> torch.Tensor: log_timescale_increment = np.log(self.max_timescale) / (self.channels // 2 - 1) inv_timescales = torch.exp( -log_timescale_increment * torch.arange(self.channels // 2, dtype=torch.float32, device=device) ) scaled_time = torch.arange(seqlen, dtype=torch.float32, device=device)[:, None] * inv_timescales[None, :] return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) def forward(self, seqlen: int): # Recompute in float32 every call — do NOT use the stored buffer whose # values may have been cast to bfloat16 by from_pretrained. return self._recompute(seqlen, self.positional_embedding.device) class BidirLMOmniAudioEncoder(PreTrainedModel): config: BidirLMOmniAudioConfig main_input_name = "input_features" _no_split_modules = ["BidirLMOmniAudioEncoderLayer"] _supports_sdpa = True def __init__(self, config: BidirLMOmniAudioConfig): super().__init__(config) self.dropout = config.dropout embed_dim = config.d_model self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.n_window = config.n_window self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) self.layers = nn.ModuleList([BidirLMOmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) self.ln_post = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1) self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) self.conv_out = nn.Linear( config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), config.d_model, bias=False, ) self.proj1 = nn.Linear(config.d_model, config.d_model) self.act = ACT2FN[config.activation_function] self.proj2 = nn.Linear(config.d_model, config.output_dim) self.n_window_infer = self.config.n_window_infer self.conv_chunksize = self.config.conv_chunksize self.post_init() def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> Optional[torch.Tensor]: if self.config._attn_implementation == "flash_attention_2": return None seq_length = inputs_tensor.shape[0] attention_mask = torch.full( [1, 1, seq_length, seq_length], torch.finfo(inputs_tensor.dtype).min, device=inputs_tensor.device, dtype=inputs_tensor.dtype, ) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 return attention_mask def forward(self, input_features, feature_lens=None, aftercnn_lens=None): aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() chunk_lengths = torch.tensor( [self.n_window * 2] * chunk_num.sum(), dtype=torch.long, device=feature_lens.device, ) tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) chunk_lengths[chunk_lengths == 0] = self.n_window * 2 chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) padded_mask_after_cnn = nn.utils.rnn.pad_sequence( [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], batch_first=True, ) padded_feature = padded_feature.unsqueeze(1) padded_embeds = [] for chunk in padded_feature.split(self.conv_chunksize, dim=0): padded_embed = F.gelu(self.conv2d1(chunk)) padded_embed = F.gelu(self.conv2d2(padded_embed)) padded_embed = F.gelu(self.conv2d3(padded_embed)) padded_embeds.append(padded_embed) padded_embed = torch.cat(padded_embeds, dim=0) b, c, f, t = padded_embed.size() padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) # Call forward() which recomputes sinusoids in float32 (avoids bfloat16 buffer precision loss). positional_embedding = ( self.positional_embedding(padded_embed.shape[1]) .unsqueeze(0) .to(padded_embed.dtype) ) padded_embed = padded_embed + positional_embedding hidden_states = padded_embed[padded_mask_after_cnn] cu_chunk_lens = [0] window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) for cnn_len in aftercnn_lens: cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) remainder = cnn_len % window_aftercnn if remainder != 0: cu_chunk_lens += [remainder] cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) for encoder_layer in self.layers: layer_outputs = encoder_layer(hidden_states, cu_seqlens) hidden_states = layer_outputs[0] hidden_states = self.ln_post(hidden_states) hidden_states = self.proj1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.proj2(hidden_states) return BaseModelOutput(last_hidden_state=hidden_states) # ═══════════════════════════════════════════════════════════════════════════ # Vision encoder (copied from BiQwen3-VL) # ═══════════════════════════════════════════════════════════════════════════ class BidirLMOmniVisionMLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) class BidirLMOmniVisionPatchEmbed(nn.Module): def __init__(self, config) -> None: super().__init__() self.patch_size = config.patch_size self.temporal_patch_size = config.temporal_patch_size self.in_channels = config.in_channels self.embed_dim = config.hidden_size kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) return hidden_states class BidirLMOmniVisionRotaryEmbedding(nn.Module): inv_freq: torch.Tensor def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() # Store theta/dim so forward can recompute inv_freq in float32. # We still register a buffer for device tracking, but don't use its values. self.dim = dim self.theta = theta inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: # Recompute inv_freq in float32 — matches training precision. inv_freq = 1.0 / ( self.theta ** ( torch.arange(0, self.dim, 2, dtype=torch.float32, device=self.inv_freq.device) / self.dim ) ) seq = torch.arange(seqlen, dtype=torch.float32, device=self.inv_freq.device) return torch.outer(seq, inv_freq) class BidirLMOmniVisionPatchMerger(nn.Module): def __init__(self, config: BidirLMOmniVisionConfig, use_postshuffle_norm=False) -> None: super().__init__() self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) self.use_postshuffle_norm = use_postshuffle_norm self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) self.act_fn = nn.GELU() self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) return x def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype) class BidirLMOmniVisionAttention(nn.Module): def __init__(self, config: BidirLMOmniVisionConfig) -> None: super().__init__() self.dim = config.hidden_size self.num_heads = config.num_heads self.head_dim = self.dim // self.num_heads self.num_key_value_groups = 1 self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) self.proj = nn.Linear(self.dim, self.dim) self.scaling = self.head_dim**-0.5 self.config = config self.attention_dropout = 0.0 self.is_causal = False def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] query_states, key_states, value_states = ( self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) ) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if self.config._attn_implementation == "flash_attention_2": max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attn_output, _ = attention_interface( self, query_states, key_states, value_states, attention_mask=None, scaling=self.scaling, dropout=0.0 if not self.training else self.attention_dropout, cu_seq_lens_q=cu_seqlens, cu_seq_lens_k=cu_seqlens, max_length_q=max_seqlen, max_length_k=max_seqlen, is_causal=self.is_causal, **kwargs, ) else: lengths = cu_seqlens[1:] - cu_seqlens[:-1] splits = [ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) ] attn_outputs = [ attention_interface( self, q, k, v, attention_mask=None, scaling=self.scaling, dropout=0.0 if not self.training else self.attention_dropout, is_causal=self.is_causal, **kwargs, )[0] for q, k, v in zip(*splits) ] attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) return attn_output class BidirLMOmniVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) self.attn = BidirLMOmniVisionAttention(config=config) self.mlp = BidirLMOmniVisionMLP(config=config) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states class BidirLMOmniVisionModel(PreTrainedModel): config: BidirLMOmniVisionConfig _no_split_modules = ["BidirLMOmniVisionBlock"] def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) self.spatial_merge_size = config.spatial_merge_size self.patch_size = config.patch_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size self.patch_embed = BidirLMOmniVisionPatchEmbed(config=config) self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) self.num_grid_per_side = int(config.num_position_embeddings**0.5) head_dim = config.hidden_size // config.num_heads self.rotary_pos_emb = BidirLMOmniVisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([BidirLMOmniVisionBlock(config) for _ in range(config.depth)]) self.merger = BidirLMOmniVisionPatchMerger(config=config, use_postshuffle_norm=False) self.deepstack_visual_indexes = config.deepstack_visual_indexes self.deepstack_merger_list = nn.ModuleList( [ BidirLMOmniVisionPatchMerger(config=config, use_postshuffle_norm=True) for _ in range(len(config.deepstack_visual_indexes)) ] ) self.gradient_checkpointing = False def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: merge_size = self.spatial_merge_size max_hw = int(grid_thw[:, 1:].max().item()) freq_table = self.rotary_pos_emb(max_hw) device = freq_table.device total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) offset = 0 for num_frames, height, width in grid_thw: merged_h, merged_w = height // merge_size, width // merge_size block_rows = torch.arange(merged_h, device=device) block_cols = torch.arange(merged_w, device=device) intra_row = torch.arange(merge_size, device=device) intra_col = torch.arange(merge_size, device=device) row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) coords = torch.stack((row_idx, col_idx), dim=-1) if num_frames > 1: coords = coords.repeat(num_frames, 1) num_tokens = coords.shape[0] pos_ids[offset : offset + num_tokens] = coords offset += num_tokens embeddings = freq_table[pos_ids] return embeddings.flatten(1) def fast_pos_embed_interpolate(self, grid_thw): grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] idx_list = [[] for _ in range(4)] weight_list = [[] for _ in range(4)] for t, h, w in zip(grid_ts, grid_hs, grid_ws): h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) h_idxs_floor = h_idxs.int() w_idxs_floor = w_idxs.int() h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) dh = h_idxs - h_idxs_floor dw = w_idxs - w_idxs_floor base_h = h_idxs_floor * self.num_grid_per_side base_h_ceil = h_idxs_ceil * self.num_grid_per_side indices = [ (base_h[None].T + w_idxs_floor[None]).flatten(), (base_h[None].T + w_idxs_ceil[None]).flatten(), (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ] weights = [ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ((1 - dh)[None].T * dw[None]).flatten(), (dh[None].T * (1 - dw)[None]).flatten(), (dh[None].T * dw[None]).flatten(), ] for i in range(4): idx_list[i].extend(indices[i].tolist()) weight_list[i].extend(weights[i].tolist()) idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device) pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) merge_size = self.config.spatial_merge_size patch_pos_embeds_permute = [] for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): pos_embed = pos_embed.repeat(t, 1) pos_embed = ( pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) .permute(0, 1, 3, 2, 4, 5) .flatten(0, 4) ) patch_pos_embeds_permute.append(pos_embed) return torch.cat(patch_pos_embeds_permute) def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states + self.fast_pos_embed_interpolate(grid_thw) rotary_pos_emb = self.rot_pos_emb(grid_thw) seq_len = hidden_states.shape[0] hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, **kwargs, ) if layer_num in self.deepstack_visual_indexes: deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( hidden_states ) deepstack_feature_lists.append(deepstack_feature) hidden_states = self.merger(hidden_states) return hidden_states, deepstack_feature_lists # ═══════════════════════════════════════════════════════════════════════════ # Shared text encoder (supports both audio injection + DeepStack visual) # ═══════════════════════════════════════════════════════════════════════════ class BidirLMOmniTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor def __init__(self, config: BidirLMOmniTextConfig, device=None): super().__init__() self.rope_type = ( config.rope_scaling.get("rope_type", "default") if hasattr(config, "rope_scaling") and config.rope_scaling is not None else "default" ) self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config # transformers 5.x removed "default" from ROPE_INIT_FUNCTIONS. # For rope_type="default" (standard inv_freq, no scaling) we compute directly. if self.rope_type == "default" or self.rope_type not in ROPE_INIT_FUNCTIONS: head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) rope_theta = getattr(config, "rope_theta", 10000.0) inv_freq = 1.0 / ( rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim) ) self.attention_scaling = 1.0 else: self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq self.mrope_section = (config.rope_scaling or {}).get("mrope_section", [24, 20, 20]) def compute_default_rope_parameters(self, config=None): """Required by transformers 5.x _init_weights when rope_type='default'.""" config = config or self.config head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) rope_theta = getattr(config, "rope_theta", 10000.0) inv_freq = 1.0 / ( rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) ) return inv_freq, 1.0 def apply_interleaved_mrope(self, freqs, mrope_section): freqs_t = freqs[0] for dim, offset in enumerate((1, 2), start=1): length = mrope_section[dim] * 3 idx = slice(offset, length, 3) freqs_t[..., idx] = freqs[dim, ..., idx] return freqs_t @torch.no_grad() def forward(self, x, position_ids): if position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) position_ids_expanded = position_ids[:, :, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class BidirLMOmniTextAttention(nn.Module): """Bidirectional multi-head attention (no causal mask, no KV cache).""" def __init__(self, config: BidirLMOmniTextConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = False self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) self.q_norm = BidirLMOmniRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = BidirLMOmniRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.sliding_window = None def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, _ = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() return self.o_proj(attn_output) class BidirLMOmniTextMLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class BidirLMOmniTextEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BidirLMOmniTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = BidirLMOmniTextAttention(config=config, layer_idx=layer_idx) self.mlp = BidirLMOmniTextMLP(config) self.input_layernorm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states = self.self_attn( hidden_states=self.input_layernorm(hidden_states), attention_mask=attention_mask, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.mlp(self.post_attention_layernorm(hidden_states)) hidden_states = residual + hidden_states return hidden_states # ═══════════════════════════════════════════════════════════════════════════ # PreTrainedModel base # ═══════════════════════════════════════════════════════════════════════════ @auto_docstring class BidirLMOmniPreTrainedModel(PreTrainedModel): config: BidirLMOmniConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [ "BidirLMOmniTextEncoderLayer", "BidirLMOmniAudioEncoderLayer", "BidirLMOmniVisionBlock", ] _supports_flash_attn = True _supports_sdpa = True _supports_attention_backend = True # ═══════════════════════════════════════════════════════════════════════════ # Text encoder model (with DeepStack visual injection support) # ═══════════════════════════════════════════════════════════════════════════ class BidirLMOmniTextModel(BidirLMOmniPreTrainedModel): """ Bidirectional text encoder. Supports: - audio feature injection via ``masked_scatter`` - DeepStack visual feature injection at intermediate layers """ config: BidirLMOmniTextConfig _no_split_modules = ["BidirLMOmniTextEncoderLayer"] def __init__(self, config: BidirLMOmniTextConfig): super().__init__(config) self.padding_idx = getattr(config, "pad_token_id", None) self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [BidirLMOmniTextEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = BidirLMOmniTextRotaryEmbedding(config) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, # DeepStack visual injection args visual_pos_masks: Optional[torch.Tensor] = None, deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("Specify exactly one of input_ids or inputs_embeds.") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_len = inputs_embeds.shape[:2] if position_ids is None: position_ids = torch.arange(seq_len, device=inputs_embeds.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) if position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, batch_size, -1) extended_attention_mask: Optional[torch.Tensor] = None if attention_mask is not None: if self.config._attn_implementation == "flash_attention_2": # Flash attention computes cu_seqlens from a 2D mask internally; # passing a 4D mask breaks the varlen path. extended_attention_mask = attention_mask else: # Convert 1/0 mask to additive float mask (0.0 = attend, -inf = ignore). # The old boolean expand (True→+1, False→+0) added to attn_weights was NOT # a real mask: padding tokens still participated in softmax, corrupting # embeddings of shorter sequences when batched with longer ones. float_mask = attention_mask.to(dtype=inputs_embeds.dtype) extended_attention_mask = ( (1.0 - float_mask)[:, None, None, :] * torch.finfo(inputs_embeds.dtype).min ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer_idx, encoder_layer in enumerate(self.layers): hidden_states = encoder_layer( hidden_states, attention_mask=extended_attention_mask, position_embeddings=position_embeddings, **kwargs, ) # DeepStack: add visual features at intermediate layers if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds): hidden_states = self._deepstack_process( hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx], ) hidden_states = self.norm(hidden_states) return BaseModelOutput(last_hidden_state=hidden_states) def _deepstack_process( self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor, ) -> torch.Tensor: visual_pos_masks = visual_pos_masks.to(hidden_states.device) visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) hidden_states[visual_pos_masks, :] = hidden_states[visual_pos_masks, :].clone() + visual_embeds return hidden_states # ═══════════════════════════════════════════════════════════════════════════ # Top-level Omni model # ═══════════════════════════════════════════════════════════════════════════ @auto_docstring( custom_intro="Multimodal encoder combining audio tower, vision tower, and shared bidirectional text encoder." ) class BidirLMOmniModel(BidirLMOmniPreTrainedModel): """ Audio + Vision + Text omni encoder. Accepts any combination of modalities: text-only, text+audio, text+vision, text+audio+vision. """ config: BidirLMOmniConfig def __init__(self, config: BidirLMOmniConfig): super().__init__(config) # Flash/SDPA attention only applies to the text encoder; # audio and vision towers always run eager (no causal masking or varlen path needed). config.audio_config._attn_implementation = "eager" config.vision_config._attn_implementation = "eager" config.text_config._attn_implementation = config._attn_implementation self.audio_tower = BidirLMOmniAudioEncoder._from_config(config.audio_config) self.visual = BidirLMOmniVisionModel._from_config(config.vision_config) self.language_model = BidirLMOmniTextModel._from_config(config.text_config) self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) # ── Audio helpers ────────────────────────────────────────────────── def get_audio_features( self, input_features: torch.FloatTensor, feature_attention_mask: Optional[torch.LongTensor] = None, audio_feature_lengths: Optional[torch.LongTensor] = None, ) -> torch.Tensor: if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) audio_features = [] for input_feature, feature_len in zip(input_features, feature_lens): audio_output = self.audio_tower( input_feature[:, :feature_len], feature_lens=feature_len.unsqueeze(0), ) audio_features.append(audio_output.last_hidden_state) return torch.cat(audio_features, dim=0) def get_audio_placeholder_mask( self, input_ids: Optional[torch.LongTensor], inputs_embeds: torch.FloatTensor, ) -> torch.Tensor: if input_ids is None: special_audio_mask = ( inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) ) ).all(-1) else: special_audio_mask = input_ids == self.config.audio_token_id return special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) # ── Vision helpers ───────────────────────────────────────────────── def get_image_features( self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None, ): pixel_values = pixel_values.type(self.visual.dtype) image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() image_embeds = torch.split(image_embeds, split_sizes) return image_embeds, deepstack_image_embeds def get_video_features( self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None, ): return self.get_image_features(pixel_values_videos, video_grid_thw) def get_vision_placeholder_mask( self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: Optional[torch.FloatTensor] = None, video_features: Optional[torch.FloatTensor] = None, ): if input_ids is None: img_embed = self.get_input_embeddings()( torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) vid_embed = self.get_input_embeddings()( torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) ) special_image_mask = (inputs_embeds == img_embed).all(-1) special_video_mask = (inputs_embeds == vid_embed).all(-1) else: special_image_mask = input_ids == self.config.image_token_id special_video_mask = input_ids == self.config.video_token_id special_image_mask_expanded = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) special_video_mask_expanded = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) if image_features is not None and inputs_embeds[special_image_mask_expanded].numel() != image_features.numel(): n_image_tokens = special_image_mask.sum() raise ValueError( f"Image features and image tokens do not match: tokens {n_image_tokens}, features {image_features.shape[0]}" ) if video_features is not None and inputs_embeds[special_video_mask_expanded].numel() != video_features.numel(): n_video_tokens = special_video_mask.sum() raise ValueError( f"Video features and video tokens do not match: tokens {n_video_tokens}, features {video_features.shape[0]}" ) return special_image_mask_expanded, special_video_mask_expanded # ── MRoPE position ids ───────────────────────────────────────────── def get_rope_index( self, input_ids: Optional[torch.LongTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Build 3-D MRoPE position ids. Returns (3, batch, seq_len).""" if video_grid_thw is not None: video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) video_grid_thw[:, 0] = 1 spatial_merge_size = self.config.vision_config.spatial_merge_size image_token_id = self.config.image_token_id video_token_id = self.config.video_token_id vision_start_token_id = self.config.vision_start_token_id if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) position_ids = torch.ones( 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device, ) image_index, video_index = 0, 0 for i, ids in enumerate(total_input_ids): ids = ids[attention_mask[i] == 1] vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) vision_tokens = ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() input_tokens = ids.tolist() llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums for _ in range(image_nums + video_nums): ed_image = input_tokens.index(image_token_id, st) if image_token_id in input_tokens and remain_images > 0 else len(input_tokens) + 1 ed_video = input_tokens.index(video_token_id, st) if video_token_id in input_tokens and remain_videos > 0 else len(input_tokens) + 1 if ed_image < ed_video: t, h, w = image_grid_thw[image_index] image_index += 1; remain_images -= 1; ed = ed_image else: t, h, w = video_grid_thw[video_index] video_index += 1; remain_videos -= 1; ed = ed_video llm_grid_t = t.item() llm_grid_h = h.item() // spatial_merge_size llm_grid_w = w.item() // spatial_merge_size text_len = ed - st st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 llm_pos_ids_list.append(torch.arange(len(input_tokens) - st).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) return position_ids # Text-only / audio-only path (no spatial position structure) if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) else: position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1) return position_ids.unsqueeze(0).expand(3, -1, -1) # ── Forward ──────────────────────────────────────────────────────── def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, # Audio inputs input_features: Optional[torch.FloatTensor] = None, feature_attention_mask: Optional[torch.LongTensor] = None, audio_feature_lengths: Optional[torch.LongTensor] = None, # Vision inputs pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("Specify exactly one of input_ids or inputs_embeds.") if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) # ── Audio injection ──────────────────────────────────────────── if input_features is not None: audio_features = self.get_audio_features( input_features, feature_attention_mask=feature_attention_mask, audio_feature_lengths=audio_feature_lengths, ).to(inputs_embeds.device, inputs_embeds.dtype) audio_mask = self.get_audio_placeholder_mask(input_ids, inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) # ── Vision injection ─────────────────────────────────────────── image_mask = video_mask = None if pixel_values is not None: image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) image_mask, _ = self.get_vision_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) _, video_mask = self.get_vision_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # ── Assemble DeepStack masks / embeds ────────────────────────── visual_pos_masks = deepstack_visual_embeds = None if image_mask is not None and video_mask is not None: im = image_mask[..., 0]; vm = video_mask[..., 0] visual_pos_masks = im | vm image_mask_joint = im[visual_pos_masks] video_mask_joint = vm[visual_pos_masks] deepstack_visual_embeds = [] for img_e, vid_e in zip(deepstack_image_embeds, deepstack_video_embeds): joint = img_e.new_zeros(visual_pos_masks.sum(), img_e.shape[-1]).to(img_e.device) joint[image_mask_joint] = img_e joint[video_mask_joint] = vid_e deepstack_visual_embeds.append(joint) elif image_mask is not None: visual_pos_masks = image_mask[..., 0] deepstack_visual_embeds = deepstack_image_embeds elif video_mask is not None: visual_pos_masks = video_mask[..., 0] deepstack_visual_embeds = deepstack_video_embeds # ── Build position ids ───────────────────────────────────────── if position_ids is None: position_ids = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) return self.language_model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, **kwargs, ) # ═══════════════════════════════════════════════════════════════════════════ # Masked language model head # ═══════════════════════════════════════════════════════════════════════════ @auto_docstring class BidirLMOmniForMaskedLM(BidirLMOmniPreTrainedModel): _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: BidirLMOmniConfig): super().__init__(config) self.model = BidirLMOmniModel(config) self.lm_head = nn.Linear( config.text_config.hidden_size, config.text_config.vocab_size, bias=False, ) self.post_init() def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, pixel_values_videos: Optional[torch.Tensor] = None, input_features: Optional[torch.FloatTensor] = None, feature_attention_mask: Optional[torch.LongTensor] = None, audio_feature_lengths: Optional[torch.LongTensor] = None, **kwargs, ) -> MaskedLMOutput: encoder_output = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, pixel_values_videos=pixel_values_videos, input_features=input_features, feature_attention_mask=feature_attention_mask, audio_feature_lengths=audio_feature_lengths, **kwargs, ) logits = self.lm_head(encoder_output.last_hidden_state) loss = None if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.text_config.vocab_size) return MaskedLMOutput( loss=loss, logits=logits, hidden_states=encoder_output.hidden_states, attentions=encoder_output.attentions, ) # ═══════════════════════════════════════════════════════════════════════════ # Sequence classification head # ═══════════════════════════════════════════════════════════════════════════ @auto_docstring class BidirLMOmniForSequenceClassification(BidirLMOmniPreTrainedModel): def __init__(self, config: BidirLMOmniConfig): super().__init__(config) self.num_labels = config.num_labels self.clf_pooling = config.clf_pooling self.model = BidirLMOmniModel(config) self.dense = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size) self.activation = nn.GELU() self.classifier = nn.Linear(config.text_config.hidden_size, self.num_labels) self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, pixel_values_videos: Optional[torch.Tensor] = None, input_features: Optional[torch.FloatTensor] = None, feature_attention_mask: Optional[torch.LongTensor] = None, audio_feature_lengths: Optional[torch.LongTensor] = None, **kwargs, ) -> SequenceClassifierOutput: encoder_output = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, pixel_values_videos=pixel_values_videos, input_features=input_features, feature_attention_mask=feature_attention_mask, audio_feature_lengths=audio_feature_lengths, **kwargs, ) last_hidden_state = encoder_output.last_hidden_state if self.clf_pooling == "bos": pooled = last_hidden_state[:, 0] elif self.clf_pooling == "mean": if attention_mask is None: pooled = last_hidden_state.mean(dim=1) else: pooled = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) pooled = pooled / attention_mask.sum(dim=1, keepdim=True) else: # "late" — project each token then mean-pool pooled = last_hidden_state pooled = self.dense(pooled) pooled = self.activation(pooled) logits = self.classifier(pooled) if self.clf_pooling == "late": if attention_mask is None: logits = logits.mean(dim=1) else: logits = (logits * attention_mask.unsqueeze(-1)).sum(dim=1) logits = logits / attention_mask.sum(dim=1, keepdim=True) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() loss = loss_fct(logits.squeeze(), labels.squeeze()) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=encoder_output.hidden_states, attentions=encoder_output.attentions, ) # ═══════════════════════════════════════════════════════════════════════════ # Token classification head # ═══════════════════════════════════════════════════════════════════════════ @auto_docstring class BidirLMOmniForTokenClassification(BidirLMOmniPreTrainedModel): def __init__(self, config: BidirLMOmniConfig): super().__init__(config) self.num_labels = config.num_labels self.model = BidirLMOmniModel(config) self.classifier = nn.Linear(config.text_config.hidden_size, self.num_labels) self.post_init() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, pixel_values_videos: Optional[torch.Tensor] = None, input_features: Optional[torch.FloatTensor] = None, feature_attention_mask: Optional[torch.LongTensor] = None, audio_feature_lengths: Optional[torch.LongTensor] = None, **kwargs, ) -> TokenClassifierOutput: encoder_output = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, pixel_values_videos=pixel_values_videos, input_features=input_features, feature_attention_mask=feature_attention_mask, audio_feature_lengths=audio_feature_lengths, **kwargs, ) logits = self.classifier(encoder_output.last_hidden_state) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=encoder_output.hidden_states, attentions=encoder_output.attentions, ) __all__ = [ "BidirLMOmniPreTrainedModel", "BidirLMOmniAudioEncoder", "BidirLMOmniVisionModel", "BidirLMOmniTextModel", "BidirLMOmniModel", "BidirLMOmniForMaskedLM", "BidirLMOmniForSequenceClassification", "BidirLMOmniForTokenClassification", ]