LTX2.3-Studio / app.py
techfreakworm's picture
ci: run unit tests + ruff lint on every push
3ea399a unverified
Raw
History Blame
12.4 kB
# app.py
"""LTX 2.3 All-in-One — Gradio entry point."""
from __future__ import annotations
import os
import pathlib
import sys
import time
from typing import Any
import gradio as gr
import backend as backend_module
import modes
import ui
import workflow as wf_module
# ---------------------------------------------------------------------------
# Bootstrap — runs once on cold start.
# ---------------------------------------------------------------------------
def _on_spaces() -> bool:
return bool(os.environ.get("SPACES_ZERO_GPU"))
COMFYUI_REPO = "https://github.com/comfyanonymous/ComfyUI.git"
# Pinned to the same commit the local git submodule uses (set in Task 5).
# Override via env var only when intentionally testing a different ComfyUI version.
COMFYUI_COMMIT = os.environ.get(
"LTX23_AIO_COMFYUI_COMMIT",
"eb0686bbb60c83e44c3a3e4f7defd0f589cfef10",
)
CUSTOM_NODES_PINNED: list[tuple[str, str]] = [
("https://github.com/Lightricks/ComfyUI-LTXVideo.git", "main"),
("https://github.com/kijai/ComfyUI-KJNodes.git", "main"),
("https://github.com/rgthree/rgthree-comfy.git", "main"),
("https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git", "main"),
("https://github.com/pythongosssss/ComfyUI-Custom-Scripts.git", "main"),
]
def _git_clone(url: str, dst: pathlib.Path, ref: str) -> None:
import subprocess
subprocess.check_call(["git", "clone", "--depth", "1", "--branch", ref, url, str(dst)])
def _bootstrap() -> None:
on_spaces = _on_spaces()
comfy_dir = pathlib.Path("/data/comfyui" if on_spaces else "comfyui")
if on_spaces and not comfy_dir.exists():
comfy_dir.parent.mkdir(parents=True, exist_ok=True)
_git_clone(COMFYUI_REPO, comfy_dir, ref=COMFYUI_COMMIT)
for node_url, node_ref in CUSTOM_NODES_PINNED:
name = node_url.rstrip(".git").rsplit("/", 1)[-1]
_git_clone(node_url, comfy_dir / "custom_nodes" / name, ref=node_ref)
# Install custom node deps
import subprocess
for cn in (comfy_dir / "custom_nodes").iterdir():
req = cn / "requirements.txt"
if req.exists():
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", str(req)])
if str(comfy_dir) not in sys.path:
sys.path.insert(0, str(comfy_dir))
os.environ.setdefault(
"COMFY_MODELS_DIR",
str(pathlib.Path("/data/models") if on_spaces else (comfy_dir / "models")),
)
_bootstrap()
# ---------------------------------------------------------------------------
# Gradio app
# ---------------------------------------------------------------------------
_CUSTOM_CSS = """
.status-card { padding: 14px 16px; border-radius: 10px; background: rgba(255,255,255,0.04); border: 1px solid rgba(255,255,255,0.08); }
.status-row { display: flex; gap: 14px; align-items: center; margin-bottom: 8px; }
.status-stage { font-weight: 600; }
.status-meta { font-size: 12px; opacity: 0.75; }
.status-bar { height: 6px; background: rgba(255,255,255,0.08); border-radius: 99px; overflow: hidden; }
.status-fill { height: 100%; background: linear-gradient(90deg,#6ea8fe,#8de9fe); transition: width .3s; }
.status-mem { font-size: 11px; opacity: 0.6; margin-top: 6px; font-family: ui-monospace, monospace; }
"""
def build_app() -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft(), title="LTX 2.3 All-in-One", css=_CUSTOM_CSS) as app:
gr.Markdown("# ⚡ LTX 2.3 All-in-One")
with gr.Row():
with gr.Column(scale=1, min_width=200):
_render_sidebar()
with gr.Column(scale=4):
handles = _render_mode_panels()
for name, h in handles.items():
inputs = _collect_inputs_for_mode(name, h)
h["generate_btn"].click(
fn=_make_handler(name, h),
inputs=inputs,
outputs=[h["status"], h["video_out"]],
)
return app
def _render_sidebar() -> None:
gr.Markdown("### Modes")
for mode in modes.MODE_REGISTRY.values():
gr.Markdown(f"- {mode.icon} {mode.label}")
gr.Markdown("---\n### Models")
gr.Button("Unload all models", variant="secondary")
def _render_mode_panels() -> dict[str, dict]:
"""Render one form per mode. Returns the component handles keyed by mode."""
handles: dict[str, dict] = {}
with gr.Tabs():
for name, mode in modes.MODE_REGISTRY.items():
with gr.Tab(label=f"{mode.icon} {mode.label}"):
handles[name] = _render_one_mode(name)
return handles
def _render_one_mode(name: str) -> dict:
"""Render a per-mode form. Returns component handles for the generate handler."""
handles: dict = {"mode": name}
with gr.Row():
with gr.Column(scale=2):
handles["prompt"] = gr.Textbox(
label="Prompt", lines=4, placeholder="Describe the shot..."
)
# Mode-specific media inputs
if name == "i2v":
handles["image"] = gr.Image(label="Source image", type="filepath")
elif name == "a2v":
handles["audio"] = gr.Audio(label="Source audio", type="filepath")
elif name == "lipsync":
handles["image"] = gr.Image(label="Portrait", type="filepath")
handles["audio"] = gr.Audio(label="Speech audio", type="filepath")
elif name == "keyframe":
handles["first_frame"] = gr.Image(label="First frame", type="filepath")
handles["last_frame"] = gr.Image(label="Last frame", type="filepath")
elif name == "style":
handles["input_video"] = gr.Video(label="Source video")
handles["preset"] = ui.preset_bar()
with gr.Row():
handles["width"] = gr.Slider(256, 1280, value=512, step=32, label="Width")
handles["height"] = gr.Slider(256, 1280, value=768, step=32, label="Height")
with gr.Row():
handles["frames"] = gr.Slider(9, 121, value=81, step=8, label="Frames (8k+1)")
handles["fps"] = gr.Slider(8, 30, value=24, step=1, label="FPS")
handles["seed"] = gr.Number(label="Seed", value=42, precision=0)
with gr.Accordion("Advanced ▾", open=False):
handles["lora"] = ui.lora_chrome(name)
handles["negative_prompt"] = gr.Textbox(label="Negative prompt", lines=2)
handles["generate_btn"] = gr.Button("▶ Generate", variant="primary", size="lg")
with gr.Column(scale=2):
handles["status"] = ui.status_banner()
handles["video_out"] = gr.Video(label="Output", autoplay=True)
handles["history"] = gr.Markdown("")
return handles
_BACKEND: backend_module.ComfyUILibraryBackend | None = None
def _get_backend() -> backend_module.ComfyUILibraryBackend:
global _BACKEND
if _BACKEND is None:
_BACKEND = backend_module.ComfyUILibraryBackend()
return _BACKEND
PRESET_DURATION = {"Fast": 60, "Balanced": 120, "Quality": 300}
async def _on_generate(mode_name: str, **inputs: Any):
"""Generate handler — async generator yielding (status_html, video_path)."""
mode = modes.MODE_REGISTRY[mode_name]
# Translate UI inputs into the parameterize_fn input dict.
params: dict[str, Any] = {
"prompt": inputs.get("prompt", ""),
"negative_prompt": inputs.get("negative_prompt", ""),
"preset": inputs.get("preset", "Balanced").lower(),
"width": int(inputs.get("width", 512)),
"height": int(inputs.get("height", 768)),
"frames": int(inputs.get("frames", 81)),
"fps": int(inputs.get("fps", 24)),
"seed": int(inputs.get("seed", 42)),
}
for k in (
"image",
"audio",
"first_frame",
"last_frame",
"input_video",
"camera_lora",
"camera_strength",
"detailer_on",
"detailer_strength",
"ic_lora",
"ic_strength",
"pose_on",
"audio_cfg",
"image_strength",
):
if k in inputs:
params[k] = inputs[k]
patches = mode.parameterize_fn(params)
workflow = wf_module.load_template(mode_name)
for patch in patches:
wf_module.set_input(workflow, *patch)
wf_module.validate(workflow)
backend = _get_backend()
duration = PRESET_DURATION.get(inputs.get("preset", "Balanced"), 120)
started = time.time()
async for event in backend.submit(mode_name, workflow, gpu_duration=duration):
elapsed = time.time() - started
if isinstance(event, backend_module.DownloadEvent):
status = ui.render_status(
stage_index=0,
stage_label=f"Downloading {event.filename}",
step=int(event.mb_done),
total_steps=int(max(event.mb_total, 1)),
elapsed_s=elapsed,
eta_s=0,
)
yield status, gr.update()
elif isinstance(event, backend_module.ProgressEvent):
stage = (
mode.stage_map[event.stage]
if event.stage < len(mode.stage_map)
else mode.stage_map[-1]
)
eta = (elapsed / max(event.step, 1)) * (event.total_steps - event.step)
status = ui.render_status(
stage_index=event.stage + 1,
stage_label=stage.label,
step=event.step,
total_steps=event.total_steps,
elapsed_s=elapsed,
eta_s=eta,
)
yield status, gr.update()
elif isinstance(event, backend_module.OutputEvent):
yield ui._render_idle(), event.video_path
elif isinstance(event, backend_module.ErrorEvent):
error_html = (
f'<div class="status-card status-error">'
f' <div class="status-row"><span class="status-stage">Error · {event.category}</span></div>'
f" <div>{event.message}</div>"
f"</div>"
)
yield error_html, gr.update()
def _input_keys_for_mode(mode_name: str, h: dict) -> list[str]:
base = ["prompt", "preset", "width", "height", "frames", "fps", "seed"]
if mode_name == "i2v":
base.append("image")
elif mode_name == "a2v":
base.append("audio")
elif mode_name == "lipsync":
base.extend(["image", "audio"])
elif mode_name == "keyframe":
base.extend(["first_frame", "last_frame"])
elif mode_name == "style":
base.append("input_video")
base.append("negative_prompt")
base.extend(["camera_lora", "camera_strength", "detailer_on", "detailer_strength"])
if h["lora"].ic_lora is not None:
base.extend(["ic_lora", "ic_strength"])
if h["lora"].pose_on is not None:
base.append("pose_on")
return base
def _collect_inputs_for_mode(mode_name: str, h: dict) -> list:
"""Gather the gr.Component handles to pass into _on_generate."""
base = [h["prompt"], h["preset"], h["width"], h["height"], h["frames"], h["fps"], h["seed"]]
if mode_name == "i2v":
base.append(h["image"])
elif mode_name == "a2v":
base.append(h["audio"])
elif mode_name == "lipsync":
base.extend([h["image"], h["audio"]])
elif mode_name == "keyframe":
base.extend([h["first_frame"], h["last_frame"]])
elif mode_name == "style":
base.append(h["input_video"])
base.append(h["negative_prompt"])
base.extend(
[
h["lora"].camera_lora,
h["lora"].camera_strength,
h["lora"].detailer_on,
h["lora"].detailer_strength,
]
)
if h["lora"].ic_lora is not None:
base.extend([h["lora"].ic_lora, h["lora"].ic_strength])
if h["lora"].pose_on is not None:
base.append(h["lora"].pose_on)
return base
def _make_handler(mode_name: str, h: dict):
keys = _input_keys_for_mode(mode_name, h)
async def handler(*values):
kwargs = dict(zip(keys, values, strict=False))
async for output in _on_generate(mode_name, **kwargs):
yield output
return handler
if __name__ == "__main__":
app = build_app()
app.launch(server_name="0.0.0.0", server_port=7860)