|
|
|
|
| from typing import Optional, Tuple
|
|
|
| import numpy as np
|
| from core.audio_visual_encoder.config import TransformerConfig as PEAVTransformerConfig
|
| from transformers import ModernBertConfig
|
|
|
|
|
| class DACVAEConfig:
|
| def __init__(
|
| self,
|
| encoder_dim: int = 64,
|
| encoder_rates: list[int] = [2, 8, 10, 12],
|
| latent_dim: int = 1024,
|
| decoder_dim: int = 1536,
|
| decoder_rates: list[int] = [12, 10, 8, 2],
|
| n_codebooks: int = 16,
|
| codebook_size: int = 1024,
|
| codebook_dim: int = 128,
|
| quantizer_dropout: bool = False,
|
| sample_rate: int = 48_000,
|
| mean: float = 0.0,
|
| std: float = 1.0,
|
| ):
|
| self.encoder_dim = encoder_dim
|
| self.encoder_rates = encoder_rates
|
| self.latent_dim = latent_dim
|
| self.decoder_dim = decoder_dim
|
| self.decoder_rates = decoder_rates
|
| self.n_codebooks = n_codebooks
|
| self.codebook_size = codebook_size
|
| self.codebook_dim = codebook_dim
|
| self.quantizer_dropout = quantizer_dropout
|
| self.sample_rate = sample_rate
|
| self.mean = mean
|
| self.std = std
|
|
|
| @property
|
| def hop_length(self):
|
| return int(np.prod(self.encoder_rates))
|
|
|
|
|
| class TextEncoderConfig:
|
| def __init__(self, dim: int = 768):
|
| self.dim = dim
|
|
|
|
|
| class T5EncoderConfig(TextEncoderConfig):
|
| def __init__(
|
| self,
|
| name: str = "t5-base",
|
| max_length: Optional[int] = 512,
|
| pad_mode: str = "longest",
|
| dim: int = 768,
|
| ):
|
| super().__init__(dim=dim)
|
| self.name = name
|
| self.max_length = max_length
|
| self.pad_mode = pad_mode
|
|
|
|
|
| class VisionEncoderConfig:
|
| def __init__(self, dim: int = 1024, batch_size: int = 300):
|
| self.dim = dim
|
| self.batch_size = batch_size
|
|
|
|
|
| class PerceptionEncoderConfig(VisionEncoderConfig):
|
| def __init__(
|
| self,
|
| dim: int = 1024,
|
| batch_size: int = 300,
|
| name: str = "PE-Core-L14-336",
|
| normalize_feature: bool = True,
|
| interpolation_mode: str = "BICUBIC",
|
| image_size: int = 336,
|
| ):
|
| super().__init__(dim=dim, batch_size=batch_size)
|
| self.name = name
|
| self.normalize_feature = normalize_feature
|
| self.interpolation_mode = interpolation_mode
|
| self.image_size = image_size
|
|
|
|
|
| class TransformerConfig:
|
| def __init__(
|
| self,
|
| dim: int = 2048,
|
| n_heads: int = 16,
|
| n_layers: int = 16,
|
| dropout: float = 0.1,
|
| norm_eps: float = 1.0e-05,
|
| qk_norm: bool = True,
|
| fc_bias: bool = False,
|
| ffn_exp: int = 4,
|
| ffn_dim_multiplier: int = 1,
|
| multiple_of: int = 64,
|
| non_linearity: str = "swiglu",
|
| use_rope: bool = True,
|
| max_positions: int = 10000,
|
| frequency_embedding_dim: int = 256,
|
| timestep_non_linearity: str = "swiglu",
|
| t_block_non_linearity: str = "silu",
|
| t_block_bias: bool = True,
|
| context_dim: int = 2048,
|
| context_non_linearity: str = "swiglu",
|
| context_embedder_dropout: float = 0.0,
|
| context_norm: bool = False,
|
| out_channels: int = 256,
|
| in_channels: Optional[int] = None,
|
| ):
|
| self.dim = dim
|
| self.n_heads = n_heads
|
| self.n_layers = n_layers
|
| self.dropout = dropout
|
| self.norm_eps = norm_eps
|
| self.qk_norm = qk_norm
|
| self.fc_bias = fc_bias
|
| self.ffn_exp = ffn_exp
|
| self.ffn_dim_multiplier = ffn_dim_multiplier
|
| self.multiple_of = multiple_of
|
| self.non_linearity = non_linearity
|
| self.use_rope = use_rope
|
| self.max_positions = max_positions
|
| self.frequency_embedding_dim = frequency_embedding_dim
|
| self.timestep_non_linearity = timestep_non_linearity
|
| self.t_block_non_linearity = t_block_non_linearity
|
| self.t_block_bias = t_block_bias
|
| self.context_dim = context_dim
|
| self.context_non_linearity = context_non_linearity
|
| self.context_embedder_dropout = context_embedder_dropout
|
| self.context_norm = context_norm
|
| self.out_channels = out_channels
|
| self.in_channels = in_channels
|
|
|
|
|
| class RankerConfig:
|
| kind: str
|
|
|
|
|
| class ImageBindRankerConfig(RankerConfig):
|
| kind: str = "imagebind"
|
|
|
| def __init__(self, checkpoint: Optional[str] = None):
|
| self.checkpoint = checkpoint
|
|
|
|
|
| class ClapRankerConfig(RankerConfig):
|
| kind: str = "clap"
|
|
|
| def __init__(self, checkpoint: Optional[str] = None):
|
| self.checkpoint = checkpoint
|
|
|
|
|
| class JudgeRankerConfig(RankerConfig):
|
| kind: str = "judge"
|
|
|
| def __init__(self, checkpoint_or_model_id: str = "facebook/sam-audio-judge"):
|
| self.checkpoint_or_model_id = checkpoint_or_model_id
|
|
|
|
|
| class SoundActivityRankerConfig(RankerConfig):
|
| kind: str = "sound_activity"
|
|
|
| def __init__(
|
| self,
|
| threshold_mode: str = "rel_to_max",
|
| sil_threshold: float = -40,
|
| metric: str = "iou",
|
| ):
|
| self.threshold_mode = threshold_mode
|
| self.sil_threshold = sil_threshold
|
| self.metric = metric
|
|
|
|
|
| class EnsembleRankerConfig(RankerConfig):
|
| kind: str = "ensemble"
|
|
|
| def __init__(self, rankers: dict[str, Tuple[RankerConfig, float]]):
|
| self.rankers = rankers
|
|
|
|
|
| def parse_ranker_config(config_dict: dict):
|
| kind = config_dict.pop("kind")
|
| match kind:
|
| case ImageBindRankerConfig.kind:
|
| return ImageBindRankerConfig(**config_dict)
|
| case ClapRankerConfig.kind:
|
| return ClapRankerConfig(**config_dict)
|
| case JudgeRankerConfig.kind:
|
| return JudgeRankerConfig(**config_dict)
|
| case SoundActivityRankerConfig.kind:
|
| return SoundActivityRankerConfig(**config_dict)
|
| case EnsembleRankerConfig.kind:
|
| return EnsembleRankerConfig(
|
| {
|
| k: (parse_ranker_config(v), w)
|
| for k, (v, w) in config_dict["rankers"].items()
|
| }
|
| )
|
|
|
|
|
| class SAMAudioConfig:
|
| def __init__(
|
| self,
|
| in_channels: int = 768,
|
| audio_codec=None,
|
| text_encoder=None,
|
| vision_encoder=None,
|
| transformer=None,
|
| num_anchors: int = 3,
|
| anchor_embedding_dim: int = 128,
|
| visual_ranker=None,
|
| text_ranker=None,
|
| span_predictor: Optional[str] = "pe-a-frame-large",
|
| ):
|
| self.in_channels = in_channels
|
| self.audio_codec = DACVAEConfig(**(audio_codec or {}))
|
| self.text_encoder = T5EncoderConfig(**(text_encoder or {}))
|
| self.vision_encoder = PerceptionEncoderConfig(**(vision_encoder or {}))
|
| self.transformer = TransformerConfig(**(transformer or {}))
|
| self.num_anchors = num_anchors
|
| self.anchor_embedding_dim = anchor_embedding_dim
|
| self.visual_ranker = (
|
| None if visual_ranker is None else parse_ranker_config(visual_ranker)
|
| )
|
| self.text_ranker = (
|
| None if text_ranker is None else parse_ranker_config(text_ranker)
|
| )
|
| self.span_predictor = span_predictor
|
|
|
|
|
| class SAMAudioJudgeConfig:
|
| def __init__(
|
| self,
|
| audio_codec: DACVAEConfig = None,
|
| transformer: PEAVTransformerConfig = None,
|
| text_model: ModernBertConfig = None,
|
| finetune_transformer: PEAVTransformerConfig = None,
|
| nth_text_layer: int = 22,
|
| bottleneck_dim: int = 256,
|
| ):
|
| self.audio_codec = DACVAEConfig(**(audio_codec or {}))
|
| self.transformer = PEAVTransformerConfig(**(transformer or {}))
|
| self.text_model = ModernBertConfig(**(text_model or {}))
|
| self.finetune_transformer = PEAVTransformerConfig(
|
| **(finetune_transformer or {})
|
| )
|
| self.nth_text_layer = nth_text_layer
|
| self.bottleneck_dim = bottleneck_dim
|
|
|