BidirLM-Omni-2.5B-Embedding / modeling_bidirlm_omni.py
Nicolas-BZRD's picture
BidirLM-Omni-2.5B-Embedding-v2
4d8a7d3
Raw
History Blame Contribute Delete
72.3 kB
import math
from typing import Callable, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.integrations import use_kernel_forward_from_hub
from transformers.modeling_layers import GradientCheckpointingLayer
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
BaseModelOutput,
MaskedLMOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import auto_docstring
from transformers.utils.generic import TransformersKwargs
try:
from .configuration_bidirlm_omni import (
BidirLMOmniAudioConfig,
BidirLMOmniConfig,
BidirLMOmniTextConfig,
BidirLMOmniVisionConfig,
)
except ImportError:
from configuration_bidirlm_omni import (
BidirLMOmniAudioConfig,
BidirLMOmniConfig,
BidirLMOmniTextConfig,
BidirLMOmniVisionConfig,
)
# ═══════════════════════════════════════════════════════════════════════════
# Shared utilities
# ═══════════════════════════════════════════════════════════════════════════
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def _get_feat_extract_output_lengths(input_lengths):
# Three Conv2d layers each with kernel=3, stride=2, padding=1.
# Per-layer formula: floor((L - 1) / 2) + 1
L = (input_lengths - 1) // 2 + 1
L = (L - 1) // 2 + 1
L = (L - 1) // 2 + 1
return L
@use_kernel_forward_from_hub("RMSNorm")
class BidirLMOmniRMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
# ═══════════════════════════════════════════════════════════════════════════
# Audio encoder (copied from BiQwen3-SP)
# ═══════════════════════════════════════════════════════════════════════════
class BidirLMOmniAudioAttention(nn.Module):
def __init__(self, config: BidirLMOmniAudioConfig):
super().__init__()
self.embed_dim = config.d_model
self.num_heads = config.encoder_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.num_key_value_groups = 1
self.config = config
self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_causal = False
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads "
f"(got embed_dim={self.embed_dim}, num_heads={self.num_heads})."
)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
seq_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
return self.out_proj(attn_output)
class BidirLMOmniAudioEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BidirLMOmniAudioConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BidirLMOmniAudioAttention(config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
return (hidden_states,)
class SinusoidsPositionEmbedding(nn.Module):
def __init__(self, length, channels, max_timescale=10000):
super().__init__()
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
# Store scalars so forward can recompute in float32.
# tf5 casts persistent=False buffers to the model dtype (bfloat16),
# which degrades sinusoidal precision and diverges from tf4 numerics.
self.length = length
self.channels = channels
self.max_timescale = max_timescale
# Register a dummy buffer only for device tracking.
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
self.register_buffer(
"positional_embedding",
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
persistent=False,
)
def _recompute(self, seqlen: int, device) -> torch.Tensor:
log_timescale_increment = np.log(self.max_timescale) / (self.channels // 2 - 1)
inv_timescales = torch.exp(
-log_timescale_increment * torch.arange(self.channels // 2, dtype=torch.float32, device=device)
)
scaled_time = torch.arange(seqlen, dtype=torch.float32, device=device)[:, None] * inv_timescales[None, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
def forward(self, seqlen: int):
# Recompute in float32 every call — do NOT use the stored buffer whose
# values may have been cast to bfloat16 by from_pretrained.
return self._recompute(seqlen, self.positional_embedding.device)
class BidirLMOmniAudioEncoder(PreTrainedModel):
config: BidirLMOmniAudioConfig
main_input_name = "input_features"
_no_split_modules = ["BidirLMOmniAudioEncoderLayer"]
_supports_sdpa = True
def __init__(self, config: BidirLMOmniAudioConfig):
super().__init__(config)
self.dropout = config.dropout
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.n_window = config.n_window
self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim)
self.layers = nn.ModuleList([BidirLMOmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)])
self.ln_post = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1)
self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1)
self.conv_out = nn.Linear(
config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
config.d_model,
bias=False,
)
self.proj1 = nn.Linear(config.d_model, config.d_model)
self.act = ACT2FN[config.activation_function]
self.proj2 = nn.Linear(config.d_model, config.output_dim)
self.n_window_infer = self.config.n_window_infer
self.conv_chunksize = self.config.conv_chunksize
self.post_init()
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> Optional[torch.Tensor]:
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def forward(self, input_features, feature_lens=None, aftercnn_lens=None):
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
chunk_lengths = torch.tensor(
[self.n_window * 2] * chunk_num.sum(),
dtype=torch.long,
device=feature_lens.device,
)
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2)
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
[torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn],
batch_first=True,
)
padded_feature = padded_feature.unsqueeze(1)
padded_embeds = []
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
padded_embed = F.gelu(self.conv2d1(chunk))
padded_embed = F.gelu(self.conv2d2(padded_embed))
padded_embed = F.gelu(self.conv2d3(padded_embed))
padded_embeds.append(padded_embed)
padded_embed = torch.cat(padded_embeds, dim=0)
b, c, f, t = padded_embed.size()
padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f))
# Call forward() which recomputes sinusoids in float32 (avoids bfloat16 buffer precision loss).
positional_embedding = (
self.positional_embedding(padded_embed.shape[1])
.unsqueeze(0)
.to(padded_embed.dtype)
)
padded_embed = padded_embed + positional_embedding
hidden_states = padded_embed[padded_mask_after_cnn]
cu_chunk_lens = [0]
window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2))
for cnn_len in aftercnn_lens:
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
remainder = cnn_len % window_aftercnn
if remainder != 0:
cu_chunk_lens += [remainder]
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32)
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states, cu_seqlens)
hidden_states = layer_outputs[0]
hidden_states = self.ln_post(hidden_states)
hidden_states = self.proj1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.proj2(hidden_states)
return BaseModelOutput(last_hidden_state=hidden_states)
# ═══════════════════════════════════════════════════════════════════════════
# Vision encoder (copied from BiQwen3-VL)
# ═══════════════════════════════════════════════════════════════════════════
class BidirLMOmniVisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
class BidirLMOmniVisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
class BidirLMOmniVisionRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
# Store theta/dim so forward can recompute inv_freq in float32.
# We still register a buffer for device tracking, but don't use its values.
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
# Recompute inv_freq in float32 — matches training precision.
inv_freq = 1.0 / (
self.theta ** (
torch.arange(0, self.dim, 2, dtype=torch.float32, device=self.inv_freq.device)
/ self.dim
)
)
seq = torch.arange(seqlen, dtype=torch.float32, device=self.inv_freq.device)
return torch.outer(seq, inv_freq)
class BidirLMOmniVisionPatchMerger(nn.Module):
def __init__(self, config: BidirLMOmniVisionConfig, use_postshuffle_norm=False) -> None:
super().__init__()
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
self.act_fn = nn.GELU()
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return x
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(orig_q_dtype), k_embed.to(orig_k_dtype)
class BidirLMOmniVisionAttention(nn.Module):
def __init__(self, config: BidirLMOmniVisionConfig) -> None:
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if self.config._attn_implementation == "flash_attention_2":
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=self.is_causal,
**kwargs,
)
else:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2)
for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self, q, k, v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=self.is_causal,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class BidirLMOmniVisionBlock(GradientCheckpointingLayer):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.attn = BidirLMOmniVisionAttention(config=config)
self.mlp = BidirLMOmniVisionMLP(config=config)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
class BidirLMOmniVisionModel(PreTrainedModel):
config: BidirLMOmniVisionConfig
_no_split_modules = ["BidirLMOmniVisionBlock"]
def __init__(self, config, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.patch_embed = BidirLMOmniVisionPatchEmbed(config=config)
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = BidirLMOmniVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([BidirLMOmniVisionBlock(config) for _ in range(config.depth)])
self.merger = BidirLMOmniVisionPatchMerger(config=config, use_postshuffle_norm=False)
self.deepstack_visual_indexes = config.deepstack_visual_indexes
self.deepstack_merger_list = nn.ModuleList(
[
BidirLMOmniVisionPatchMerger(config=config, use_postshuffle_norm=True)
for _ in range(len(config.deepstack_visual_indexes))
]
)
self.gradient_checkpointing = False
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size
max_hw = int(grid_thw[:, 1:].max().item())
freq_table = self.rotary_pos_emb(max_hw)
device = freq_table.device
total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw:
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra_row = torch.arange(merge_size, device=device)
intra_col = torch.arange(merge_size, device=device)
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
num_tokens = coords.shape[0]
pos_ids[offset : offset + num_tokens] = coords
offset += num_tokens
embeddings = freq_table[pos_ids]
return embeddings.flatten(1)
def fast_pos_embed_interpolate(self, grid_thw):
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
for t, h, w in zip(grid_ts, grid_hs, grid_ws):
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for i in range(4):
idx_list[i].extend(indices[i].tolist())
weight_list[i].extend(weights[i].tolist())
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device)
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
merge_size = self.config.spatial_merge_size
patch_pos_embeds_permute = []
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
return torch.cat(patch_pos_embeds_permute)
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states + self.fast_pos_embed_interpolate(grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len = hidden_states.shape[0]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
**kwargs,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
hidden_states
)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
return hidden_states, deepstack_feature_lists
# ═══════════════════════════════════════════════════════════════════════════
# Shared text encoder (supports both audio injection + DeepStack visual)
# ═══════════════════════════════════════════════════════════════════════════
class BidirLMOmniTextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: BidirLMOmniTextConfig, device=None):
super().__init__()
self.rope_type = (
config.rope_scaling.get("rope_type", "default")
if hasattr(config, "rope_scaling") and config.rope_scaling is not None
else "default"
)
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
# transformers 5.x removed "default" from ROPE_INIT_FUNCTIONS.
# For rope_type="default" (standard inv_freq, no scaling) we compute directly.
if self.rope_type == "default" or self.rope_type not in ROPE_INIT_FUNCTIONS:
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
rope_theta = getattr(config, "rope_theta", 10000.0)
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) / head_dim)
)
self.attention_scaling = 1.0
else:
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.mrope_section = (config.rope_scaling or {}).get("mrope_section", [24, 20, 20])
def compute_default_rope_parameters(self, config=None):
"""Required by transformers 5.x _init_weights when rope_type='default'."""
config = config or self.config
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
rope_theta = getattr(config, "rope_theta", 10000.0)
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)
)
return inv_freq, 1.0
def apply_interleaved_mrope(self, freqs, mrope_section):
freqs_t = freqs[0]
for dim, offset in enumerate((1, 2), start=1):
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t
@torch.no_grad()
def forward(self, x, position_ids):
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class BidirLMOmniTextAttention(nn.Module):
"""Bidirectional multi-head attention (no causal mask, no KV cache)."""
def __init__(self, config: BidirLMOmniTextConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = False
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
self.q_norm = BidirLMOmniRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = BidirLMOmniRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.sliding_window = None
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return self.o_proj(attn_output)
class BidirLMOmniTextMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class BidirLMOmniTextEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BidirLMOmniTextConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = BidirLMOmniTextAttention(config=config, layer_idx=layer_idx)
self.mlp = BidirLMOmniTextMLP(config)
self.input_layernorm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn(
hidden_states=self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = residual + hidden_states
return hidden_states
# ═══════════════════════════════════════════════════════════════════════════
# PreTrainedModel base
# ═══════════════════════════════════════════════════════════════════════════
@auto_docstring
class BidirLMOmniPreTrainedModel(PreTrainedModel):
config: BidirLMOmniConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = [
"BidirLMOmniTextEncoderLayer",
"BidirLMOmniAudioEncoderLayer",
"BidirLMOmniVisionBlock",
]
_supports_flash_attn = True
_supports_sdpa = True
_supports_attention_backend = True
# ═══════════════════════════════════════════════════════════════════════════
# Text encoder model (with DeepStack visual injection support)
# ═══════════════════════════════════════════════════════════════════════════
class BidirLMOmniTextModel(BidirLMOmniPreTrainedModel):
"""
Bidirectional text encoder. Supports:
- audio feature injection via ``masked_scatter``
- DeepStack visual feature injection at intermediate layers
"""
config: BidirLMOmniTextConfig
_no_split_modules = ["BidirLMOmniTextEncoderLayer"]
def __init__(self, config: BidirLMOmniTextConfig):
super().__init__(config)
self.padding_idx = getattr(config, "pad_token_id", None)
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[BidirLMOmniTextEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = BidirLMOmniRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = BidirLMOmniTextRotaryEmbedding(config)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# DeepStack visual injection args
visual_pos_masks: Optional[torch.Tensor] = None,
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("Specify exactly one of input_ids or inputs_embeds.")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_len = inputs_embeds.shape[:2]
if position_ids is None:
position_ids = torch.arange(seq_len, device=inputs_embeds.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, batch_size, -1)
extended_attention_mask: Optional[torch.Tensor] = None
if attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
# Flash attention computes cu_seqlens from a 2D mask internally;
# passing a 4D mask breaks the varlen path.
extended_attention_mask = attention_mask
else:
# Convert 1/0 mask to additive float mask (0.0 = attend, -inf = ignore).
# The old boolean expand (True→+1, False→+0) added to attn_weights was NOT
# a real mask: padding tokens still participated in softmax, corrupting
# embeddings of shorter sequences when batched with longer ones.
float_mask = attention_mask.to(dtype=inputs_embeds.dtype)
extended_attention_mask = (
(1.0 - float_mask)[:, None, None, :] * torch.finfo(inputs_embeds.dtype).min
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer_idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states,
attention_mask=extended_attention_mask,
position_embeddings=position_embeddings,
**kwargs,
)
# DeepStack: add visual features at intermediate layers
if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
hidden_states = self._deepstack_process(
hidden_states,
visual_pos_masks,
deepstack_visual_embeds[layer_idx],
)
hidden_states = self.norm(hidden_states)
return BaseModelOutput(last_hidden_state=hidden_states)
def _deepstack_process(
self,
hidden_states: torch.Tensor,
visual_pos_masks: torch.Tensor,
visual_embeds: torch.Tensor,
) -> torch.Tensor:
visual_pos_masks = visual_pos_masks.to(hidden_states.device)
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
hidden_states[visual_pos_masks, :] = hidden_states[visual_pos_masks, :].clone() + visual_embeds
return hidden_states
# ═══════════════════════════════════════════════════════════════════════════
# Top-level Omni model
# ═══════════════════════════════════════════════════════════════════════════
@auto_docstring(
custom_intro="Multimodal encoder combining audio tower, vision tower, and shared bidirectional text encoder."
)
class BidirLMOmniModel(BidirLMOmniPreTrainedModel):
"""
Audio + Vision + Text omni encoder.
Accepts any combination of modalities: text-only, text+audio, text+vision, text+audio+vision.
"""
config: BidirLMOmniConfig
def __init__(self, config: BidirLMOmniConfig):
super().__init__(config)
# Flash/SDPA attention only applies to the text encoder;
# audio and vision towers always run eager (no causal masking or varlen path needed).
config.audio_config._attn_implementation = "eager"
config.vision_config._attn_implementation = "eager"
config.text_config._attn_implementation = config._attn_implementation
self.audio_tower = BidirLMOmniAudioEncoder._from_config(config.audio_config)
self.visual = BidirLMOmniVisionModel._from_config(config.vision_config)
self.language_model = BidirLMOmniTextModel._from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# ── Audio helpers ──────────────────────────────────────────────────
def get_audio_features(
self,
input_features: torch.FloatTensor,
feature_attention_mask: Optional[torch.LongTensor] = None,
audio_feature_lengths: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)
audio_features = []
for input_feature, feature_len in zip(input_features, feature_lens):
audio_output = self.audio_tower(
input_feature[:, :feature_len],
feature_lens=feature_len.unsqueeze(0),
)
audio_features.append(audio_output.last_hidden_state)
return torch.cat(audio_features, dim=0)
def get_audio_placeholder_mask(
self,
input_ids: Optional[torch.LongTensor],
inputs_embeds: torch.FloatTensor,
) -> torch.Tensor:
if input_ids is None:
special_audio_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
)
).all(-1)
else:
special_audio_mask = input_ids == self.config.audio_token_id
return special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
# ── Vision helpers ─────────────────────────────────────────────────
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
):
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
return image_embeds, deepstack_image_embeds
def get_video_features(
self,
pixel_values_videos: torch.FloatTensor,
video_grid_thw: Optional[torch.LongTensor] = None,
):
return self.get_image_features(pixel_values_videos, video_grid_thw)
def get_vision_placeholder_mask(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
image_features: Optional[torch.FloatTensor] = None,
video_features: Optional[torch.FloatTensor] = None,
):
if input_ids is None:
img_embed = self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
vid_embed = self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = (inputs_embeds == img_embed).all(-1)
special_video_mask = (inputs_embeds == vid_embed).all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
special_image_mask_expanded = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
special_video_mask_expanded = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if image_features is not None and inputs_embeds[special_image_mask_expanded].numel() != image_features.numel():
n_image_tokens = special_image_mask.sum()
raise ValueError(
f"Image features and image tokens do not match: tokens {n_image_tokens}, features {image_features.shape[0]}"
)
if video_features is not None and inputs_embeds[special_video_mask_expanded].numel() != video_features.numel():
n_video_tokens = special_video_mask.sum()
raise ValueError(
f"Video features and video tokens do not match: tokens {n_video_tokens}, features {video_features.shape[0]}"
)
return special_image_mask_expanded, special_video_mask_expanded
# ── MRoPE position ids ─────────────────────────────────────────────
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Build 3-D MRoPE position ids. Returns (3, batch, seq_len)."""
if video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
video_grid_thw[:, 0] = 1
spatial_merge_size = self.config.vision_config.spatial_merge_size
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1],
dtype=input_ids.dtype, device=input_ids.device,
)
image_index, video_index = 0, 0
for i, ids in enumerate(total_input_ids):
ids = ids[attention_mask[i] == 1]
vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1)
vision_tokens = ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
ed_image = input_tokens.index(image_token_id, st) if image_token_id in input_tokens and remain_images > 0 else len(input_tokens) + 1
ed_video = input_tokens.index(video_token_id, st) if video_token_id in input_tokens and remain_videos > 0 else len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1; remain_images -= 1; ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_index += 1; remain_videos -= 1; ed = ed_video
llm_grid_t = t.item()
llm_grid_h = h.item() // spatial_merge_size
llm_grid_w = w.item() // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(torch.arange(len(input_tokens) - st).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
return position_ids
# Text-only / audio-only path (no spatial position structure)
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.shape[0], -1)
return position_ids.unsqueeze(0).expand(3, -1, -1)
# ── Forward ────────────────────────────────────────────────────────
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
# Audio inputs
input_features: Optional[torch.FloatTensor] = None,
feature_attention_mask: Optional[torch.LongTensor] = None,
audio_feature_lengths: Optional[torch.LongTensor] = None,
# Vision inputs
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("Specify exactly one of input_ids or inputs_embeds.")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# ── Audio injection ────────────────────────────────────────────
if input_features is not None:
audio_features = self.get_audio_features(
input_features,
feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths,
).to(inputs_embeds.device, inputs_embeds.dtype)
audio_mask = self.get_audio_placeholder_mask(input_ids, inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
# ── Vision injection ───────────────────────────────────────────
image_mask = video_mask = None
if pixel_values is not None:
image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_vision_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_vision_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
# ── Assemble DeepStack masks / embeds ──────────────────────────
visual_pos_masks = deepstack_visual_embeds = None
if image_mask is not None and video_mask is not None:
im = image_mask[..., 0]; vm = video_mask[..., 0]
visual_pos_masks = im | vm
image_mask_joint = im[visual_pos_masks]
video_mask_joint = vm[visual_pos_masks]
deepstack_visual_embeds = []
for img_e, vid_e in zip(deepstack_image_embeds, deepstack_video_embeds):
joint = img_e.new_zeros(visual_pos_masks.sum(), img_e.shape[-1]).to(img_e.device)
joint[image_mask_joint] = img_e
joint[video_mask_joint] = vid_e
deepstack_visual_embeds.append(joint)
elif image_mask is not None:
visual_pos_masks = image_mask[..., 0]
deepstack_visual_embeds = deepstack_image_embeds
elif video_mask is not None:
visual_pos_masks = video_mask[..., 0]
deepstack_visual_embeds = deepstack_video_embeds
# ── Build position ids ─────────────────────────────────────────
if position_ids is None:
position_ids = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
return self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
**kwargs,
)
# ═══════════════════════════════════════════════════════════════════════════
# Masked language model head
# ═══════════════════════════════════════════════════════════════════════════
@auto_docstring
class BidirLMOmniForMaskedLM(BidirLMOmniPreTrainedModel):
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
def __init__(self, config: BidirLMOmniConfig):
super().__init__(config)
self.model = BidirLMOmniModel(config)
self.lm_head = nn.Linear(
config.text_config.hidden_size,
config.text_config.vocab_size,
bias=False,
)
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
input_features: Optional[torch.FloatTensor] = None,
feature_attention_mask: Optional[torch.LongTensor] = None,
audio_feature_lengths: Optional[torch.LongTensor] = None,
**kwargs,
) -> MaskedLMOutput:
encoder_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
pixel_values_videos=pixel_values_videos,
input_features=input_features,
feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths,
**kwargs,
)
logits = self.lm_head(encoder_output.last_hidden_state)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, vocab_size=self.config.text_config.vocab_size)
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=encoder_output.hidden_states,
attentions=encoder_output.attentions,
)
# ═══════════════════════════════════════════════════════════════════════════
# Sequence classification head
# ═══════════════════════════════════════════════════════════════════════════
@auto_docstring
class BidirLMOmniForSequenceClassification(BidirLMOmniPreTrainedModel):
def __init__(self, config: BidirLMOmniConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.clf_pooling = config.clf_pooling
self.model = BidirLMOmniModel(config)
self.dense = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
self.activation = nn.GELU()
self.classifier = nn.Linear(config.text_config.hidden_size, self.num_labels)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
input_features: Optional[torch.FloatTensor] = None,
feature_attention_mask: Optional[torch.LongTensor] = None,
audio_feature_lengths: Optional[torch.LongTensor] = None,
**kwargs,
) -> SequenceClassifierOutput:
encoder_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
pixel_values_videos=pixel_values_videos,
input_features=input_features,
feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths,
**kwargs,
)
last_hidden_state = encoder_output.last_hidden_state
if self.clf_pooling == "bos":
pooled = last_hidden_state[:, 0]
elif self.clf_pooling == "mean":
if attention_mask is None:
pooled = last_hidden_state.mean(dim=1)
else:
pooled = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1)
pooled = pooled / attention_mask.sum(dim=1, keepdim=True)
else: # "late" — project each token then mean-pool
pooled = last_hidden_state
pooled = self.dense(pooled)
pooled = self.activation(pooled)
logits = self.classifier(pooled)
if self.clf_pooling == "late":
if attention_mask is None:
logits = logits.mean(dim=1)
else:
logits = (logits * attention_mask.unsqueeze(-1)).sum(dim=1)
logits = logits / attention_mask.sum(dim=1, keepdim=True)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.squeeze(), labels.squeeze())
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_output.hidden_states,
attentions=encoder_output.attentions,
)
# ═══════════════════════════════════════════════════════════════════════════
# Token classification head
# ═══════════════════════════════════════════════════════════════════════════
@auto_docstring
class BidirLMOmniForTokenClassification(BidirLMOmniPreTrainedModel):
def __init__(self, config: BidirLMOmniConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.model = BidirLMOmniModel(config)
self.classifier = nn.Linear(config.text_config.hidden_size, self.num_labels)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.Tensor] = None,
input_features: Optional[torch.FloatTensor] = None,
feature_attention_mask: Optional[torch.LongTensor] = None,
audio_feature_lengths: Optional[torch.LongTensor] = None,
**kwargs,
) -> TokenClassifierOutput:
encoder_output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
pixel_values_videos=pixel_values_videos,
input_features=input_features,
feature_attention_mask=feature_attention_mask,
audio_feature_lengths=audio_feature_lengths,
**kwargs,
)
logits = self.classifier(encoder_output.last_hidden_state)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_output.hidden_states,
attentions=encoder_output.attentions,
)
__all__ = [
"BidirLMOmniPreTrainedModel",
"BidirLMOmniAudioEncoder",
"BidirLMOmniVisionModel",
"BidirLMOmniTextModel",
"BidirLMOmniModel",
"BidirLMOmniForMaskedLM",
"BidirLMOmniForSequenceClassification",
"BidirLMOmniForTokenClassification",
]