import torch import torch.nn as nn import torch.nn.functional as F class EMAVectorQuantizer(nn.Module): def __init__( self, num_embeddings=512, embedding_dim=256, commitment_cost=0.25, decay=0.99, epsilon=1e-5, ): super().__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.commitment_cost = commitment_cost self.decay = decay self.epsilon = epsilon embed = torch.randn(num_embeddings, embedding_dim) self.register_buffer("embedding", embed) self.register_buffer("cluster_size", torch.zeros(num_embeddings)) self.register_buffer("ema_w", embed.clone()) def forward(self, z): z = z.permute(0, 2, 1).contiguous() z_flattened = z.view(-1, self.embedding_dim) distances = ( torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(self.embedding**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.t()) ) min_encoding_indices = torch.argmin(distances, dim=1) z_q = F.embedding(min_encoding_indices, self.embedding) if self.training: encodings = F.one_hot(min_encoding_indices, self.num_embeddings).float() self.cluster_size.data.mul_(self.decay).add_(encodings.sum(0), alpha=1 - self.decay) n = self.cluster_size.sum() cluster_size = (self.cluster_size + self.epsilon) / ( n + self.num_embeddings * self.epsilon ) * n dw = torch.matmul(encodings.t(), z_flattened) self.ema_w.data.mul_(self.decay).add_(dw, alpha=1 - self.decay) self.embedding.data.copy_(self.ema_w / cluster_size.unsqueeze(1)) loss = self.commitment_cost * F.mse_loss(z_q.detach(), z_flattened) z_q = z_flattened + (z_q - z_flattened).detach() z_q = z_q.view(z.shape).permute(0, 2, 1).contiguous() return z_q, min_encoding_indices.view(z.shape[0], z.shape[1]), loss class RVQ(nn.Module): def __init__(self, num_levels=3, num_embeddings=512, embedding_dim=256): super().__init__() self.num_levels = num_levels self.quantizers = nn.ModuleList( [EMAVectorQuantizer(num_embeddings, embedding_dim) for _ in range(num_levels)] ) def forward(self, z): quantized_out = 0 residual = z all_indices = [] total_loss = 0 for quantizer in self.quantizers: z_q, indices, loss = quantizer(residual) quantized_out = quantized_out + z_q residual = residual - z_q all_indices.append(indices) total_loss += loss return quantized_out, torch.stack(all_indices, dim=1), total_loss class ResBlock1D(nn.Module): def __init__(self, channels): super().__init__() self.net = nn.Sequential( nn.Conv1d(channels, channels, kernel_size=3, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv1d(channels, channels, kernel_size=3, padding=1), ) def forward(self, x): return x + self.net(x) class MotionEncoder(nn.Module): def __init__(self, in_channels=263, latent_dim=512): super().__init__() self.net = nn.Sequential( nn.Conv1d(in_channels, 512, kernel_size=3, padding=1), nn.LeakyReLU(0.2, inplace=True), ResBlock1D(512), ResBlock1D(512), ResBlock1D(512), nn.Conv1d(512, latent_dim, kernel_size=8, stride=4, padding=2), ) def forward(self, x): return self.net(x) class MotionDecoder(nn.Module): def __init__(self, latent_dim=512, out_channels=263): super().__init__() self.net = nn.Sequential( nn.ConvTranspose1d(latent_dim, 512, kernel_size=8, stride=4, padding=2), nn.LeakyReLU(0.2, inplace=True), ResBlock1D(512), ResBlock1D(512), ResBlock1D(512), nn.Conv1d(512, out_channels, kernel_size=3, padding=1), ) def forward(self, z_q): return self.net(z_q) class MotionRVQ_VAE(nn.Module): def __init__(self): super().__init__() self.encoder = MotionEncoder(in_channels=263, latent_dim=512) self.rvq = RVQ(num_levels=4, num_embeddings=1024, embedding_dim=512) self.decoder = MotionDecoder(latent_dim=512, out_channels=263) def forward(self, x): z = self.encoder(x) z_q, token_indices, commitment_loss = self.rvq(z) x_recon = self.decoder(z_q) return x_recon, token_indices, commitment_loss