Buckets:
| """Flat denoiser classes — transformer received at call time, not stored. | |
| Three implementations of the :class:`~ltx_pipelines.utils.types.Denoiser` protocol: | |
| * :class:`SimpleDenoiser` — single transformer call, no guidance. | |
| * :class:`GuidedDenoiser` — static guiders, handles CFG + STG + isolated modality. | |
| * :class:`FactoryGuidedDenoiser` — resolves guiders per-step from sigma. | |
| ``GuidedDenoiser`` and ``FactoryGuidedDenoiser`` share the core multi-pass | |
| logic via the module-level :func:`_guided_denoise` function, which batches | |
| all guidance passes into a single transformer call. | |
| """ | |
| import torch | |
| from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderFactory, MultiModalGuiderParams | |
| from ltx_core.guidance.perturbations import ( | |
| BatchedPerturbationConfig, | |
| Perturbation, | |
| PerturbationConfig, | |
| PerturbationType, | |
| ) | |
| from ltx_core.model.transformer import X0Model | |
| from ltx_core.types import LatentState | |
| from ltx_pipelines.utils.helpers import modality_from_latent_state | |
| from ltx_pipelines.utils.types import DenoisedLatentResult | |
| _POSITIVE_ONLY_GUIDER = MultiModalGuider( | |
| params=MultiModalGuiderParams(cfg_scale=1.0, stg_scale=0.0, modality_scale=1.0), | |
| ) | |
| """Guider that only runs the conditioned pass and returns cond unchanged.""" | |
| def _ensure_guider(guider: MultiModalGuider | None) -> MultiModalGuider: | |
| """Return the guider as-is, or a positive-only guider for absent modalities.""" | |
| return guider if guider is not None else _POSITIVE_ONLY_GUIDER | |
| def _repeat_state(state: LatentState, n: int) -> LatentState: | |
| """Repeat a ``LatentState`` *n* times along the batch dimension. | |
| ``(B, ...) → (n*B, ...)`` by tiling the whole tensor n times, so the | |
| ordering is ``[item0, item1, ..., item0, item1, ...]`` — matching | |
| ``torch.cat`` of n per-pass contexts. | |
| """ | |
| def _repeat(t: torch.Tensor) -> torch.Tensor: | |
| repeats = [1] * t.dim() | |
| repeats[0] = n | |
| return t.repeat(repeats) | |
| return LatentState( | |
| latent=_repeat(state.latent), | |
| denoise_mask=_repeat(state.denoise_mask), | |
| positions=_repeat(state.positions), | |
| clean_latent=_repeat(state.clean_latent), | |
| attention_mask=_repeat(state.attention_mask) if state.attention_mask is not None else None, | |
| ) | |
| def _guided_denoise( # noqa: PLR0913,PLR0915 | |
| transformer: X0Model, | |
| video_state: LatentState | None, | |
| audio_state: LatentState | None, | |
| sigma: torch.Tensor, | |
| video_guider: MultiModalGuider, | |
| audio_guider: MultiModalGuider, | |
| v_context: torch.Tensor | None, | |
| a_context: torch.Tensor | None, | |
| *, | |
| last_denoised_video: torch.Tensor | None, | |
| last_denoised_audio: torch.Tensor | None, | |
| step_index: int, | |
| force_uncond_pass: bool = False, | |
| ) -> tuple[DenoisedLatentResult | None, DenoisedLatentResult | None]: | |
| """Core guided denoising — batches all guidance passes into one transformer call. | |
| Collects per-pass contexts first, then builds a single batched Modality | |
| per present modality via :func:`modality_from_latent_state`. When wrapped | |
| with :class:`~ltx_core.batch_split.BatchSplitAdapter`, the transformer may | |
| split this batch into sequential chunks internally. | |
| Guiders must not be ``None``. For absent modalities, callers should pass | |
| :data:`_POSITIVE_ONLY_GUIDER` (via :func:`_ensure_guider`) so that only | |
| the conditioned pass runs and ``calculate()`` returns cond unchanged. | |
| """ | |
| v_skip = video_guider.should_skip_step(step_index) | |
| a_skip = audio_guider.should_skip_step(step_index) | |
| if v_skip and a_skip: | |
| video_result = DenoisedLatentResult.result_or_none(denoised=last_denoised_video) | |
| audio_result = DenoisedLatentResult.result_or_none(denoised=last_denoised_audio) | |
| return video_result, audio_result | |
| if video_state is not None and v_context is None: | |
| raise ValueError("v_context is required when video_state is provided") | |
| if audio_state is not None and a_context is None: | |
| raise ValueError("a_context is required when audio_state is provided") | |
| # Define passes: (name, video_context, audio_context, perturbation_config). | |
| # Context is None for absent modalities — filtered out during collection. | |
| _pass = tuple[str, torch.Tensor | None, torch.Tensor | None, PerturbationConfig] | |
| passes: list[_pass] = [("cond", v_context, a_context, PerturbationConfig.empty())] | |
| v_needs_neg = video_guider.do_unconditional_generation() or (force_uncond_pass and video_state is not None) | |
| a_needs_neg = audio_guider.do_unconditional_generation() or (force_uncond_pass and audio_state is not None) | |
| if v_needs_neg or a_needs_neg: | |
| if v_needs_neg and video_guider.negative_context is None: | |
| raise ValueError("Negative context is required for unconditioned denoising") | |
| if a_needs_neg and audio_guider.negative_context is None: | |
| raise ValueError("Negative context is required for unconditioned denoising") | |
| v_neg = video_guider.negative_context if video_guider.negative_context is not None else v_context | |
| a_neg = audio_guider.negative_context if audio_guider.negative_context is not None else a_context | |
| passes.append(("uncond", v_neg, a_neg, PerturbationConfig.empty())) | |
| stg_perturbations: list[Perturbation] = [] | |
| if video_guider.do_perturbed_generation(): | |
| stg_perturbations.append( | |
| Perturbation(type=PerturbationType.SKIP_VIDEO_SELF_ATTN, blocks=video_guider.params.stg_blocks) | |
| ) | |
| if audio_guider.do_perturbed_generation(): | |
| stg_perturbations.append( | |
| Perturbation(type=PerturbationType.SKIP_AUDIO_SELF_ATTN, blocks=audio_guider.params.stg_blocks) | |
| ) | |
| if stg_perturbations: | |
| passes.append(("ptb", v_context, a_context, PerturbationConfig(stg_perturbations))) | |
| if video_guider.do_isolated_modality_generation() or audio_guider.do_isolated_modality_generation(): | |
| passes.append( | |
| ( | |
| "mod", | |
| v_context, | |
| a_context, | |
| PerturbationConfig( | |
| [ | |
| Perturbation(type=PerturbationType.SKIP_A2V_CROSS_ATTN, blocks=None), | |
| Perturbation(type=PerturbationType.SKIP_V2A_CROSS_ATTN, blocks=None), | |
| ] | |
| ), | |
| ) | |
| ) | |
| # Collect contexts, repeat states, and build batched modalities. | |
| pass_names = [name for name, _, _, _ in passes] | |
| ptb_configs = [ptb for _, _, _, ptb in passes] | |
| n = len(passes) | |
| orig_b = (video_state or audio_state).latent.shape[0] | |
| def _batched_sigma(state: LatentState) -> torch.Tensor: | |
| """Expand scalar sigma to (n * B,) matching the repeated state.""" | |
| return sigma.expand(state.latent.shape[0] * n) | |
| batched_video = None | |
| if video_state is not None: | |
| v_context = torch.cat([vc for _, vc, _, _ in passes], dim=0) | |
| batched_video = modality_from_latent_state( | |
| _repeat_state(video_state, n), | |
| v_context, | |
| _batched_sigma(video_state), | |
| enabled=not v_skip, | |
| ) | |
| batched_audio = None | |
| if audio_state is not None: | |
| a_context = torch.cat([ac for _, _, ac, _ in passes], dim=0) | |
| batched_audio = modality_from_latent_state( | |
| _repeat_state(audio_state, n), | |
| a_context, | |
| _batched_sigma(audio_state), | |
| enabled=not a_skip, | |
| ) | |
| # Replicate each pass's PerturbationConfig to all `orig_b` samples it | |
| # carries, so `BatchedPerturbationConfig.mask_like` returns a per-sample | |
| # mask (length n*orig_b) instead of a per-pass mask (length n). Without | |
| # this expansion the mask is broadcast against a (n*orig_b, T, D) tensor | |
| # and the multiplication fails with a batch-dim mismatch whenever | |
| # `orig_b > 1` (e.g. multi-prompt benchmark panels). | |
| batched_ptb_configs = [ptb for ptb in ptb_configs for _ in range(orig_b)] | |
| all_v, all_a = transformer( | |
| video=batched_video, audio=batched_audio, perturbations=BatchedPerturbationConfig(batched_ptb_configs) | |
| ) | |
| # Split results back and combine via guiders. | |
| splits_v = list(all_v.chunk(n)) if all_v is not None else [0.0] * n | |
| splits_a = list(all_a.chunk(n)) if all_a is not None else [0.0] * n | |
| r = dict(zip(pass_names, zip(splits_v, splits_a, strict=True), strict=True)) | |
| cond_v, cond_a = r["cond"] | |
| uncond_v, uncond_a = r.get("uncond", (0.0, 0.0)) | |
| ptb_v, ptb_a = r.get("ptb", (0.0, 0.0)) | |
| mod_v, mod_a = r.get("mod", (0.0, 0.0)) | |
| denoised_video = last_denoised_video if v_skip else video_guider.calculate(cond_v, uncond_v, ptb_v, mod_v) | |
| denoised_audio = last_denoised_audio if a_skip else audio_guider.calculate(cond_a, uncond_a, ptb_a, mod_a) | |
| return ( | |
| DenoisedLatentResult.result_or_none( | |
| denoised=denoised_video, uncond=uncond_v, cond=cond_v, ptb=ptb_v, mod=mod_v | |
| ), | |
| DenoisedLatentResult.result_or_none( | |
| denoised=denoised_audio, uncond=uncond_a, cond=cond_a, ptb=ptb_a, mod=mod_a | |
| ), | |
| ) | |
| class SimpleDenoiser: | |
| """Single transformer call, no guidance. | |
| Passes ``None`` Modality for absent modalities. | |
| """ | |
| def __init__( | |
| self, | |
| v_context: torch.Tensor | None, | |
| a_context: torch.Tensor | None, | |
| ) -> None: | |
| self.v_context = v_context | |
| self.a_context = a_context | |
| def __call__( | |
| self, | |
| transformer: X0Model, | |
| video_state: LatentState | None, | |
| audio_state: LatentState | None, | |
| sigmas: torch.Tensor, | |
| step_index: int, | |
| ) -> tuple[DenoisedLatentResult | None, DenoisedLatentResult | None]: | |
| sigma = sigmas[step_index] | |
| pos_video = modality_from_latent_state(video_state, self.v_context, sigma) if video_state is not None else None | |
| pos_audio = modality_from_latent_state(audio_state, self.a_context, sigma) if audio_state is not None else None | |
| denoised_video, denoised_audio = transformer(video=pos_video, audio=pos_audio, perturbations=None) | |
| return ( | |
| DenoisedLatentResult.result_or_none(denoised=denoised_video), | |
| DenoisedLatentResult.result_or_none(denoised=denoised_audio), | |
| ) | |
| class GuidedDenoiser: | |
| """Static guiders — handles CFG + STG + isolated modality. | |
| Context/guider can be ``None`` for absent modalities (a positive-only | |
| guider is substituted at call time). | |
| """ | |
| def __init__( | |
| self, | |
| v_context: torch.Tensor | None, | |
| a_context: torch.Tensor | None, | |
| video_guider: MultiModalGuider | None = None, | |
| audio_guider: MultiModalGuider | None = None, | |
| force_uncond_pass: bool = False, | |
| ) -> None: | |
| self.v_context = v_context | |
| self.a_context = a_context | |
| self.video_guider = video_guider | |
| self.audio_guider = audio_guider | |
| self.force_uncond_pass = force_uncond_pass | |
| self._last_denoised_video: torch.Tensor | None = None | |
| self._last_denoised_audio: torch.Tensor | None = None | |
| def __call__( | |
| self, | |
| transformer: X0Model, | |
| video_state: LatentState | None, | |
| audio_state: LatentState | None, | |
| sigmas: torch.Tensor, | |
| step_index: int, | |
| ) -> tuple[DenoisedLatentResult | None, DenoisedLatentResult | None]: | |
| guided_denoise_result_v, guided_denoise_result_a = _guided_denoise( | |
| transformer=transformer, | |
| video_state=video_state, | |
| audio_state=audio_state, | |
| sigma=sigmas[step_index], | |
| video_guider=_ensure_guider(self.video_guider), | |
| audio_guider=_ensure_guider(self.audio_guider), | |
| v_context=self.v_context, | |
| a_context=self.a_context, | |
| last_denoised_video=self._last_denoised_video, | |
| last_denoised_audio=self._last_denoised_audio, | |
| step_index=step_index, | |
| force_uncond_pass=self.force_uncond_pass, | |
| ) | |
| self._last_denoised_video = guided_denoise_result_v.denoised | |
| self._last_denoised_audio = guided_denoise_result_a.denoised | |
| return guided_denoise_result_v, guided_denoise_result_a | |
| class FactoryGuidedDenoiser: | |
| """Resolves guiders per-step from sigma, then delegates to shared guided logic.""" | |
| def __init__( | |
| self, | |
| v_context: torch.Tensor | None, | |
| a_context: torch.Tensor | None, | |
| video_guider_factory: MultiModalGuiderFactory | None = None, | |
| audio_guider_factory: MultiModalGuiderFactory | None = None, | |
| force_uncond_pass: bool = False, | |
| ) -> None: | |
| self.v_context = v_context | |
| self.a_context = a_context | |
| self.video_guider_factory = video_guider_factory | |
| self.audio_guider_factory = audio_guider_factory | |
| self.force_uncond_pass = force_uncond_pass | |
| self._last_denoised_video: torch.Tensor | None = None | |
| self._last_denoised_audio: torch.Tensor | None = None | |
| self._sigma_vals_cached: list[float] | None = None | |
| def __call__( | |
| self, | |
| transformer: X0Model, | |
| video_state: LatentState | None, | |
| audio_state: LatentState | None, | |
| sigmas: torch.Tensor, | |
| step_index: int, | |
| ) -> tuple[DenoisedLatentResult | None, DenoisedLatentResult | None]: | |
| if self._sigma_vals_cached is None: | |
| self._sigma_vals_cached = sigmas.detach().cpu().tolist() | |
| sigma_val = self._sigma_vals_cached[step_index] | |
| video_guider = _ensure_guider( | |
| self.video_guider_factory.build_from_sigma(sigma_val) if self.video_guider_factory else None | |
| ) | |
| audio_guider = _ensure_guider( | |
| (self.audio_guider_factory or self.video_guider_factory).build_from_sigma(sigma_val) | |
| if self.video_guider_factory or self.audio_guider_factory | |
| else None | |
| ) | |
| guided_denoise_result_v, guided_denoise_result_a = _guided_denoise( | |
| transformer=transformer, | |
| video_state=video_state, | |
| audio_state=audio_state, | |
| sigma=sigmas[step_index], | |
| video_guider=video_guider, | |
| audio_guider=audio_guider, | |
| v_context=self.v_context, | |
| a_context=self.a_context, | |
| last_denoised_video=self._last_denoised_video, | |
| last_denoised_audio=self._last_denoised_audio, | |
| step_index=step_index, | |
| force_uncond_pass=self.force_uncond_pass, | |
| ) | |
| self._last_denoised_video = guided_denoise_result_v.denoised | |
| self._last_denoised_audio = guided_denoise_result_a.denoised | |
| return guided_denoise_result_v, guided_denoise_result_a | |
Xet Storage Details
- Size:
- 14.7 kB
- Xet hash:
- 72017164bcb2c4183cda7957af92b9707af250077aae853ffdc8c324ced6242d
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.