# app.py """LTX 2.3 All-in-One — Gradio entry point.""" from __future__ import annotations import os import pathlib import sys import time from typing import Any import gradio as gr import backend as backend_module import modes import ui import workflow as wf_module # --------------------------------------------------------------------------- # Bootstrap — runs once on cold start. # --------------------------------------------------------------------------- def _on_spaces() -> bool: return bool(os.environ.get("SPACES_ZERO_GPU")) COMFYUI_REPO = "https://github.com/comfyanonymous/ComfyUI.git" # Pinned to the same commit the local git submodule uses (set in Task 5). # Override via env var only when intentionally testing a different ComfyUI version. COMFYUI_COMMIT = os.environ.get( "LTX23_AIO_COMFYUI_COMMIT", "eb0686bbb60c83e44c3a3e4f7defd0f589cfef10", ) CUSTOM_NODES_PINNED: list[tuple[str, str]] = [ ("https://github.com/Lightricks/ComfyUI-LTXVideo.git", "main"), ("https://github.com/kijai/ComfyUI-KJNodes.git", "main"), ("https://github.com/rgthree/rgthree-comfy.git", "main"), ("https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git", "main"), ("https://github.com/pythongosssss/ComfyUI-Custom-Scripts.git", "main"), ] def _git_clone(url: str, dst: pathlib.Path, ref: str) -> None: import subprocess subprocess.check_call(["git", "clone", "--depth", "1", "--branch", ref, url, str(dst)]) def _bootstrap() -> None: on_spaces = _on_spaces() comfy_dir = pathlib.Path("/data/comfyui" if on_spaces else "comfyui") if on_spaces and not comfy_dir.exists(): comfy_dir.parent.mkdir(parents=True, exist_ok=True) _git_clone(COMFYUI_REPO, comfy_dir, ref=COMFYUI_COMMIT) for node_url, node_ref in CUSTOM_NODES_PINNED: name = node_url.rstrip(".git").rsplit("/", 1)[-1] _git_clone(node_url, comfy_dir / "custom_nodes" / name, ref=node_ref) # Install custom node deps import subprocess for cn in (comfy_dir / "custom_nodes").iterdir(): req = cn / "requirements.txt" if req.exists(): subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", str(req)]) if str(comfy_dir) not in sys.path: sys.path.insert(0, str(comfy_dir)) os.environ.setdefault( "COMFY_MODELS_DIR", str(pathlib.Path("/data/models") if on_spaces else (comfy_dir / "models")), ) _bootstrap() # --------------------------------------------------------------------------- # Gradio app # --------------------------------------------------------------------------- _CUSTOM_CSS = """ .status-card { padding: 14px 16px; border-radius: 10px; background: rgba(255,255,255,0.04); border: 1px solid rgba(255,255,255,0.08); } .status-row { display: flex; gap: 14px; align-items: center; margin-bottom: 8px; } .status-stage { font-weight: 600; } .status-meta { font-size: 12px; opacity: 0.75; } .status-bar { height: 6px; background: rgba(255,255,255,0.08); border-radius: 99px; overflow: hidden; } .status-fill { height: 100%; background: linear-gradient(90deg,#6ea8fe,#8de9fe); transition: width .3s; } .status-mem { font-size: 11px; opacity: 0.6; margin-top: 6px; font-family: ui-monospace, monospace; } """ def build_app() -> gr.Blocks: with gr.Blocks(theme=gr.themes.Soft(), title="LTX 2.3 All-in-One", css=_CUSTOM_CSS) as app: gr.Markdown("# ⚡ LTX 2.3 All-in-One") with gr.Row(): with gr.Column(scale=1, min_width=200): _render_sidebar() with gr.Column(scale=4): handles = _render_mode_panels() for name, h in handles.items(): inputs = _collect_inputs_for_mode(name, h) h["generate_btn"].click( fn=_make_handler(name, h), inputs=inputs, outputs=[h["status"], h["video_out"]], ) return app def _render_sidebar() -> None: gr.Markdown("### Modes") for mode in modes.MODE_REGISTRY.values(): gr.Markdown(f"- {mode.icon} {mode.label}") gr.Markdown("---\n### Models") gr.Button("Unload all models", variant="secondary") def _render_mode_panels() -> dict[str, dict]: """Render one form per mode. Returns the component handles keyed by mode.""" handles: dict[str, dict] = {} with gr.Tabs(): for name, mode in modes.MODE_REGISTRY.items(): with gr.Tab(label=f"{mode.icon} {mode.label}"): handles[name] = _render_one_mode(name) return handles def _render_one_mode(name: str) -> dict: """Render a per-mode form. Returns component handles for the generate handler.""" handles: dict = {"mode": name} with gr.Row(): with gr.Column(scale=2): handles["prompt"] = gr.Textbox( label="Prompt", lines=4, placeholder="Describe the shot..." ) # Mode-specific media inputs if name == "i2v": handles["image"] = gr.Image(label="Source image", type="filepath") elif name == "a2v": handles["audio"] = gr.Audio(label="Source audio", type="filepath") elif name == "lipsync": handles["image"] = gr.Image(label="Portrait", type="filepath") handles["audio"] = gr.Audio(label="Speech audio", type="filepath") elif name == "keyframe": handles["first_frame"] = gr.Image(label="First frame", type="filepath") handles["last_frame"] = gr.Image(label="Last frame", type="filepath") elif name == "style": handles["input_video"] = gr.Video(label="Source video") handles["preset"] = ui.preset_bar() with gr.Row(): handles["width"] = gr.Slider(256, 1280, value=512, step=32, label="Width") handles["height"] = gr.Slider(256, 1280, value=768, step=32, label="Height") with gr.Row(): handles["frames"] = gr.Slider(9, 121, value=81, step=8, label="Frames (8k+1)") handles["fps"] = gr.Slider(8, 30, value=24, step=1, label="FPS") handles["seed"] = gr.Number(label="Seed", value=42, precision=0) with gr.Accordion("Advanced ▾", open=False): handles["lora"] = ui.lora_chrome(name) handles["negative_prompt"] = gr.Textbox(label="Negative prompt", lines=2) handles["generate_btn"] = gr.Button("▶ Generate", variant="primary", size="lg") with gr.Column(scale=2): handles["status"] = ui.status_banner() handles["video_out"] = gr.Video(label="Output", autoplay=True) handles["history"] = gr.Markdown("") return handles _BACKEND: backend_module.ComfyUILibraryBackend | None = None def _get_backend() -> backend_module.ComfyUILibraryBackend: global _BACKEND if _BACKEND is None: _BACKEND = backend_module.ComfyUILibraryBackend() return _BACKEND PRESET_DURATION = {"Fast": 60, "Balanced": 120, "Quality": 300} async def _on_generate(mode_name: str, **inputs: Any): """Generate handler — async generator yielding (status_html, video_path).""" mode = modes.MODE_REGISTRY[mode_name] # Translate UI inputs into the parameterize_fn input dict. params: dict[str, Any] = { "prompt": inputs.get("prompt", ""), "negative_prompt": inputs.get("negative_prompt", ""), "preset": inputs.get("preset", "Balanced").lower(), "width": int(inputs.get("width", 512)), "height": int(inputs.get("height", 768)), "frames": int(inputs.get("frames", 81)), "fps": int(inputs.get("fps", 24)), "seed": int(inputs.get("seed", 42)), } for k in ( "image", "audio", "first_frame", "last_frame", "input_video", "camera_lora", "camera_strength", "detailer_on", "detailer_strength", "ic_lora", "ic_strength", "pose_on", "audio_cfg", "image_strength", ): if k in inputs: params[k] = inputs[k] patches = mode.parameterize_fn(params) workflow = wf_module.load_template(mode_name) for patch in patches: wf_module.set_input(workflow, *patch) wf_module.validate(workflow) backend = _get_backend() duration = PRESET_DURATION.get(inputs.get("preset", "Balanced"), 120) started = time.time() async for event in backend.submit(mode_name, workflow, gpu_duration=duration): elapsed = time.time() - started if isinstance(event, backend_module.DownloadEvent): status = ui.render_status( stage_index=0, stage_label=f"Downloading {event.filename}", step=int(event.mb_done), total_steps=int(max(event.mb_total, 1)), elapsed_s=elapsed, eta_s=0, ) yield status, gr.update() elif isinstance(event, backend_module.ProgressEvent): stage = ( mode.stage_map[event.stage] if event.stage < len(mode.stage_map) else mode.stage_map[-1] ) eta = (elapsed / max(event.step, 1)) * (event.total_steps - event.step) status = ui.render_status( stage_index=event.stage + 1, stage_label=stage.label, step=event.step, total_steps=event.total_steps, elapsed_s=elapsed, eta_s=eta, ) yield status, gr.update() elif isinstance(event, backend_module.OutputEvent): yield ui._render_idle(), event.video_path elif isinstance(event, backend_module.ErrorEvent): error_html = ( f'