LTX2.3-Studio / backend.py
techfreakworm's picture
ci: run unit tests + ruff lint on every push
3ea399a unverified
Raw
History Blame
7.18 kB
"""ComfyUI library-mode backend.
Single-process, single-implementation. The @spaces.GPU decorator is the only
divergence between local and HF Spaces deployment.
"""
from __future__ import annotations
import asyncio
import os
import pathlib
import sys
import threading
import traceback as tb_mod
from collections.abc import AsyncIterator, Iterable
from dataclasses import dataclass, field
from typing import Any
import models
@dataclass
class DownloadEvent:
filename: str
mb_done: float
mb_total: float
@dataclass
class ProgressEvent:
stage: int
stage_label: str
step: int
total_steps: int
@dataclass
class OutputEvent:
video_path: str
audio_path: str | None = None
meta: dict = field(default_factory=dict)
@dataclass
class ErrorEvent:
category: str # "oom" | "zerogpu_timeout" | "execution" | "interrupt" | "download"
message: str
stage: int | None = None
traceback: str = ""
def _on_spaces() -> bool:
return bool(os.environ.get("SPACES_ZERO_GPU"))
def _comfy_dir() -> pathlib.Path:
if _on_spaces():
return pathlib.Path("/data/comfyui")
return pathlib.Path(__file__).parent / "comfyui"
class ComfyUILibraryBackend:
"""Wraps PromptExecutor for in-process workflow execution."""
def __init__(self) -> None:
self._comfy_dir = _comfy_dir()
if not self._comfy_dir.exists():
raise RuntimeError(
f"ComfyUI not found at {self._comfy_dir}. "
f"Local: run `bash setup.sh`. Spaces: see app.py:_bootstrap()."
)
if str(self._comfy_dir) not in sys.path:
sys.path.insert(0, str(self._comfy_dir))
# Defer comfy imports until the path is set up.
# NOTE: ComfyUI ships PromptExecutor in the top-level `execution.py`
# module, NOT under `comfy.execution`. Same for `nodes`. Both must be
# imported AFTER the sys.path insert above.
import asyncio
import comfy.cli_args # noqa: F401 — side-effect: registers CLI flags
import execution # top-level module — provides PromptExecutor
import nodes # top-level module — provides init_extra_nodes (async)
# init_extra_nodes is an async function in modern ComfyUI; run it once.
asyncio.run(nodes.init_extra_nodes()) # discover custom_nodes/
self._executor = execution.PromptExecutor(server_instance=None)
def __repr__(self) -> str:
return f"ComfyUILibraryBackend(comfy_dir={self._comfy_dir!r})"
async def submit(
self, mode: str, workflow: dict, gpu_duration: int = 120
) -> AsyncIterator[Any]:
"""Run a workflow end-to-end. Yields Download/Progress/Output/Error events."""
# Pre-flight: ensure all model files exist.
try:
needed = models.walk_workflow_for_models(workflow)
for download_event in models.ensure_models(needed):
yield download_event
except Exception as e:
yield ErrorEvent(
category="download",
message=str(e),
traceback=tb_mod.format_exc(),
)
return
# Run the inference in a worker thread; pass progress events through a queue.
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def _push(event: Any) -> None:
asyncio.run_coroutine_threadsafe(queue.put(event), loop)
def _hook(value: int, total: int, _preview=None) -> None:
_push(
ProgressEvent(
stage=0,
stage_label="diffusion",
step=int(value),
total_steps=int(total),
)
)
def _worker() -> None:
import comfy.utils
saved_hook = getattr(comfy.utils, "PROGRESS_BAR_HOOK", None)
try:
# Use the public setter; it writes the same global the
# ProgressBar class reads, but is the documented API.
comfy.utils.set_progress_bar_global_hook(_hook)
self._executor.execute(
workflow,
prompt_id="ltx23-aio",
extra_data={"client_id": "ltx23-aio"},
execute_outputs=[],
)
# PromptExecutor writes output files via VHS_VideoCombine; we read its
# history to find the most recent saved video.
outputs = list(self._executor.outputs.values())
video_path = _first_video_path(outputs) or ""
_push(OutputEvent(video_path=video_path))
except Exception as exc:
_push(
ErrorEvent(
category=_classify(exc),
message=str(exc),
traceback=tb_mod.format_exc(),
)
)
finally:
comfy.utils.set_progress_bar_global_hook(saved_hook)
_free_memory()
_push(None) # sentinel: stop the consumer
if _on_spaces():
import spaces
execute = spaces.GPU(duration=gpu_duration)(_worker)
thread = threading.Thread(target=execute, daemon=True)
else:
thread = threading.Thread(target=_worker, daemon=True)
thread.start()
while True:
event = await queue.get()
if event is None:
return
yield event
def interrupt(self) -> None:
"""Cancel the currently running workflow (if any)."""
try:
import comfy.model_management as mm
mm.interrupt_current_processing()
except Exception:
pass
def _classify(exc: Exception) -> str:
name = type(exc).__name__.lower()
if "outofmemory" in name or "cuda out of memory" in str(exc).lower():
return "oom"
if "interrupt" in name:
return "interrupt"
return "execution"
def _free_memory() -> None:
"""Free VRAM after a workflow finishes (success or failure)."""
try:
import comfy.model_management as mm
mm.unload_all_models()
except Exception:
pass
try:
import torch
if torch.backends.mps.is_available():
torch.mps.empty_cache()
except Exception:
pass
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
def _first_video_path(outputs: Iterable) -> str | None:
"""Find the first .mp4 path emitted by VHS_VideoCombine in PromptExecutor outputs."""
for output in outputs:
if not isinstance(output, dict):
continue
for value in output.values():
if isinstance(value, list):
for item in value:
if isinstance(item, dict) and "filename" in item:
fn = item["filename"]
if fn.endswith((".mp4", ".webm", ".mov")):
return item.get("fullpath", fn)
return None