Buckets:
| # Implementing Custom Training Strategies | |
| This guide explains how to implement your own training strategy for specialized recipes that cannot be expressed with | |
| the built-in `flexible` strategy. | |
| ## ๐ Overview | |
| The trainer uses the **Strategy Pattern** to separate training logic from the core training loop. Each strategy defines: | |
| 1. **What data is needed** - Which preprocessed data directories to load | |
| 2. **How to prepare inputs** - Transform batch data into model inputs | |
| 3. **How to compute loss** - Calculate the training objective | |
| This architecture lets you implement new training modes without modifying the core trainer code. | |
| ### When You Need a Custom Strategy | |
| > [!NOTE] | |
| > The built-in `flexible` strategy already supports most conditioning scenarios out of the box: | |
| > first-frame conditioning, video extension (prefix/suffix), spatial crop (outpainting), | |
| > mask-based inpainting, IC-LoRA reference conditioning, and frozen modality cross-conditioning | |
| > (audio-to-video, video-to-audio). Only implement a custom strategy if your use case requires | |
| > fundamentally different training logic that cannot be expressed through the flexible strategy's | |
| > configuration. | |
| Consider implementing a custom strategy when you need: | |
| - **Custom loss computation** (e.g., weighted losses, auxiliary losses, perceptual losses) | |
| - **Non-standard noise application** (e.g., noise schedules different from flow matching) | |
| - **Novel conditioning mechanisms** not covered by the flexible strategy's condition types | |
| - **Additional model outputs** beyond the standard video/audio predictions | |
| ## ๐๏ธ Architecture Overview | |
| ### How Strategies Fit Into the Trainer | |
| The trainer delegates all training-mode-specific logic to the strategy: | |
| 1. **Initialization** โ The trainer calls `config.get_data_sources()` to determine which preprocessed data directories to load | |
| 2. **Each training step:** | |
| - Calls `prepare_training_inputs()` to transform the raw batch into model-ready inputs | |
| - Runs the transformer forward pass | |
| - Calls `compute_loss()` to compute the training objective | |
| The trainer handles everything else: optimization, checkpointing, validation, and distributed training. | |
| ### Key Components | |
| | Component | Purpose | | |
| |-----------------------------------------------------------------------------------------|--------------------------------------------------------------| | |
| | [`TrainingStrategyConfigBase`](../src/ltx_trainer/training_strategies/base_strategy.py) | Base class for strategy configuration (Pydantic model) | | |
| | [`TrainingStrategy`](../src/ltx_trainer/training_strategies/base_strategy.py) | Abstract base class defining the strategy interface | | |
| | [`ModelInputs`](../src/ltx_trainer/training_strategies/base_strategy.py) | Dataclass containing prepared inputs for the transformer | | |
| | [`Modality`](../../ltx-core/src/ltx_core/model/transformer/modality.py) | ltx-core dataclass representing video or audio modality data | | |
| ## ๐ Step-by-Step Implementation | |
| ### Step 1: Plan Your Strategy | |
| Before writing code, answer these questions: | |
| 1. **What additional data does your strategy need?** | |
| - Example: A perceptual-loss strategy may need auxiliary feature targets | |
| - Example: A novel conditioning mechanism may need an additional precomputed directory | |
| 2. **What does conditioning look like?** | |
| - Which tokens should be noised vs. kept clean? | |
| - How should conditioning tokens be structured (e.g., first frame, reference video, mask)? | |
| 3. **How should loss be computed?** | |
| - Which tokens contribute to the loss? | |
| - Are there multiple loss terms to combine? | |
| ### Step 2: Extend Data Preprocessing (If Needed) | |
| If your strategy requires additional preprocessed data beyond video latents, audio latents, and text embeddings, you'll | |
| need to extend the preprocessing pipeline. | |
| #### Option A: Modify `process_dataset.py` | |
| For integrated preprocessing, add new arguments and processing steps to the main script. For example, to add mask | |
| preprocessing: | |
| ```python | |
| # In process_dataset.py, add a new argument | |
| @app.command() | |
| def main( | |
| # ... existing arguments ... | |
| mask_column: str | None = typer.Option( | |
| default=None, | |
| help="Column name containing mask video paths (for inpainting)", | |
| ), | |
| ) -> None: | |
| # ... existing processing ... | |
| # Process masks if provided | |
| if mask_column: | |
| logger.info("Processing mask videos for inpainting training...") | |
| mask_latents_dir = output_base / "mask_latents" | |
| compute_latents( | |
| dataset_file=dataset_path, | |
| video_column=mask_column, | |
| resolution_buckets=parsed_resolution_buckets, | |
| output_dir=str(mask_latents_dir), | |
| model_path=model_path, | |
| # ... other args ... | |
| ) | |
| ``` | |
| #### Option B: Create a Standalone Script | |
| For complex preprocessing that doesn't fit naturally into the existing pipeline, create a dedicated script | |
| (e.g., `scripts/process_masks.py`). Use [`scripts/compute_reference.py`](../scripts/compute_reference.py) as a | |
| template - it shows how to process paired data and update the dataset JSON. | |
| #### Expected Output Structure | |
| Your preprocessing should create a directory structure that the strategy can reference: | |
| ``` | |
| preprocessed_data_root/ | |
| โโโ latents/ # Video latents (standard) | |
| โโโ conditions/ # Text embeddings (standard) | |
| โโโ audio_latents/ # Audio latents (if with_audio) | |
| โโโ mask_latents/ # Your custom data directory | |
| โโโ reference_latents/ # Reference videos (for IC-LoRA) | |
| ``` | |
| ### Step 3: Create the Strategy Configuration | |
| Create a new file for your strategy (e.g., `src/ltx_trainer/training_strategies/inpainting.py`): | |
| ```python | |
| """Inpainting training strategy. | |
| This strategy implements video inpainting training where: | |
| - Mask latents indicate which regions to inpaint | |
| - Loss is computed only on masked (inpainted) regions | |
| """ | |
| from typing import Any, Literal | |
| import torch | |
| from pydantic import Field | |
| from torch import Tensor | |
| from ltx_core.model.transformer.modality import Modality | |
| from ltx_trainer.timestep_samplers import TimestepSampler | |
| from ltx_trainer.training_strategies.base_strategy import ( | |
| ModelInputs, | |
| TrainingStrategy, | |
| TrainingStrategyConfigBase, | |
| ) | |
| class InpaintingConfig(TrainingStrategyConfigBase): | |
| """Configuration for inpainting training strategy.""" | |
| # The 'name' field acts as a discriminator for the config union | |
| name: Literal["inpainting"] = "inpainting" | |
| mask_latents_dir: str = Field( | |
| default="mask_latents", | |
| description="Directory name for mask latents", | |
| ) | |
| # Add any strategy-specific parameters | |
| mask_threshold: float = Field( | |
| default=0.5, | |
| description="Threshold for binary mask conversion", | |
| ge=0.0, | |
| le=1.0, | |
| ) | |
| def get_data_sources(self) -> dict[str, str]: | |
| """Define which data directories to load. | |
| Returns a mapping of directory names (under preprocessed_data_root) to | |
| batch keys. The trainer loads .pt files from each directory and exposes | |
| them in the batch under the specified key. The trainer also uses this | |
| mapping to validate that all required directories exist. | |
| """ | |
| return { | |
| "latents": "latents", # -> batch["latents"] | |
| "conditions": "conditions", # -> batch["conditions"] | |
| self.mask_latents_dir: "masks", # -> batch["masks"] | |
| } | |
| ``` | |
| **Key points:** | |
| - Inherit from `TrainingStrategyConfigBase` | |
| - Use `Literal["your_strategy_name"]` for the `name` field - this enables automatic strategy selection | |
| - Use Pydantic `Field` for validation and documentation | |
| - Implement `get_data_sources()` on the config โ it's the single source of truth for data directories (used for both dataset wiring and existence validation) | |
| ### Step 4: Implement the Strategy Class | |
| ```python | |
| class InpaintingStrategy(TrainingStrategy): | |
| """Inpainting training strategy. | |
| Trains the model to fill in masked regions of videos while | |
| keeping unmasked regions as conditioning. | |
| """ | |
| config: InpaintingConfig | |
| def __init__(self, config: InpaintingConfig): | |
| super().__init__(config) | |
| def prepare_training_inputs( | |
| self, | |
| batch: dict[str, Any], | |
| timestep_sampler: TimestepSampler, | |
| ) -> ModelInputs: | |
| """Transform batch data into model inputs. | |
| This is where the core training logic lives: | |
| 1. Extract and patchify latents | |
| 2. Sample noise and apply it appropriately | |
| 3. Create conditioning masks | |
| 4. Build Modality objects for the transformer | |
| """ | |
| # Get video latents [B, C, F, H, W] | |
| latents_data = batch["latents"] | |
| video_latents = latents_data["latents"] | |
| # Get dimensions | |
| num_frames = latents_data["num_frames"][0].item() | |
| height = latents_data["height"][0].item() | |
| width = latents_data["width"][0].item() | |
| # Patchify: [B, C, F, H, W] -> [B, seq_len, C] | |
| video_latents = self._video_patchifier.patchify(video_latents) | |
| batch_size, seq_len, _ = video_latents.shape | |
| device = video_latents.device | |
| dtype = video_latents.dtype | |
| # Get mask latents and process them | |
| mask_data = batch["masks"] | |
| mask_latents = mask_data["latents"] | |
| mask_latents = self._video_patchifier.patchify(mask_latents) | |
| # Create binary mask: True = inpaint this region, False = keep original | |
| inpaint_mask = mask_latents.mean(dim=-1) > self.config.mask_threshold | |
| # Sample noise and sigmas | |
| sigmas = timestep_sampler.sample_for(video_latents) | |
| noise = torch.randn_like(video_latents) | |
| # Apply noise only to inpaint regions | |
| sigmas_expanded = sigmas.view(-1, 1, 1) | |
| noisy_latents = (1 - sigmas_expanded) * video_latents + sigmas_expanded * noise | |
| # Keep original latents for non-inpaint regions (conditioning) | |
| inpaint_mask_expanded = inpaint_mask.unsqueeze(-1) | |
| noisy_latents = torch.where(inpaint_mask_expanded, noisy_latents, video_latents) | |
| # Create per-token timesteps | |
| # Conditioning tokens (non-inpaint) get timestep=0 | |
| # Inpaint tokens get the sampled sigma | |
| timesteps = self._create_per_token_timesteps(~inpaint_mask, sigmas.squeeze()) | |
| # Compute targets (velocity prediction: noise - clean) | |
| targets = noise - video_latents | |
| # Get text embeddings | |
| conditions = batch["conditions"] | |
| video_prompt_embeds = conditions["video_prompt_embeds"] | |
| prompt_attention_mask = conditions["prompt_attention_mask"] | |
| # Generate position embeddings | |
| positions = self._get_video_positions( | |
| num_frames=num_frames, | |
| height=height, | |
| width=width, | |
| batch_size=batch_size, | |
| fps=24.0, # Or get from latents_data | |
| device=device, | |
| ) | |
| # Create video Modality | |
| video_modality = Modality( | |
| enabled=True, | |
| latent=noisy_latents, | |
| sigma=sigmas, | |
| timesteps=timesteps, | |
| positions=positions, | |
| context=video_prompt_embeds, | |
| context_mask=prompt_attention_mask, | |
| ) | |
| # Loss mask: only compute loss on inpaint regions | |
| loss_mask = inpaint_mask | |
| return ModelInputs( | |
| video=video_modality, | |
| audio=None, | |
| video_targets=targets, | |
| audio_targets=None, | |
| video_loss_mask=loss_mask, | |
| audio_loss_mask=None, | |
| ) | |
| def compute_loss( | |
| self, | |
| video_pred: Tensor, | |
| audio_pred: Tensor | None, | |
| inputs: ModelInputs, | |
| ) -> Tensor: | |
| """Compute training loss on inpaint regions only. Returns [B,].""" | |
| # MSE loss | |
| loss = (video_pred - inputs.video_targets).pow(2) | |
| # Apply loss mask and reduce to per-element [B,] | |
| loss_mask = inputs.video_loss_mask.unsqueeze(-1).float() | |
| masked = loss.mul(loss_mask) | |
| return masked.mean(dim=[-2, -1]) / loss_mask.mean(dim=[-2, -1]).clamp(min=1e-8) | |
| ``` | |
| ### Step 5: Register the Strategy | |
| You need to register your strategy in two places: | |
| **1. Update [`src/ltx_trainer/training_strategies/__init__.py`](../src/ltx_trainer/training_strategies/__init__.py):** | |
| ```python | |
| # Add import for your strategy | |
| from ltx_trainer.training_strategies.inpainting import InpaintingConfig, InpaintingStrategy | |
| # Add to the TrainingStrategyConfig type alias | |
| TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig | FlexibleStrategyConfig | InpaintingConfig | |
| # Add to __all__ | |
| __all__ = [ | |
| # ... existing exports ... | |
| "InpaintingConfig", | |
| "InpaintingStrategy", | |
| ] | |
| # Add case in get_training_strategy() | |
| def get_training_strategy(config: TrainingStrategyConfig) -> TrainingStrategy: | |
| match config: | |
| # ... existing cases ... | |
| case InpaintingConfig(): | |
| strategy = InpaintingStrategy(config) | |
| ``` | |
| **2. Update [`src/ltx_trainer/config.py`](../src/ltx_trainer/config.py):** | |
| ```python | |
| # Add import | |
| from ltx_trainer.training_strategies.inpainting import InpaintingConfig | |
| # Add to the TrainingStrategyConfig union with a Tag matching your strategy name | |
| TrainingStrategyConfig = Annotated[ | |
| Annotated[TextToVideoConfig, Tag("text_to_video")] | |
| | Annotated[VideoToVideoConfig, Tag("video_to_video")] | |
| | Annotated[FlexibleStrategyConfig, Tag("flexible")] | |
| | Annotated[InpaintingConfig, Tag("inpainting")], | |
| Discriminator(_get_strategy_discriminator), | |
| ] | |
| ``` | |
| ### Step 6: Create a Configuration File | |
| Create an example config in `configs/`: | |
| ```yaml | |
| # configs/custom_inpainting_lora.yaml | |
| model: | |
| model_path: "/path/to/ltx2.safetensors" | |
| text_encoder_path: "/path/to/gemma" | |
| training_mode: "lora" | |
| training_strategy: | |
| name: "inpainting" # Must match your Literal type | |
| mask_latents_dir: "mask_latents" | |
| mask_threshold: 0.5 | |
| lora: | |
| rank: 32 | |
| alpha: 32 | |
| target_modules: | |
| - "to_k" | |
| - "to_q" | |
| - "to_v" | |
| - "to_out.0" | |
| data: | |
| preprocessed_data_root: "/path/to/preprocessed/dataset" | |
| optimization: | |
| learning_rate: 1e-4 | |
| steps: 2000 | |
| batch_size: 1 | |
| # ... other config sections ... | |
| ``` | |
| ## ๐ง Helper Methods Reference | |
| The base `TrainingStrategy` class provides these helper methods: | |
| | Method | Purpose | | |
| |----------------------------------------------|-------------------------------------------------| | |
| | `_video_patchifier.patchify(latents)` | Convert `[B, C, F, H, W]` โ `[B, seq_len, C]` | | |
| | `_audio_patchifier.patchify(latents)` | Convert `[B, C, T, F]` โ `[B, T, C*F]` | | |
| | `_get_video_positions(...)` | Generate position embeddings for video | | |
| | `_get_audio_positions(...)` | Generate position embeddings for audio | | |
| | `_create_per_token_timesteps(conditioning_mask, sampled_sigma)` | Create timesteps with 0 for conditioning tokens | | |
| | `_create_first_frame_conditioning_mask(...)` | Create mask for first-frame conditioning | | |
| ## ๐ Understanding ModelInputs | |
| The `ModelInputs` dataclass contains everything needed for the forward pass and loss computation: | |
| ```python | |
| @dataclass | |
| class ModelInputs: | |
| video: Modality | None # Video modality data | |
| audio: Modality | None # Audio modality data | |
| video_targets: Tensor | None # Target values for video loss (velocity) | |
| audio_targets: Tensor | None # Target values for audio loss (velocity) | |
| video_loss_mask: Tensor | None # Boolean loss mask for video tokens | |
| audio_loss_mask: Tensor | None # Boolean loss mask for audio tokens | |
| ``` | |
| ## ๐ Understanding Modality | |
| The `Modality` dataclass (from ltx-core) represents a single modality's data: | |
| ```python | |
| @dataclass(frozen=True) | |
| class Modality: | |
| latent: Tensor # [B, T, D] โ patchified latent tokens | |
| sigma: Tensor # [B,] โ per-batch noise level (for cross-attn conditioning) | |
| timesteps: Tensor # [B, T] โ per-token timestep embeddings | |
| positions: Tensor # [B, 3, T, 2] for video, [B, 1, T, 2] for audio โ positional bounds | |
| context: Tensor # text conditioning embeddings | |
| enabled: bool = True | |
| context_mask: Tensor | None = None # attention mask for text context | |
| attention_mask: Tensor | None = None # optional 2D self-attention mask [B, T, T] | |
| ``` | |
| > [!NOTE] | |
| > **Per-token timesteps:** Each token in the sequence has its own timestep. Conditioning tokensโthose that should remain | |
| > un-noisedโmust have `timestep=0`. This is how the model distinguishes clean reference tokens from tokens to denoise. Use | |
| > `_create_per_token_timesteps(conditioning_mask, sampled_sigma)` to set this up correctly. | |
| > [!NOTE] | |
| > `Modality` is immutable (frozen dataclass). Use `dataclasses.replace()` to create modified copies. | |
| ## โ Testing Your Strategy | |
| 1. **Verify your training configuration is valid:** | |
| ```bash | |
| uv run python -c " | |
| from ltx_trainer.config import LtxTrainerConfig | |
| import yaml | |
| with open('configs/custom_inpainting_lora.yaml') as f: | |
| config = LtxTrainerConfig(**yaml.safe_load(f)) | |
| print(f'Strategy: {config.training_strategy.name}') | |
| " | |
| ``` | |
| 2. **Test strategy instantiation:** | |
| ```bash | |
| uv run python -c " | |
| from ltx_trainer.training_strategies import get_training_strategy | |
| from ltx_trainer.training_strategies.inpainting import InpaintingConfig | |
| config = InpaintingConfig() | |
| strategy = get_training_strategy(config) | |
| print(f'Data sources: {config.get_data_sources()}') | |
| " | |
| ``` | |
| 3. **Run a short training test:** | |
| ```bash | |
| uv run python scripts/train.py configs/custom_inpainting_lora.yaml | |
| ``` | |
| ## ๐ก Tips and Best Practices | |
| ### Debugging | |
| - Set `data.num_dataloader_workers: 0` to get clearer error messages | |
| - Use a small dataset and few steps for initial testing | |
| - Check tensor shapes at each step with print statements | |
| ## ๐ Related Documentation | |
| - [Training Modes](training-modes.md) - Overview of built-in training modes | |
| - [Configuration Reference](configuration-reference.md) - All configuration options | |
| - [Dataset Preparation](dataset-preparation.md) - Preprocessing workflow | |
| - [ltx-core Documentation](../../ltx-core/README.md) - Core model components | |
| ## ๐ Reference: Existing Strategies | |
| Study these implementations for guidance: | |
| | Strategy | Complexity | Key Features | | |
| |----------|------------|--------------| | |
| | [`FlexibleStrategy`](../src/ltx_trainer/training_strategies/flexible.py) | Medium | Unified conditioning framework โ supports all built-in modes | | |
| | [`TextToVideoStrategy`](../src/ltx_trainer/training_strategies/text_to_video.py) | Simple | First-frame conditioning, optional audio (deprecated) | | |
| | [`VideoToVideoStrategy`](../src/ltx_trainer/training_strategies/video_to_video.py) | Medium | Reference video concatenation, split loss mask (deprecated) | | |
Xet Storage Details
- Size:
- 19.2 kB
- Xet hash:
- 759c1d0898b3c575eb626acf6616d312d336c926af68145bc7a280ad32dbb475
ยท
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.