linoyts's picture
linoyts HF Staff
Default GPU flavor -> rtx-pro-6000 (Blackwell)
498a04b verified
Raw
History Blame Contribute Delete
21.8 kB
"""HF Jobs backend for the LTX-2 trainer Space.
Per training request:
1. stage the dataset (videos + dataset.json) + config.yaml + job_config.json locally,
2. sync them to a per-run HF bucket,
3. generate a self-contained UV job script,
4. submit it with `hf jobs uv run --flavor a100-large --secrets HF_TOKEN --detach`.
On the Job, the script syncs the source bucket + run bucket, runs `uv sync --frozen`
(reproducing the working trainer env from the lockfile), downloads the base checkpoint
and Gemma, then runs process_dataset.py → train.py, which pushes the trained LoRA to the
Hub. For IC-LoRA, references are user-supplied (paired `*_reference` videos) — no
auto-derivation.
The Space itself only needs gradio + huggingface_hub + pyyaml (no torch).
"""
from __future__ import annotations
import json
import os
import re
import shutil
import tempfile
import zipfile
from pathlib import Path
import yaml
from huggingface_hub import HfApi
SRC_BUCKET = os.environ.get("LTX_SRC_BUCKET", "ltx-community/ltx2-trainer-src-v2")
DEFAULT_FLAVOR = "rtx-pro-6000" # default single-GPU flavor (Blackwell, 96GB; cu128 torch in v2)
VIDEO_EXTS = {".mp4", ".mov", ".mkv", ".webm", ".avi", ".m4v"}
# Deterministic on-Job paths (the Space bakes these into config.yaml).
JOB_ROOT = "/tmp/ltxjob"
JOB_MODEL = f"{JOB_ROOT}/models/ltx-2.3-22b-dev.safetensors"
JOB_GEMMA = f"{JOB_ROOT}/gemma"
JOB_RUN = f"{JOB_ROOT}/run"
# `recommended` holds the per-mode hyperparameters the example configs use (configs/*.yaml):
# IC-LoRA/v2v → lr 2e-4, 3000 steps; T2V/I2V → lr 1e-4, 2000 steps. Rank/alpha stay 32.
MODES = {
"IC-LoRA (in-context control)": {
"needs_reference": True,
"recommended": {"learning_rate": 2e-4, "steps": 3000},
"target_modules": [
"attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0",
"attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0",
"ff.net.0.proj", "ff.net.2",
],
},
"Text-to-Video LoRA": {"needs_reference": False, "first_frame_prob": 0.0,
"recommended": {"learning_rate": 1e-4, "steps": 2000},
"target_modules": ["to_k", "to_q", "to_v", "to_out.0"]},
"Image-to-Video LoRA": {"needs_reference": False, "first_frame_prob": 0.5,
"recommended": {"learning_rate": 1e-4, "steps": 2000},
"target_modules": ["to_k", "to_q", "to_v", "to_out.0"]},
}
def parse_resolution(resolution: str) -> tuple[int, int, int]:
parts = resolution.lower().replace(" ", "").split("x")
if len(parts) != 3:
raise ValueError(f"Resolution must be 'WxHxF', got {resolution!r}")
w, h, f = (int(p) for p in parts)
if w % 32 or h % 32:
raise ValueError(f"Width and height must be divisible by 32 (got {w}x{h}).")
if f % 8 != 1:
raise ValueError(f"Frame count must satisfy frames % 8 == 1 (got {f}).")
return w, h, f
def _is_reference(p: Path) -> bool:
return p.stem.endswith("_reference")
def _stage_uploads(uploaded: list[str], videos_dir: Path) -> tuple[list[Path], dict[str, str]]:
"""Copy uploaded videos into videos_dir and read per-clip `.txt` caption sidecars.
Returns (sorted video paths, {stem: caption}). `.txt` files are matched to a video by
stem (e.g. `clip.txt` captions `clip.mp4`). Handles `.zip` archives too.
"""
videos_dir.mkdir(parents=True, exist_ok=True)
caption_files: dict[str, str] = {}
def handle(name: str, data: bytes) -> None:
suffix = Path(name).suffix.lower()
if suffix in VIDEO_EXTS:
(videos_dir / Path(name).name).write_bytes(data)
elif suffix == ".txt":
caption_files[Path(name).stem] = data.decode("utf-8", "ignore").strip()
for p in uploaded or []:
src = Path(p)
if src.suffix.lower() == ".zip":
with zipfile.ZipFile(src) as zf:
for m in zf.namelist():
if m.startswith("__MACOSX") or m.endswith("/"):
continue
handle(m, zf.read(m))
else:
handle(src.name, src.read_bytes())
videos = sorted(p for p in videos_dir.glob("*") if p.suffix.lower() in VIDEO_EXTS)
return videos, caption_files
def build_dataset_items(
videos: list[Path], caption_all: str, caption_files: dict[str, str], needs_reference: bool,
trigger: str = "",
) -> tuple[list[dict], list[dict]]:
"""Build dataset.json rows from uploaded clips.
For IC-LoRA, pair each target `X.ext` with a user-supplied `X_reference.ext` (no
auto-derivation). Per-clip caption resolves to: the `X.txt` sidecar if present, else the
single `caption_all` (used for every pair), else "a video". If `trigger` is set, it is
prepended to every caption (e.g. "TRG, <caption>") for LoRA activation at inference.
Returns (items, items) — second value kept for signature parity with the Hub path.
Raises ValueError on missing references or no targets.
"""
vids = [v for v in videos if v.suffix.lower() in VIDEO_EXTS]
if needs_reference:
targets = sorted(v for v in vids if not _is_reference(v))
refs = {v.stem[: -len("_reference")]: v for v in vids if _is_reference(v)}
else:
targets, refs = sorted(vids), {}
default = caption_all.strip() if caption_all and caption_all.strip() else "a video"
trigger = (trigger or "").strip()
items, missing = [], []
for v in targets:
cap = caption_files.get(v.stem) or default
if trigger:
cap = f"{trigger}, {cap}"
row = {"media_path": f"videos/{v.name}", "caption": cap}
if needs_reference:
ref = refs.get(v.stem)
if ref is None:
missing.append(v.name)
continue
row["reference_video"] = f"videos/{ref.name}"
items.append(row)
if needs_reference and missing:
raise ValueError(
"Missing reference video(s) for: " + ", ".join(missing)
+ ". For IC-LoRA, every target `X.mp4` needs a paired `X_reference.mp4`."
)
if not items:
raise ValueError("No target videos found in the upload.")
return items, items
def validate_dataset(mode: str, from_hub: bool, dataset_repo: str, uploaded: list[str],
caption_all: str, token: str = "") -> str:
"""Pre-flight check (no job): does the dataset match the chosen mode? Returns a markdown report."""
needs_reference = MODES[mode]["needs_reference"]
caption_all = (caption_all or "").strip()
if from_hub:
repo = (dataset_repo or "").strip()
if not repo:
return "⚠️ Enter a Hub dataset repo id first."
from huggingface_hub import hf_hub_download # noqa: PLC0415
try:
dj = hf_hub_download(repo, "dataset.json", repo_type="dataset", token=token or None)
except Exception as e: # noqa: BLE001
return (f"❌ Couldn't read `dataset.json` from **{repo}**. The dataset must contain a "
f"trainer-format `dataset.json` (`media_path` + `caption`[ + `reference_video`]). "
f"_{str(e)[:100]}_")
try:
items = json.loads(Path(dj).read_text())
assert isinstance(items, list) and items
except Exception: # noqa: BLE001
return "❌ `dataset.json` is empty or not a JSON list."
if [it for it in items if not it.get("media_path")]:
return "❌ Some rows are missing `media_path`."
lines = [f"✅ **{len(items)} items** in `{repo}`."]
if needs_reference:
no_ref = [it for it in items if not it.get("reference_video")]
if no_ref:
return (f"❌ **IC-LoRA needs a `reference_video` per row** — {len(no_ref)}/{len(items)} "
f"rows are missing it. (Pick a different mode, or use a dataset with references.)")
lines.append("✅ `reference_video` present on every row — IC-LoRA ready.")
nocap = sum(1 for it in items if not (it.get("caption") or "").strip())
if nocap:
lines.append(f"⚠️ {nocap} rows have no caption (they'll default to *“a video”*).")
lines.append(f"📝 Sample: `{items[0]['media_path']}` — *“{(items[0].get('caption') or '')[:80]}”*")
return "\n\n".join(lines)
# upload path — inspect filenames (instant, no copying)
files = [Path(p) for p in (uploaded or [])]
vids = [p for p in files if p.suffix.lower() in VIDEO_EXTS]
txts = {p.stem for p in files if p.suffix.lower() == ".txt"}
if not vids:
return "⚠️ Upload at least one video."
if needs_reference:
targets = [p for p in vids if not _is_reference(p)]
refs = {p.stem[: -len("_reference")] for p in vids if _is_reference(p)}
if not targets:
return "❌ No target clips found — only `*_reference` files were uploaded."
missing = [p.name for p in targets if p.stem not in refs]
if missing:
return (f"❌ **{len(missing)} target(s) have no paired `*_reference`:** "
f"{', '.join(missing[:6])}{'…' if len(missing) > 6 else ''}")
with_txt = sum(1 for p in targets if p.stem in txts)
cap = (f"{with_txt}/{len(targets)} clips have a `.txt`"
+ ("" if with_txt == len(targets)
else "; the rest use the *same-caption* box" if caption_all
else f"; {len(targets) - with_txt} will default to *“a video”*"))
return f"✅ **{len(targets)} IC-LoRA pairs** — every target has its reference.\n\n📝 {cap}."
with_txt = sum(1 for p in vids if p.stem in txts)
cap = (f"{with_txt}/{len(vids)} clips have a `.txt`"
+ ("" if with_txt == len(vids) or caption_all else f"; {len(vids) - with_txt} default to *“a video”*"))
return f"✅ **{len(vids)} clips** ready for {mode}.\n\n📝 {cap}."
def build_config_dict(params: dict, items: list[dict]) -> dict:
"""Build the trainer config.yaml dict. `items` are the dataset.json rows (used to seed the
validation sample's prompt + reference). Paths are relative to the dataset dir on the Job."""
w, h, f = parse_resolution(params["resolution"])
mode_cfg = MODES[params["mode"]]
conditions: list[dict] = []
if mode_cfg["needs_reference"]:
conditions.append({"type": "reference", "latents_dir": "reference_latents", "probability": 1.0})
conditions.append({"type": "first_frame", "probability": 0.2})
elif mode_cfg.get("first_frame_prob", 0.0) > 0:
conditions.append({"type": "first_frame", "probability": mode_cfg["first_frame_prob"]})
val_sample: dict = {"prompt": items[0]["caption"] if items else "a video"}
if mode_cfg["needs_reference"]:
ref_item = next((it for it in items if it.get("reference_video")), None)
if ref_item:
val_sample["conditions"] = [
{"type": "reference", "video": f"{JOB_RUN}/dataset/{ref_item['reference_video']}",
"include_in_output": True}
]
return {
"model": {"model_path": JOB_MODEL, "text_encoder_path": JOB_GEMMA,
"training_mode": "lora", "load_checkpoint": None},
"lora": {"rank": int(params["rank"]), "alpha": int(params["alpha"]), "dropout": 0.0,
"target_modules": mode_cfg["target_modules"]},
"training_strategy": {"name": "flexible",
"video": {"is_generated": True, "latents_dir": "latents", "conditions": conditions}},
"optimization": {"learning_rate": float(params["learning_rate"]), "steps": int(params["steps"]),
"batch_size": int(params["batch_size"]),
"gradient_accumulation_steps": int(params["gradient_accumulation_steps"]),
"max_grad_norm": 1.0, "optimizer_type": params["optimizer_type"],
"scheduler_type": "linear", "scheduler_params": {},
"enable_gradient_checkpointing": True},
"acceleration": {"mixed_precision_mode": "bf16",
"quantization": params["quantization"] or None,
"load_text_encoder_in_8bit": bool(params["load_text_encoder_in_8bit"]),
"offload_optimizer_during_validation": True},
"data": {"preprocessed_data_root": f"{JOB_RUN}/dataset/.precomputed", "num_dataloader_workers": 2},
"validation": {"samples": [val_sample],
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
"video_dims": [w, h, f], "frame_rate": 25.0, "seed": int(params["seed"]),
# 30 inference steps (configs/*); validate every N steps, clamped to <= steps
# so a short run still gets at least one sample.
"inference_steps": 30,
"interval": max(1, min(int(params.get("validation_interval", 250)), int(params["steps"]))),
"guidance_scale": 4.0, "stg_scale": 1.0, "stg_blocks": [29], "stg_mode": "stg_v",
"generate_audio": False, "skip_initial_validation": True},
# keep_last_n: -1 avoids a trainer bug — when steps is a multiple of the checkpoint
# interval, the final step is saved twice and cleanup (keep_last_n>0) deletes the file
# before push_to_hub. Keeping all checkpoints (tiny for LoRA) sidesteps it.
"checkpoints": {"interval": max(int(params["steps"]), 1), "keep_last_n": -1, "precision": "bfloat16"},
"flow_matching": {"timestep_sampling_mode": "shifted_logit_normal", "timestep_sampling_params": {}},
"hub": {"push_to_hub": bool(params["push_to_hub"]), "hub_model_id": params["hub_model_id"] or None},
"wandb": {"enabled": False, "project": "ltx-2-trainer", "entity": None,
"tags": ["ltx2", "jobs"], "log_validation_videos": True},
"seed": int(params["seed"]),
"output_dir": f"{JOB_RUN}/outputs",
}
# --------------------------------------------------------------------------------------
# UV job script (runs on HF Jobs hardware)
# --------------------------------------------------------------------------------------
JOB_SCRIPT_TEMPLATE = '''# /// script
# requires-python = ">=3.10"
# dependencies = ["huggingface_hub[hf-xet]>=1.5", "hf_transfer"]
# ///
"""Auto-generated LTX-2 training job. Reproduces the trainer env via uv sync --frozen."""
import json, os, subprocess, sys
from pathlib import Path
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
JOB = Path("{JOB_ROOT}"); SRC = JOB / "src"; RUN = JOB / "run"
MODEL = "{JOB_MODEL}"; GEMMA = "{JOB_GEMMA}"
SRC_BUCKET = "{SRC_BUCKET}"; RUN_BUCKET = "{RUN_BUCKET}"
def sh(cmd, cwd=None):
print(">>>", " ".join(cmd), flush=True)
subprocess.run(cmd, cwd=str(cwd) if cwd else None, check=True)
def main():
JOB.mkdir(parents=True, exist_ok=True)
print("=== 0/6 system libs (opencv needs libGL) ===", flush=True)
subprocess.run("apt-get update && apt-get install -y --no-install-recommends libgl1 libglib2.0-0",
shell=True, check=False)
print("=== 1/6 sync source bucket ===", flush=True)
sh(["hf", "buckets", "sync", f"hf://buckets/{{SRC_BUCKET}}", str(SRC)])
print("=== 2/6 sync run bucket (dataset + config) ===", flush=True)
sh(["hf", "buckets", "sync", f"hf://buckets/{{RUN_BUCKET}}", str(RUN)])
jc = json.loads((RUN / "job_config.json").read_text())
print("=== 3/6 uv sync (reproduce trainer env) ===", flush=True)
sh(["uv", "sync", "--frozen"], cwd=SRC)
print("=== 4/6 download base checkpoint + Gemma ===", flush=True)
from huggingface_hub import hf_hub_download, snapshot_download
hf_hub_download("Lightricks/LTX-2.3", "ltx-2.3-22b-dev.safetensors", local_dir=str(JOB / "models"))
snapshot_download("google/gemma-3-12b-it-qat-q4_0-unquantized", local_dir=GEMMA)
tr = SRC / "packages" / "ltx-trainer"
def uvrun(args):
sh(["uv", "run", "python", *args], cwd=tr)
if jc.get("dataset_repo"):
print("=== 4b/6 download Hub dataset ===", flush=True)
snapshot_download(jc["dataset_repo"], repo_type="dataset", local_dir=str(RUN / "dataset"))
ds_json = RUN / "dataset" / "dataset.json"
print("=== 5/6 preprocess dataset ===", flush=True)
uvrun(["scripts/process_dataset.py", str(ds_json),
"--resolution-buckets", jc["resolution"],
"--model-path", MODEL, "--text-encoder-path", GEMMA, "--skip-audio"])
print("=== 6/6 train (pushes LoRA to the Hub) ===", flush=True)
uvrun(["scripts/train.py", str(RUN / "config.yaml"), "--disable-progress-bars"])
print("=== DONE ===", flush=True)
if __name__ == "__main__":
main()
'''
def _run_bucket_name(run_name: str) -> str:
safe = re.sub(r"[^a-zA-Z0-9-]+", "-", run_name).strip("-").lower() or "run"
return f"ltx2-train-{safe}"
def _namespace(token: str | None) -> str:
from huggingface_hub import whoami # noqa: PLC0415
return whoami(token=token or os.environ.get("HF_TOKEN"))["name"]
def submit(params: dict, uploaded_videos: list[str], flavor: str, timeout: str) -> dict:
"""Stage data → bucket → generate UV script → submit job. Returns {job_id, url, bucket, log}."""
token = (params.get("hf_token") or os.environ.get("HF_TOKEN") or "").strip()
ns = _namespace(token)
bucket = f"{ns}/{_run_bucket_name(params['run_name'])}"
needs_reference = MODES[params["mode"]]["needs_reference"]
dataset_repo = (params.get("dataset_repo") or "").strip()
tmp = Path(tempfile.mkdtemp(prefix="ltxrun-"))
try:
job_cfg = {"resolution": params["resolution"]}
if dataset_repo:
# Hub dataset: read its dataset.json (to seed validation) — the Job downloads the
# full dataset itself. The dataset repo must contain a trainer-format dataset.json.
from huggingface_hub import hf_hub_download # noqa: PLC0415
try:
dj = hf_hub_download(dataset_repo, "dataset.json", repo_type="dataset", token=token)
except Exception as e: # noqa: BLE001
raise ValueError(
f"Could not read dataset.json from dataset `{dataset_repo}`. The Hub dataset "
f"must contain a trainer-format dataset.json (media_path + caption "
f"[+ reference_video]). ({e})"
) from e
items = json.loads(Path(dj).read_text())
if needs_reference and not any(it.get("reference_video") for it in items):
raise ValueError(
f"Dataset `{dataset_repo}` has no `reference_video` column — required for IC-LoRA."
)
job_cfg["dataset_repo"] = dataset_repo
else:
videos, caption_files = _stage_uploads(uploaded_videos, tmp / "dataset" / "videos")
if not videos:
raise ValueError("No valid video files in the upload.")
items, _ = build_dataset_items(videos, params.get("caption_all", ""), caption_files,
needs_reference, trigger=params.get("trigger_word", ""))
(tmp / "dataset" / "dataset.json").write_text(json.dumps(items, indent=2))
# config.yaml + job_config.json (read at the run root, sibling of dataset/ on the Job)
cfg = build_config_dict(params, items)
(tmp / "config.yaml").write_text(yaml.safe_dump(cfg, sort_keys=False))
(tmp / "job_config.json").write_text(json.dumps(job_cfg, indent=2))
# create + sync the per-run bucket via the huggingface_hub Python API (no `hf` CLI —
# the CLI can crash at import on some Space base images due to dep conflicts).
api = HfApi(token=token)
bucket_name = bucket.split("/", 1)[1]
api.create_bucket(bucket_name, exist_ok=True, token=token)
try:
api.sync_bucket(source=str(tmp), dest=f"hf://buckets/{bucket}", token=token)
except Exception as e: # noqa: BLE001
raise RuntimeError(f"bucket sync failed: {e}") from e
# render + write the job script, then submit it as a detached UV job
script = JOB_SCRIPT_TEMPLATE.format(
JOB_ROOT=JOB_ROOT, JOB_MODEL=JOB_MODEL, JOB_GEMMA=JOB_GEMMA,
SRC_BUCKET=SRC_BUCKET, RUN_BUCKET=bucket,
)
script_path = tmp / "job_train.py"
script_path.write_text(script)
job = api.run_uv_job(
str(script_path), flavor=flavor, timeout=timeout,
secrets={"HF_TOKEN": token}, token=token,
)
job_id = getattr(job, "id", "") or ""
url = getattr(job, "url", "") or (f"https://huggingface.co/jobs/{ns}/{job_id}" if job_id else "")
return {"job_id": job_id, "url": url, "bucket": bucket, "log": f"Submitted job {job_id} on {flavor}."}
finally:
shutil.rmtree(tmp, ignore_errors=True)
def job_logs(job_id: str, token: str = "") -> str:
try:
return "\n".join(HfApi(token=token).fetch_job_logs(job_id=job_id, token=token))
except Exception as e: # noqa: BLE001
return f"(could not fetch logs: {e})"
def job_status(job_id: str, token: str = "") -> str:
try:
job = HfApi(token=token).inspect_job(job_id=job_id, token=token)
status = getattr(job, "status", None)
stage = getattr(status, "stage", None)
if stage is None and isinstance(status, dict):
stage = status.get("stage")
return str(stage or status or "UNKNOWN")
except Exception as e: # noqa: BLE001
return f"UNKNOWN ({e})"