import multiprocessing from pathlib import Path import torch import torch.nn.functional as F import torch.optim as optim from safetensors.torch import save_file from rvq_model import MotionRVQ_VAE if __name__ == "__main__": multiprocessing.freeze_support() from rvq_humanml_dataset import DataLoader, HumanML3DDataset base_dir = Path(__file__).resolve().parent device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Training on: {device}") model = MotionRVQ_VAE().to(device) dataset = HumanML3DDataset(data_dir=str(base_dir / "new_joint_vecs"), window_size=100) dataloader = DataLoader( dataset, batch_size=32, shuffle=True, num_workers=4, drop_last=True, ) optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4) num_epochs = 5 model.train() for epoch in range(num_epochs): epoch_loss = 0.0 for batch_idx, batch in enumerate(dataloader): batch = batch.to(device) optimizer.zero_grad() reconstructed, _, commit_loss = model(batch) pos_loss = F.mse_loss(reconstructed, batch) vel_orig = batch[:, :, 1:] - batch[:, :, :-1] vel_recon = reconstructed[:, :, 1:] - reconstructed[:, :, :-1] vel_loss = F.mse_loss(vel_recon, vel_orig) reconstruction_loss = pos_loss + (1.5 * vel_loss) loss = reconstruction_loss + commit_loss loss.backward() optimizer.step() epoch_loss += loss.item() if batch_idx % 50 == 0: print( f"Epoch [{epoch + 1}/{num_epochs}] Batch [{batch_idx}/{len(dataloader)}] " f"MSE: {reconstruction_loss.item():.4f} | Commit: {commit_loss.item():.4f}" ) print(f"--- End of epoch {epoch + 1} | Avg loss: {epoch_loss / len(dataloader):.4f} ---") weights_path = base_dir / "motion_rvq_weights.safetensors" state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()} save_file(state_dict, str(weights_path)) print(f"Training complete and model saved to: {weights_path}")