Buckets:
ltx-community/ltx2-trainer-src-v2 / packages /ltx-trainer /src /ltx_trainer /training_strategies /base_strategy.py
| """Base class for training strategies. | |
| This module defines the abstract base class that all training strategies must implement, | |
| along with the base configuration class. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from typing import Any, Literal | |
| import torch | |
| from pydantic import BaseModel, ConfigDict, Field | |
| from torch import Tensor | |
| from ltx_core.components.patchifiers import ( | |
| AudioPatchifier, | |
| VideoLatentPatchifier, | |
| get_pixel_coords, | |
| ) | |
| from ltx_core.model.transformer.modality import Modality | |
| from ltx_core.types import AudioLatentShape, SpatioTemporalScaleFactors, VideoLatentShape | |
| from ltx_trainer.timestep_samplers import TimestepSampler | |
| # Default frames per second for video missing in the FPS metadata | |
| DEFAULT_FPS = 24 | |
| # VAE scale factors for LTX-2 | |
| VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default() | |
| class TrainingStrategyConfigBase(BaseModel): | |
| """Base configuration class for training strategies. | |
| All strategy-specific configuration classes should inherit from this. | |
| """ | |
| model_config = ConfigDict(extra="forbid") | |
| name: Literal["text_to_video", "video_to_video", "flexible"] = Field( | |
| description="Unique name identifying the training strategy type" | |
| ) | |
| def get_data_sources(self) -> dict[str, str]: | |
| """Get the required data sources for this strategy. | |
| Returns a mapping of directory name (relative to ``preprocessed_data_root``) | |
| to the dataset output key under which that directory's contents are exposed. | |
| This is the single source of truth for which directories the strategy needs: | |
| it drives both dataset wiring (in the trainer) and existence validation | |
| (in ``LtxTrainerConfig``). | |
| """ | |
| class ModelInputs: | |
| """Container for model inputs using the Modality-based interface.""" | |
| video: Modality | None | |
| audio: Modality | None | |
| # Training targets (for loss computation) | |
| video_targets: Tensor | None | |
| audio_targets: Tensor | None | |
| # Masks for loss computation (True = compute loss for this token) | |
| video_loss_mask: Tensor | None | |
| audio_loss_mask: Tensor | None | |
| class TrainingStrategy(ABC): | |
| """Abstract base class for training strategies. | |
| Each strategy encapsulates the logic for a specific training mode, | |
| handling input preparation and loss computation. | |
| """ | |
| def __init__(self, config: TrainingStrategyConfigBase): | |
| """Initialize strategy with configuration. | |
| Args: | |
| config: Strategy-specific configuration | |
| """ | |
| self.config = config | |
| self._video_patchifier = VideoLatentPatchifier(patch_size=1) | |
| self._audio_patchifier = AudioPatchifier(patch_size=1) | |
| def prepare_training_inputs( | |
| self, | |
| batch: dict[str, Any], | |
| timestep_sampler: TimestepSampler, | |
| ) -> ModelInputs: | |
| """Prepare training inputs from a raw data batch. | |
| Args: | |
| batch: Raw batch data from the dataset. Contains: | |
| - "latents": Video latent data | |
| - "conditions": Text embeddings with keys: | |
| - "video_prompt_embeds": Already processed by embedding connectors | |
| - "audio_prompt_embeds": Already processed by embedding connectors | |
| - "prompt_attention_mask": Attention mask | |
| - Additional keys depending on strategy (e.g., "ref_latents" for IC-LoRA) | |
| timestep_sampler: Sampler for generating timesteps and noise | |
| Returns: | |
| ModelInputs containing Modality objects and training targets | |
| """ | |
| def compute_loss( | |
| self, | |
| video_pred: Tensor, | |
| audio_pred: Tensor | None, | |
| inputs: ModelInputs, | |
| ) -> Tensor: | |
| """Compute the training loss. | |
| Args: | |
| video_pred: Video prediction from the transformer model | |
| audio_pred: Audio prediction from the transformer model (None for video-only) | |
| inputs: The prepared model inputs containing targets and masks | |
| Returns: | |
| Per-element loss tensor of shape [B,]. The trainer reduces to a scalar | |
| before backward(). Returning unreduced loss enables per-sigma-bucket tracking. | |
| """ | |
| def get_checkpoint_metadata(self) -> dict[str, Any]: | |
| """Get strategy-specific metadata to include in checkpoint files. | |
| Override this method in subclasses to add custom metadata, | |
| e.g. any parameters that a downstream inference pipeline may need. | |
| Returns: | |
| Dictionary of metadata key-value pairs (values must be JSON-serializable) | |
| """ | |
| return {} | |
| def _get_video_positions( | |
| self, | |
| num_frames: int, | |
| height: int, | |
| width: int, | |
| batch_size: int, | |
| fps: float, | |
| device: torch.device, | |
| ) -> Tensor: | |
| """Generate video position embeddings using ltx_core's native implementation. | |
| Args: | |
| num_frames: Number of latent frames | |
| height: Latent height | |
| width: Latent width | |
| batch_size: Batch size | |
| fps: Frames per second | |
| device: Target device | |
| Returns: | |
| Position tensor of shape [B, 3, seq_len, 2] (float32) | |
| """ | |
| latent_coords = self._video_patchifier.get_patch_grid_bounds( | |
| output_shape=VideoLatentShape( | |
| frames=num_frames, | |
| height=height, | |
| width=width, | |
| batch=batch_size, | |
| channels=128, # Video latent channels | |
| ), | |
| device=device, | |
| ) | |
| # Convert latent coords to pixel coords with causal fix | |
| pixel_coords = get_pixel_coords( | |
| latent_coords=latent_coords, | |
| scale_factors=VIDEO_SCALE_FACTORS, | |
| causal_fix=True, | |
| ).float() | |
| # Scale temporal dimension by 1/fps to get time in seconds | |
| pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps | |
| return pixel_coords | |
| def _get_audio_positions( | |
| self, | |
| num_time_steps: int, | |
| batch_size: int, | |
| device: torch.device, | |
| ) -> Tensor: | |
| """Generate audio position embeddings using ltx_core's native implementation. | |
| Args: | |
| num_time_steps: Number of audio time steps (T, not T*mel_bins) | |
| batch_size: Batch size | |
| device: Target device | |
| Returns: | |
| Position tensor of shape [B, 1, num_time_steps, 2] | |
| Note: | |
| Audio latents should be in patchified format [B, T, C*F] = [B, T, 128] | |
| where T is the number of time steps, C=8 channels, F=16 mel bins. | |
| This matches the format produced by AudioPatchifier.patchify(). | |
| """ | |
| mel_bins = 16 | |
| return self._audio_patchifier.get_patch_grid_bounds( | |
| output_shape=AudioLatentShape( | |
| frames=num_time_steps, | |
| mel_bins=mel_bins, | |
| batch=batch_size, | |
| channels=8, # Audio latent channels | |
| ), | |
| device=device, | |
| ) | |
| def _create_per_token_timesteps(conditioning_mask: Tensor, sampled_sigma: Tensor) -> Tensor: | |
| """Create per-token timesteps based on conditioning mask. | |
| Args: | |
| conditioning_mask: Boolean mask of shape (batch_size, sequence_length), | |
| where True = conditioning token (timestep=0), False = target token (use sigma) | |
| sampled_sigma: Sampled sigma values of shape (batch_size,) or (batch_size, 1, 1) | |
| Returns: | |
| Timesteps tensor of shape [batch_size, sequence_length] | |
| """ | |
| # Expand to match conditioning mask shape [B, seq_len] | |
| expanded_sigma = sampled_sigma.view(-1, 1).expand_as(conditioning_mask) | |
| # Conditioning tokens get 0, target tokens get the sampled sigma | |
| return torch.where(conditioning_mask, torch.zeros_like(expanded_sigma), expanded_sigma) | |
| def _create_first_frame_conditioning_mask( | |
| batch_size: int, | |
| sequence_length: int, | |
| height: int, | |
| width: int, | |
| device: torch.device, | |
| first_frame_conditioning_p: float = 0.0, | |
| ) -> Tensor: | |
| """Create conditioning mask for first frame conditioning. | |
| Args: | |
| batch_size: Batch size | |
| sequence_length: Total sequence length | |
| height: Latent height | |
| width: Latent width | |
| device: Target device | |
| first_frame_conditioning_p: Probability of conditioning on the first frame | |
| Returns: | |
| Boolean mask where True indicates first frame tokens (if conditioning is enabled). | |
| The conditioning decision is drawn independently per batch element so the training | |
| signal across samples in a batch is i.i.d. | |
| """ | |
| conditioning_mask = torch.zeros(batch_size, sequence_length, dtype=torch.bool, device=device) | |
| if first_frame_conditioning_p > 0: | |
| first_frame_end_idx = height * width | |
| if first_frame_end_idx < sequence_length: | |
| # Per-sample Bernoulli draw so each batch element is independently conditioned. | |
| per_sample_condition = torch.rand(batch_size, device=device) < first_frame_conditioning_p | |
| conditioning_mask[per_sample_condition, :first_frame_end_idx] = True | |
| return conditioning_mask | |
Xet Storage Details
- Size:
- 9.48 kB
- Xet hash:
- 8d7f1ceba6c5721ac41b185a2ba3fcf16a41d4470b82f192bf8ec02cba4393d9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.