"""HF Jobs backend for the LTX-2 trainer Space. Per training request: 1. stage the dataset (videos + dataset.json) + config.yaml + job_config.json locally, 2. sync them to a per-run HF bucket, 3. generate a self-contained UV job script, 4. submit it with `hf jobs uv run --flavor a100-large --secrets HF_TOKEN --detach`. On the Job, the script syncs the source bucket + run bucket, runs `uv sync --frozen` (reproducing the working trainer env from the lockfile), downloads the base checkpoint and Gemma, then runs process_dataset.py → train.py, which pushes the trained LoRA to the Hub. For IC-LoRA, references are user-supplied (paired `*_reference` videos) — no auto-derivation. The Space itself only needs gradio + huggingface_hub + pyyaml (no torch). """ from __future__ import annotations import json import os import re import shutil import tempfile import zipfile from pathlib import Path import yaml from huggingface_hub import HfApi SRC_BUCKET = os.environ.get("LTX_SRC_BUCKET", "ltx-community/ltx2-trainer-src-v2") DEFAULT_FLAVOR = "rtx-pro-6000" # default single-GPU flavor (Blackwell, 96GB; cu128 torch in v2) VIDEO_EXTS = {".mp4", ".mov", ".mkv", ".webm", ".avi", ".m4v"} # Deterministic on-Job paths (the Space bakes these into config.yaml). JOB_ROOT = "/tmp/ltxjob" JOB_MODEL = f"{JOB_ROOT}/models/ltx-2.3-22b-dev.safetensors" JOB_GEMMA = f"{JOB_ROOT}/gemma" JOB_RUN = f"{JOB_ROOT}/run" # `recommended` holds the per-mode hyperparameters the example configs use (configs/*.yaml): # IC-LoRA/v2v → lr 2e-4, 3000 steps; T2V/I2V → lr 1e-4, 2000 steps. Rank/alpha stay 32. MODES = { "IC-LoRA (in-context control)": { "needs_reference": True, "recommended": {"learning_rate": 2e-4, "steps": 3000}, "target_modules": [ "attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0", "attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0", "ff.net.0.proj", "ff.net.2", ], }, "Text-to-Video LoRA": {"needs_reference": False, "first_frame_prob": 0.0, "recommended": {"learning_rate": 1e-4, "steps": 2000}, "target_modules": ["to_k", "to_q", "to_v", "to_out.0"]}, "Image-to-Video LoRA": {"needs_reference": False, "first_frame_prob": 0.5, "recommended": {"learning_rate": 1e-4, "steps": 2000}, "target_modules": ["to_k", "to_q", "to_v", "to_out.0"]}, } def parse_resolution(resolution: str) -> tuple[int, int, int]: parts = resolution.lower().replace(" ", "").split("x") if len(parts) != 3: raise ValueError(f"Resolution must be 'WxHxF', got {resolution!r}") w, h, f = (int(p) for p in parts) if w % 32 or h % 32: raise ValueError(f"Width and height must be divisible by 32 (got {w}x{h}).") if f % 8 != 1: raise ValueError(f"Frame count must satisfy frames % 8 == 1 (got {f}).") return w, h, f def _is_reference(p: Path) -> bool: return p.stem.endswith("_reference") def _stage_uploads(uploaded: list[str], videos_dir: Path) -> tuple[list[Path], dict[str, str]]: """Copy uploaded videos into videos_dir and read per-clip `.txt` caption sidecars. Returns (sorted video paths, {stem: caption}). `.txt` files are matched to a video by stem (e.g. `clip.txt` captions `clip.mp4`). Handles `.zip` archives too. """ videos_dir.mkdir(parents=True, exist_ok=True) caption_files: dict[str, str] = {} def handle(name: str, data: bytes) -> None: suffix = Path(name).suffix.lower() if suffix in VIDEO_EXTS: (videos_dir / Path(name).name).write_bytes(data) elif suffix == ".txt": caption_files[Path(name).stem] = data.decode("utf-8", "ignore").strip() for p in uploaded or []: src = Path(p) if src.suffix.lower() == ".zip": with zipfile.ZipFile(src) as zf: for m in zf.namelist(): if m.startswith("__MACOSX") or m.endswith("/"): continue handle(m, zf.read(m)) else: handle(src.name, src.read_bytes()) videos = sorted(p for p in videos_dir.glob("*") if p.suffix.lower() in VIDEO_EXTS) return videos, caption_files def build_dataset_items( videos: list[Path], caption_all: str, caption_files: dict[str, str], needs_reference: bool, trigger: str = "", ) -> tuple[list[dict], list[dict]]: """Build dataset.json rows from uploaded clips. For IC-LoRA, pair each target `X.ext` with a user-supplied `X_reference.ext` (no auto-derivation). Per-clip caption resolves to: the `X.txt` sidecar if present, else the single `caption_all` (used for every pair), else "a video". If `trigger` is set, it is prepended to every caption (e.g. "TRG, ") for LoRA activation at inference. Returns (items, items) — second value kept for signature parity with the Hub path. Raises ValueError on missing references or no targets. """ vids = [v for v in videos if v.suffix.lower() in VIDEO_EXTS] if needs_reference: targets = sorted(v for v in vids if not _is_reference(v)) refs = {v.stem[: -len("_reference")]: v for v in vids if _is_reference(v)} else: targets, refs = sorted(vids), {} default = caption_all.strip() if caption_all and caption_all.strip() else "a video" trigger = (trigger or "").strip() items, missing = [], [] for v in targets: cap = caption_files.get(v.stem) or default if trigger: cap = f"{trigger}, {cap}" row = {"media_path": f"videos/{v.name}", "caption": cap} if needs_reference: ref = refs.get(v.stem) if ref is None: missing.append(v.name) continue row["reference_video"] = f"videos/{ref.name}" items.append(row) if needs_reference and missing: raise ValueError( "Missing reference video(s) for: " + ", ".join(missing) + ". For IC-LoRA, every target `X.mp4` needs a paired `X_reference.mp4`." ) if not items: raise ValueError("No target videos found in the upload.") return items, items def validate_dataset(mode: str, from_hub: bool, dataset_repo: str, uploaded: list[str], caption_all: str, token: str = "") -> str: """Pre-flight check (no job): does the dataset match the chosen mode? Returns a markdown report.""" needs_reference = MODES[mode]["needs_reference"] caption_all = (caption_all or "").strip() if from_hub: repo = (dataset_repo or "").strip() if not repo: return "⚠️ Enter a Hub dataset repo id first." from huggingface_hub import hf_hub_download # noqa: PLC0415 try: dj = hf_hub_download(repo, "dataset.json", repo_type="dataset", token=token or None) except Exception as e: # noqa: BLE001 return (f"❌ Couldn't read `dataset.json` from **{repo}**. The dataset must contain a " f"trainer-format `dataset.json` (`media_path` + `caption`[ + `reference_video`]). " f"_{str(e)[:100]}_") try: items = json.loads(Path(dj).read_text()) assert isinstance(items, list) and items except Exception: # noqa: BLE001 return "❌ `dataset.json` is empty or not a JSON list." if [it for it in items if not it.get("media_path")]: return "❌ Some rows are missing `media_path`." lines = [f"✅ **{len(items)} items** in `{repo}`."] if needs_reference: no_ref = [it for it in items if not it.get("reference_video")] if no_ref: return (f"❌ **IC-LoRA needs a `reference_video` per row** — {len(no_ref)}/{len(items)} " f"rows are missing it. (Pick a different mode, or use a dataset with references.)") lines.append("✅ `reference_video` present on every row — IC-LoRA ready.") nocap = sum(1 for it in items if not (it.get("caption") or "").strip()) if nocap: lines.append(f"⚠️ {nocap} rows have no caption (they'll default to *“a video”*).") lines.append(f"📝 Sample: `{items[0]['media_path']}` — *“{(items[0].get('caption') or '')[:80]}”*") return "\n\n".join(lines) # upload path — inspect filenames (instant, no copying) files = [Path(p) for p in (uploaded or [])] vids = [p for p in files if p.suffix.lower() in VIDEO_EXTS] txts = {p.stem for p in files if p.suffix.lower() == ".txt"} if not vids: return "⚠️ Upload at least one video." if needs_reference: targets = [p for p in vids if not _is_reference(p)] refs = {p.stem[: -len("_reference")] for p in vids if _is_reference(p)} if not targets: return "❌ No target clips found — only `*_reference` files were uploaded." missing = [p.name for p in targets if p.stem not in refs] if missing: return (f"❌ **{len(missing)} target(s) have no paired `*_reference`:** " f"{', '.join(missing[:6])}{'…' if len(missing) > 6 else ''}") with_txt = sum(1 for p in targets if p.stem in txts) cap = (f"{with_txt}/{len(targets)} clips have a `.txt`" + ("" if with_txt == len(targets) else "; the rest use the *same-caption* box" if caption_all else f"; {len(targets) - with_txt} will default to *“a video”*")) return f"✅ **{len(targets)} IC-LoRA pairs** — every target has its reference.\n\n📝 {cap}." with_txt = sum(1 for p in vids if p.stem in txts) cap = (f"{with_txt}/{len(vids)} clips have a `.txt`" + ("" if with_txt == len(vids) or caption_all else f"; {len(vids) - with_txt} default to *“a video”*")) return f"✅ **{len(vids)} clips** ready for {mode}.\n\n📝 {cap}." def build_config_dict(params: dict, items: list[dict]) -> dict: """Build the trainer config.yaml dict. `items` are the dataset.json rows (used to seed the validation sample's prompt + reference). Paths are relative to the dataset dir on the Job.""" w, h, f = parse_resolution(params["resolution"]) mode_cfg = MODES[params["mode"]] conditions: list[dict] = [] if mode_cfg["needs_reference"]: conditions.append({"type": "reference", "latents_dir": "reference_latents", "probability": 1.0}) conditions.append({"type": "first_frame", "probability": 0.2}) elif mode_cfg.get("first_frame_prob", 0.0) > 0: conditions.append({"type": "first_frame", "probability": mode_cfg["first_frame_prob"]}) val_sample: dict = {"prompt": items[0]["caption"] if items else "a video"} if mode_cfg["needs_reference"]: ref_item = next((it for it in items if it.get("reference_video")), None) if ref_item: val_sample["conditions"] = [ {"type": "reference", "video": f"{JOB_RUN}/dataset/{ref_item['reference_video']}", "include_in_output": True} ] return { "model": {"model_path": JOB_MODEL, "text_encoder_path": JOB_GEMMA, "training_mode": "lora", "load_checkpoint": None}, "lora": {"rank": int(params["rank"]), "alpha": int(params["alpha"]), "dropout": 0.0, "target_modules": mode_cfg["target_modules"]}, "training_strategy": {"name": "flexible", "video": {"is_generated": True, "latents_dir": "latents", "conditions": conditions}}, "optimization": {"learning_rate": float(params["learning_rate"]), "steps": int(params["steps"]), "batch_size": int(params["batch_size"]), "gradient_accumulation_steps": int(params["gradient_accumulation_steps"]), "max_grad_norm": 1.0, "optimizer_type": params["optimizer_type"], "scheduler_type": "linear", "scheduler_params": {}, "enable_gradient_checkpointing": True}, "acceleration": {"mixed_precision_mode": "bf16", "quantization": params["quantization"] or None, "load_text_encoder_in_8bit": bool(params["load_text_encoder_in_8bit"]), "offload_optimizer_during_validation": True}, "data": {"preprocessed_data_root": f"{JOB_RUN}/dataset/.precomputed", "num_dataloader_workers": 2}, "validation": {"samples": [val_sample], "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", "video_dims": [w, h, f], "frame_rate": 25.0, "seed": int(params["seed"]), # 30 inference steps (configs/*); validate every N steps, clamped to <= steps # so a short run still gets at least one sample. "inference_steps": 30, "interval": max(1, min(int(params.get("validation_interval", 250)), int(params["steps"]))), "guidance_scale": 4.0, "stg_scale": 1.0, "stg_blocks": [29], "stg_mode": "stg_v", "generate_audio": False, "skip_initial_validation": True}, # keep_last_n: -1 avoids a trainer bug — when steps is a multiple of the checkpoint # interval, the final step is saved twice and cleanup (keep_last_n>0) deletes the file # before push_to_hub. Keeping all checkpoints (tiny for LoRA) sidesteps it. "checkpoints": {"interval": max(int(params["steps"]), 1), "keep_last_n": -1, "precision": "bfloat16"}, "flow_matching": {"timestep_sampling_mode": "shifted_logit_normal", "timestep_sampling_params": {}}, "hub": {"push_to_hub": bool(params["push_to_hub"]), "hub_model_id": params["hub_model_id"] or None}, "wandb": {"enabled": False, "project": "ltx-2-trainer", "entity": None, "tags": ["ltx2", "jobs"], "log_validation_videos": True}, "seed": int(params["seed"]), "output_dir": f"{JOB_RUN}/outputs", } # -------------------------------------------------------------------------------------- # UV job script (runs on HF Jobs hardware) # -------------------------------------------------------------------------------------- JOB_SCRIPT_TEMPLATE = '''# /// script # requires-python = ">=3.10" # dependencies = ["huggingface_hub[hf-xet]>=1.5", "hf_transfer"] # /// """Auto-generated LTX-2 training job. Reproduces the trainer env via uv sync --frozen.""" import json, os, subprocess, sys from pathlib import Path os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" JOB = Path("{JOB_ROOT}"); SRC = JOB / "src"; RUN = JOB / "run" MODEL = "{JOB_MODEL}"; GEMMA = "{JOB_GEMMA}" SRC_BUCKET = "{SRC_BUCKET}"; RUN_BUCKET = "{RUN_BUCKET}" def sh(cmd, cwd=None): print(">>>", " ".join(cmd), flush=True) subprocess.run(cmd, cwd=str(cwd) if cwd else None, check=True) def main(): JOB.mkdir(parents=True, exist_ok=True) print("=== 0/6 system libs (opencv needs libGL) ===", flush=True) subprocess.run("apt-get update && apt-get install -y --no-install-recommends libgl1 libglib2.0-0", shell=True, check=False) print("=== 1/6 sync source bucket ===", flush=True) sh(["hf", "buckets", "sync", f"hf://buckets/{{SRC_BUCKET}}", str(SRC)]) print("=== 2/6 sync run bucket (dataset + config) ===", flush=True) sh(["hf", "buckets", "sync", f"hf://buckets/{{RUN_BUCKET}}", str(RUN)]) jc = json.loads((RUN / "job_config.json").read_text()) print("=== 3/6 uv sync (reproduce trainer env) ===", flush=True) sh(["uv", "sync", "--frozen"], cwd=SRC) print("=== 4/6 download base checkpoint + Gemma ===", flush=True) from huggingface_hub import hf_hub_download, snapshot_download hf_hub_download("Lightricks/LTX-2.3", "ltx-2.3-22b-dev.safetensors", local_dir=str(JOB / "models")) snapshot_download("google/gemma-3-12b-it-qat-q4_0-unquantized", local_dir=GEMMA) tr = SRC / "packages" / "ltx-trainer" def uvrun(args): sh(["uv", "run", "python", *args], cwd=tr) if jc.get("dataset_repo"): print("=== 4b/6 download Hub dataset ===", flush=True) snapshot_download(jc["dataset_repo"], repo_type="dataset", local_dir=str(RUN / "dataset")) ds_json = RUN / "dataset" / "dataset.json" print("=== 5/6 preprocess dataset ===", flush=True) uvrun(["scripts/process_dataset.py", str(ds_json), "--resolution-buckets", jc["resolution"], "--model-path", MODEL, "--text-encoder-path", GEMMA, "--skip-audio"]) print("=== 6/6 train (pushes LoRA to the Hub) ===", flush=True) uvrun(["scripts/train.py", str(RUN / "config.yaml"), "--disable-progress-bars"]) print("=== DONE ===", flush=True) if __name__ == "__main__": main() ''' def _run_bucket_name(run_name: str) -> str: safe = re.sub(r"[^a-zA-Z0-9-]+", "-", run_name).strip("-").lower() or "run" return f"ltx2-train-{safe}" def _namespace(token: str | None) -> str: from huggingface_hub import whoami # noqa: PLC0415 return whoami(token=token or os.environ.get("HF_TOKEN"))["name"] def submit(params: dict, uploaded_videos: list[str], flavor: str, timeout: str) -> dict: """Stage data → bucket → generate UV script → submit job. Returns {job_id, url, bucket, log}.""" token = (params.get("hf_token") or os.environ.get("HF_TOKEN") or "").strip() ns = _namespace(token) bucket = f"{ns}/{_run_bucket_name(params['run_name'])}" needs_reference = MODES[params["mode"]]["needs_reference"] dataset_repo = (params.get("dataset_repo") or "").strip() tmp = Path(tempfile.mkdtemp(prefix="ltxrun-")) try: job_cfg = {"resolution": params["resolution"]} if dataset_repo: # Hub dataset: read its dataset.json (to seed validation) — the Job downloads the # full dataset itself. The dataset repo must contain a trainer-format dataset.json. from huggingface_hub import hf_hub_download # noqa: PLC0415 try: dj = hf_hub_download(dataset_repo, "dataset.json", repo_type="dataset", token=token) except Exception as e: # noqa: BLE001 raise ValueError( f"Could not read dataset.json from dataset `{dataset_repo}`. The Hub dataset " f"must contain a trainer-format dataset.json (media_path + caption " f"[+ reference_video]). ({e})" ) from e items = json.loads(Path(dj).read_text()) if needs_reference and not any(it.get("reference_video") for it in items): raise ValueError( f"Dataset `{dataset_repo}` has no `reference_video` column — required for IC-LoRA." ) job_cfg["dataset_repo"] = dataset_repo else: videos, caption_files = _stage_uploads(uploaded_videos, tmp / "dataset" / "videos") if not videos: raise ValueError("No valid video files in the upload.") items, _ = build_dataset_items(videos, params.get("caption_all", ""), caption_files, needs_reference, trigger=params.get("trigger_word", "")) (tmp / "dataset" / "dataset.json").write_text(json.dumps(items, indent=2)) # config.yaml + job_config.json (read at the run root, sibling of dataset/ on the Job) cfg = build_config_dict(params, items) (tmp / "config.yaml").write_text(yaml.safe_dump(cfg, sort_keys=False)) (tmp / "job_config.json").write_text(json.dumps(job_cfg, indent=2)) # create + sync the per-run bucket via the huggingface_hub Python API (no `hf` CLI — # the CLI can crash at import on some Space base images due to dep conflicts). api = HfApi(token=token) bucket_name = bucket.split("/", 1)[1] api.create_bucket(bucket_name, exist_ok=True, token=token) try: api.sync_bucket(source=str(tmp), dest=f"hf://buckets/{bucket}", token=token) except Exception as e: # noqa: BLE001 raise RuntimeError(f"bucket sync failed: {e}") from e # render + write the job script, then submit it as a detached UV job script = JOB_SCRIPT_TEMPLATE.format( JOB_ROOT=JOB_ROOT, JOB_MODEL=JOB_MODEL, JOB_GEMMA=JOB_GEMMA, SRC_BUCKET=SRC_BUCKET, RUN_BUCKET=bucket, ) script_path = tmp / "job_train.py" script_path.write_text(script) job = api.run_uv_job( str(script_path), flavor=flavor, timeout=timeout, secrets={"HF_TOKEN": token}, token=token, ) job_id = getattr(job, "id", "") or "" url = getattr(job, "url", "") or (f"https://huggingface.co/jobs/{ns}/{job_id}" if job_id else "") return {"job_id": job_id, "url": url, "bucket": bucket, "log": f"Submitted job {job_id} on {flavor}."} finally: shutil.rmtree(tmp, ignore_errors=True) def job_logs(job_id: str, token: str = "") -> str: try: return "\n".join(HfApi(token=token).fetch_job_logs(job_id=job_id, token=token)) except Exception as e: # noqa: BLE001 return f"(could not fetch logs: {e})" def job_status(job_id: str, token: str = "") -> str: try: job = HfApi(token=token).inspect_job(job_id=job_id, token=token) status = getattr(job, "status", None) stage = getattr(status, "stage", None) if stage is None and isinstance(status, dict): stage = status.get("stage") return str(stage or status or "UNKNOWN") except Exception as e: # noqa: BLE001 return f"UNKNOWN ({e})"