""" Our Transformer-based model for the AudioSet dataset. The model is heavily inspired in the Llama-3 model: reference: https://github.com/meta-llama/llama3/blob/main/llama/model.py """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple from .configuration_biome import BioMEConfig def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device, dtype=torch.float32) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class GroupedQueryAttention(nn.Module): """ A MultiHeadGroupedQueryAttention implementation. Paper: 'GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints' (https://arxiv.org/pdf/2305.13245) Code heavily inspired on: - https://github.com/meta-llama/llama3/blob/main/llama/model.py - https://docs.pytorch.org/torchtune/0.4/_modules/torchtune/modules/attention.html Args: qdim (int): Query input dimension. Default: 512 kdim (int, optional): Key input dimension. Default: qdim vdim (int, optional): Value input dimension. Default: qdim embd_dim (int, optional): Embedding dimension after projection. Must be divisible by nheads. Default: qdim nheads (int): Number of attention heads. Default: 8 dropout (float): Dropout probability. Default: 0 bias (bool): Use bias in projections. Default: True use_gqa (bool): Enable grouped query attention. Default: False device (torch.device, optional): Device for parameters dtype (optional): Data type for parameters Shape: - Query: (B, L_q, qdim) - Key: (B, L_k, kdim) - Value: (B, L_k, vdim) - Output: (B, L_q, qdim) where B is batch size and L is sequence length """ def __init__( self, dim: int = 512, num_q_heads: int = 16, num_kv_heads: int = 4, dropout: float = 0.0, bias: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().__init__() self.dim = dim self.num_q_heads = num_q_heads self.num_kv_heads = num_kv_heads self.dropout = dropout self.bias = bias factory_kwargs = {"device": device, "dtype": dtype} assert dim % num_q_heads == 0, "Embedding dim is not divisible by nheads" self.dim_per_head = dim // num_q_heads self.q_proj = nn.Linear(self.dim, num_q_heads * self.dim_per_head, bias=bias, **factory_kwargs) self.k_proj = nn.Linear(self.dim, num_kv_heads * self.dim_per_head, bias=bias, **factory_kwargs) self.v_proj = nn.Linear(self.dim, num_kv_heads * self.dim_per_head, bias=bias, **factory_kwargs) self.out_proj = nn.Linear(num_q_heads * self.dim_per_head, self.dim, bias=bias, **factory_kwargs) def forward( self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, ) -> torch.Tensor: """ Args: x (torch.Tensor): Input tensor of shape (batch_size, tgt_len, qdim) start_pos (int): Start position for rotary embeddings freqs_cis (torch.Tensor): Rotary embeddings attn_mask (torch.Tensor): Attention mask is_causal (bool): If True, applies a causal mask to prevent attending to future positions. Returns: torch.Tensor: Output tensor of shape (batch_size, tgt_len, qdim) """ bsz, seqlen, _ = x.shape # Step 1: Apply projections xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x) # Step 2: Split the heads before the scale-dot product attention xq = xq.view(bsz, seqlen, self.num_q_heads, self.dim_per_head) xk = xk.view(bsz, seqlen, self.num_kv_heads, self.dim_per_head) xv = xv.view(bsz, seqlen, self.num_kv_heads, self.dim_per_head) # Step 3: Apply rotary embeddings xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # Step 4: Apply scale-dot product attention # Note: torch sdpa expects (batch_size, num_heads, seq_len, dim_per_head) attn_output = ( F.scaled_dot_product_attention( xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2), attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, enable_gqa=True, ) .transpose(1, 2) .flatten(-2) # (B, nheads, L, dim_per_head) -> (B, L_t, E_total) ) return self.out_proj(attn_output) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight class FiLM(nn.Module): """ A Feature-wise Linear Modulation Layer from 'FiLM: Visual Reasoning with a General Conditioning Layer' """ def __init__(self, d_model: int, context_dim: int): super().__init__() self.d_model = d_model self.context_dim = context_dim self.shared_modulator = nn.Linear(context_dim, 2 * d_model) def forward(self, x, ctx): """ Arguments ---------- x: torch.Tensor Activations / Tensor in the Transformer of shape (B, T, d_model) ctx: torch.Tensor Side channel information. It can be (B, F) or (B, T, F). If 3-dimensional, note that the sequence-dimension, T, must match the input tensor where you are going to combine the FiLM'ed result. """ params = self.shared_modulator(ctx) params = params.view(params.size(0), 1, -1) gammas, betas = params.chunk(2, dim=-1) return (gammas * x) + betas class TransformerFFN(nn.Module): def __init__(self, dim, hidden_dim, bias: bool = False): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=bias) self.w2 = nn.Linear(hidden_dim, dim, bias=bias) self.w3 = nn.Linear(dim, hidden_dim, bias=bias) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerEncoderLayer(nn.Module): def __init__(self, config: BioMEConfig): super().__init__() self.use_context = config.use_context if self.use_context: self.film = FiLM( d_model=config.hidden_size, context_dim=config.ctx_hidden_size ) self.film_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) self.film_norm_ctx = RMSNorm(config.ctx_hidden_size, eps=config.norm_eps) self.attention = GroupedQueryAttention( dim=config.hidden_size, num_q_heads=config.num_query_heads, num_kv_heads=config.num_kv_heads, dropout=config.dropout, bias=config.bias, ) self.feed_forward = TransformerFFN( dim=config.hidden_size, hidden_dim=config.ffn_hidden_size, ) self.attention_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) def forward( self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, ctx: torch.Tensor = None, padding_mask: torch.Tensor = None, ): if padding_mask is not None: x[padding_mask] = 0 h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis) if self.use_context: h = self.film(self.film_norm(h), self.film_norm_ctx(ctx)) out = h + self.feed_forward(self.ffn_norm(h)) return out