|
|
|
|
| import math
|
| from abc import ABCMeta, abstractmethod
|
| from typing import Union
|
|
|
| import dacvae
|
| import torch
|
|
|
| from sam_audio.model.config import DACVAEConfig
|
|
|
|
|
| class Encoder(torch.nn.Module, metaclass=ABCMeta):
|
| @abstractmethod
|
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: ...
|
|
|
|
|
| class Codec(Encoder):
|
| @abstractmethod
|
| def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor: ...
|
|
|
| @abstractmethod
|
| def wav_idx_to_feature_idx(
|
| self, wav_idx: Union[torch.Tensor, int], sample_rate=None
|
| ) -> Union[torch.Tensor, int]: ...
|
|
|
| @abstractmethod
|
| def feature_idx_to_wav_idx(
|
| self, feature_idx: Union[torch.Tensor, int], sample_rate=None
|
| ) -> Union[torch.Tensor, int]: ...
|
|
|
| @staticmethod
|
| def cast_to_int(
|
| x: Union[int, torch.Tensor],
|
| ) -> Union[int, torch.Tensor]:
|
| if isinstance(x, torch.Tensor):
|
| return x.int()
|
| else:
|
| return int(x)
|
|
|
|
|
| class DACVAEEncoder(Encoder):
|
| def __init__(self, config: DACVAEConfig) -> None:
|
| super().__init__()
|
| model = dacvae.DACVAE(
|
| encoder_dim=config.encoder_dim,
|
| encoder_rates=config.encoder_rates,
|
| latent_dim=config.latent_dim,
|
| decoder_dim=config.decoder_dim,
|
| decoder_rates=config.decoder_rates,
|
| n_codebooks=config.n_codebooks,
|
| codebook_size=config.codebook_size,
|
| codebook_dim=config.codebook_dim,
|
| quantizer_dropout=config.quantizer_dropout,
|
| sample_rate=config.sample_rate,
|
| ).eval()
|
| self._setup_model(model)
|
| self.hop_length = config.hop_length
|
| self.sample_rate = config.sample_rate
|
|
|
| def _setup_model(self, model):
|
| self.encoder = model.encoder
|
| self.quantizer = model.quantizer
|
|
|
| def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| with torch.no_grad():
|
| z = self.encoder(self._pad(waveform))
|
| mean, scale = self.quantizer.in_proj(z).chunk(2, dim=1)
|
| encoded_frames = mean
|
| return encoded_frames
|
|
|
| def _pad(self, wavs):
|
| length = wavs.size(-1)
|
| if length % self.hop_length:
|
| p1d = (0, self.hop_length - (length % self.hop_length))
|
| return torch.nn.functional.pad(wavs, p1d, "reflect")
|
| else:
|
| return wavs
|
|
|
|
|
| class DACVAE(DACVAEEncoder, Codec):
|
| def _setup_model(self, model):
|
| super()._setup_model(model)
|
| self.decoder = model.decoder
|
|
|
| def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor:
|
| emb = self.quantizer.out_proj(encoded_frames)
|
| return self.decoder(emb)
|
|
|
| def feature_idx_to_wav_idx(self, feature_idx, sample_rate=None):
|
| if sample_rate is None:
|
| sample_rate = self.sample_rate
|
| orig_freq = sample_rate
|
| new_freq = self.sample_rate
|
| wav_chunklen = feature_idx * self.hop_length * (orig_freq / new_freq)
|
| return self.cast_to_int(wav_chunklen)
|
|
|
| def wav_idx_to_feature_idx(self, wav_idx, sample_rate=None):
|
| ceil = math.ceil
|
| if torch.is_tensor(wav_idx):
|
| ceil = torch.ceil
|
| if sample_rate is None:
|
| sample_rate = self.sample_rate
|
| orig_freq = sample_rate
|
| new_freq = self.sample_rate
|
| target_length = ceil(new_freq * wav_idx / orig_freq)
|
| res = ceil(target_length / self.hop_length)
|
| return self.cast_to_int(res)
|
|
|