import math import torch import torch.nn as nn import torchaudio.compliance.kaldi as ta_kaldi from .biome_modules import RMSNorm from .configuration_biome import BioMEConfig from .biome_modules import precompute_freqs_cis from .biome_modules import TransformerEncoderLayer class BioME(nn.Module): def __init__(self, cfg: BioMEConfig): super().__init__() self.cfg = cfg self.n_layers = cfg.num_layers self.patch_embedding = nn.Conv2d( 1, cfg.embed_dim, kernel_size=cfg.input_patch_size, stride=cfg.input_patch_size, bias=False, ) self.dropout_input = nn.Dropout(cfg.dropout) self.post_extract_proj = ( nn.Linear(cfg.embed_dim, cfg.hidden_size) if cfg.embed_dim != cfg.hidden_size else nn.Identity() ) self.layers = torch.nn.ModuleList() for _ in range(cfg.num_layers): self.layers.append(TransformerEncoderLayer(cfg)) self.feature_norm = RMSNorm(cfg.embed_dim, eps=cfg.norm_eps) self.freqs_cis = precompute_freqs_cis( cfg.hidden_size // cfg.num_query_heads, cfg.max_seq_len * 2, cfg.rope_theta, ) self.modulation_cache = {} # Weights initialization deep_norm_beta = math.pow(8 * cfg.num_layers, -1 / 4) for i in range(cfg.num_layers): nn.init.xavier_normal_(self.layers[i].attention.k_proj.weight, gain=1) nn.init.xavier_normal_( self.layers[i].attention.v_proj.weight, gain=deep_norm_beta ) nn.init.xavier_normal_(self.layers[i].attention.q_proj.weight, gain=1) nn.init.xavier_normal_( self.layers[i].attention.out_proj.weight, gain=deep_norm_beta ) nn.init.xavier_normal_( self.layers[i].feed_forward.w1.weight, gain=deep_norm_beta ) nn.init.xavier_normal_( self.layers[i].feed_forward.w2.weight, gain=deep_norm_beta ) nn.init.xavier_normal_( self.layers[i].feed_forward.w3.weight, gain=deep_norm_beta ) def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) return padding_mask def forward( self, wavs: torch.Tensor, start_pos: int, padding_mask: torch.Tensor = None, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, apply_mask: bool = False, ): # 1. Get input features fbank = self.wav_to_fbank(wavs, fbank_mean=fbank_mean, fbank_std=fbank_std) ctx = self.get_modulation_spectrum(wavs) # Side-channel (MSAB) features # 2. Patchfy the input features = self.feature_patchfy(fbank) patch_padding_mask = None if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask) patch_padding_mask = padding_mask.clone() ids_restore, kept_mask = None, None if apply_mask: B, T, F = features.shape u = torch.rand(B, T, device=features.device) to_mask = (u < self.cfg.mlm_mask_prob) kept_mask = ~to_mask features = features.masked_fill(~kept_mask.unsqueeze(-1), 0.0) features = self.post_extract_proj(features) _, seqlen, _ = features.shape # 3. Apply positional encoding if self.freqs_cis.device.type == "meta": self.freqs_cis = self._get_freqs_cis() self.freqs_cis = self.freqs_cis.to(features.device) freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] # 4. Apply transformer layers x = self.dropout_input(features) layer_results = [] for layer in self.layers: x = layer( x, start_pos=0, freqs_cis=freqs_cis, ctx=ctx, padding_mask=padding_mask ) layer_results.append(x) # 5. Apply post-processing return x, layer_results, padding_mask, ids_restore, kept_mask, patch_padding_mask def wav_to_fbank( self, source: torch.Tensor, fbank_mean: float = -4.268, fbank_std: float = 4.569, ): fbanks = [] for waveform in source: waveform = waveform.unsqueeze(0) * 2**15 fbank = ta_kaldi.fbank( waveform, num_mel_bins=self.cfg.n_mels, sample_frequency=self.cfg.sample_rate, frame_length=self.cfg.frame_length, frame_shift=self.cfg.frame_shift, use_energy=False, window_type="hanning", dither=0.0, ) fbanks.append(fbank) fbank = torch.stack(fbanks, dim=0) fbank = (fbank - fbank_mean) / (2 * fbank_std) return fbank def feature_patchfy(self, rep: torch.Tensor) -> torch.Tensor: """ Patchify the feature representation. """ rep = rep.unsqueeze(1) features = self.patch_embedding(rep) features = features.reshape(features.shape[0], features.shape[1], -1) features = features.transpose(1, 2) features = self.feature_norm(features) return features def _get_freqs_cis(self): return precompute_freqs_cis( self.cfg.hidden_size // self.cfg.num_query_heads, self.cfg.max_seq_len * 2, self.cfg.rope_theta, ) @torch.no_grad() def normalize_fft(self, spec_data, window, n_samples, n_fft, fs): # Normalizations win_rms = torch.sqrt(window.pow(2.0).sum() / n_samples) # Compute the power spectrogram spec_data /= win_rms spec_data = spec_data.abs().pow( 2.0 ) # same as X_pwr = abs(np.multiply(Xt, np.conj(Xt))) spec_data *= 1.0 / n_fft**2 # make it orthonormal if n_fft % 2 != 0: n_freqs = (n_fft + 1) / 2 spec_data[ :, 1:, : ] *= 2 # double all frequency components except DC component else: n_freqs = (n_fft / 2) + 1 spec_data[ :, 1:-1, : ] *= 2 # double all frequency components except DC and fs/2 components f_delta = fs / n_fft spec_data = torch.divide(spec_data, f_delta) # scale by frequency delta return f_delta, spec_data @torch.no_grad() def get_modulation_spectrum(self, wavs: torch.Tensor): # number of samples and number of channels _, n_samples = wavs.shape # Step 1: compute STFT spectrogram window = torch.hamming_window( self.cfg.mss_win_size, periodic=True, device=wavs.device ) spec_data = torch.stft( wavs, n_fft=self.cfg.mss_n_fft1, win_length=self.cfg.mss_win_size, hop_length=self.cfg.mss_win_shift, window=window, return_complex=True, onesided=True, ) # We add pad while old code remove the last window if necessary _, _, n_windows = spec_data.shape # Normalizations _, spec_data = self.normalize_fft( spec_data, window, n_samples, self.cfg.mss_n_fft1, self.cfg.sample_rate ) # Step 2: Modulation Features # modulation sampling frequency fs_mod = 1 / (self.cfg.mss_win_shift / self.cfg.sample_rate) n_fft2 = self.cfg.mss_n_fft2 if n_fft2 is None: n_fft2 = n_windows # the AM analysis is made in the Amplitude derived from the Power Spectrogram window = torch.hamming_window(n_windows, periodic=True, device=wavs.device) spec_data = torch.multiply(spec_data, window) mod_psd = torch.fft.rfft(spec_data, n=n_fft2, dim=2) _, mod_psd = self.normalize_fft(mod_psd, window, n_samples, n_fft2, fs_mod) return torch.cat([mod_psd.mean(dim=1), mod_psd.mean(dim=2)], dim=1)