linoyts's picture
linoyts HF Staff
Default GPU flavor -> rtx-pro-6000 (Blackwell)
498a04b verified
Raw
History Blame Contribute Delete
23 kB
"""LTX-2.3 LoRA Trainer — HF Space (HF Jobs + buckets).
Sign in with Hugging Face, provide a dataset (upload videos or point to a Hub dataset), set
hyperparameters, and submit a training job to HF Jobs. The job reproduces the trainer env from
the lockfile, trains a LoRA / IC-LoRA for LTX-2.3, and pushes it to your Hub repo. Everything
runs under the *signed-in user's* account — no pasted tokens.
Runs on `cpu-basic` — the Space only submits + monitors jobs (no GPU, no torch here).
"""
from __future__ import annotations
import re
import gradio as gr
import jobs
# Single-GPU flavors that fit the 22B model. The v2 trainer pins torch 2.9.1/cu128, which has
# Blackwell (sm_120) kernels, so rtx-pro-6000 works here. Multi-GPU flavors omitted (single-GPU job).
FLAVORS = ["l40sx1", "a100-large", "rtx-pro-6000", "h200"]
FLAVOR_GUIDE = """**Which GPU?** You're billed per-minute of actual runtime.
| Flavor | VRAM | $/h | Best for |
|---|---|---|---|
| `l40sx1` | 48 GB | $1.80 | cheapest — pair with the **Low-VRAM** profile (int8) |
| `a100-large` | 80 GB | $2.50 | fits the 22B model in bf16 |
| `rtx-pro-6000` | 96 GB | $2.75 | **recommended (default)** — Blackwell, more headroom |
| `h200` | 141 GB | $5.00 | fastest — best for higher resolution / longer clips |
"""
MAX_LOG = 60_000
MODE_KEYS = list(jobs.MODES.keys())
# Resolution presets (label → "WxHxF"); grounded in configs/*.yaml + dataset-preparation.md.
# Editable dropdown — users can also type a custom W×H×F (validated on submit: W,H÷32, F%8==1).
RESOLUTION_PRESETS = [
("512×320×25 · quick smoke test", "512x320x25"),
("768×512×49 · balanced (default)", "768x512x49"),
("960×544×49 · higher quality", "960x544x49"),
("512×512×81 · square, longer", "512x512x81"),
("576×576×49 · low-VRAM", "576x576x49"),
]
THEME = gr.themes.Base(
primary_hue=gr.themes.colors.orange,
secondary_hue=gr.themes.colors.amber,
neutral_hue=gr.themes.colors.stone,
font=[gr.themes.GoogleFont("Space Grotesk"), "ui-sans-serif", "system-ui", "sans-serif"],
radius_size=gr.themes.sizes.radius_lg,
).set(
body_background_fill="#0b0b0d", body_background_fill_dark="#0b0b0d",
body_text_color="#ECEAE6", body_text_color_dark="#ECEAE6",
body_text_color_subdued="#b3aea4", body_text_color_subdued_dark="#b3aea4",
background_fill_primary="#0b0b0d", background_fill_primary_dark="#0b0b0d",
background_fill_secondary="rgba(255,255,255,0.03)", background_fill_secondary_dark="rgba(255,255,255,0.03)",
block_background_fill="rgba(255,255,255,0.04)", block_background_fill_dark="rgba(255,255,255,0.04)",
block_border_color="rgba(255,255,255,0.08)", block_border_color_dark="rgba(255,255,255,0.08)",
block_label_text_color="#cfcabf", block_label_text_color_dark="#cfcabf",
block_info_text_color="#c4bfb5", block_info_text_color_dark="#c4bfb5",
input_placeholder_color="#a39e94", input_placeholder_color_dark="#a39e94",
block_title_text_color="#ECEAE6", block_title_text_color_dark="#ECEAE6",
border_color_primary="rgba(255,255,255,0.10)", border_color_primary_dark="rgba(255,255,255,0.10)",
input_background_fill="rgba(255,255,255,0.04)", input_background_fill_dark="rgba(255,255,255,0.04)",
input_border_color="rgba(255,255,255,0.10)", input_border_color_dark="rgba(255,255,255,0.10)",
button_primary_background_fill="linear-gradient(92deg,#ff8a3d,#ffc04d)",
button_primary_background_fill_dark="linear-gradient(92deg,#ff8a3d,#ffc04d)",
button_primary_background_fill_hover="linear-gradient(92deg,#ff9b54,#ffcc66)",
button_primary_text_color="#1c1206", button_primary_text_color_dark="#1c1206",
button_secondary_background_fill="rgba(255,255,255,0.06)", button_secondary_background_fill_dark="rgba(255,255,255,0.06)",
button_secondary_text_color="#ECEAE6", button_secondary_text_color_dark="#ECEAE6",
color_accent_soft="rgba(255,160,60,0.14)",
)
CSS = """
.gradio-container {
background:
radial-gradient(1100px 480px at 50% -8%, rgba(255,150,60,0.11), transparent 60%),
radial-gradient(820px 460px at 92% 2%, rgba(255,205,90,0.06), transparent 55%),
#0b0b0d !important;
}
#hero h1 {margin:0; font-size:1.95rem; letter-spacing:-.01em; font-weight:700;}
#hero p {margin:.3rem 0 0; color:#9b968e;}
#banner {border-radius:12px; padding:10px 14px; font-size:.95rem;
background:rgba(255,255,255,0.04); border:1px solid rgba(255,255,255,0.08);}
/* glass card around each step */
.tabitem {background:rgba(255,255,255,0.035) !important; border:1px solid rgba(255,255,255,0.08) !important;
border-radius:16px !important; backdrop-filter:blur(8px); padding:18px !important;}
/* segmented pill tabs */
.tab-nav {border:none !important; gap:6px; margin-bottom:10px;}
.tab-nav button {border:1px solid rgba(255,255,255,0.08) !important; border-radius:999px !important;
padding:6px 16px !important; background:rgba(255,255,255,0.03) !important; color:#bdb8b0 !important;}
.tab-nav button.selected {background:linear-gradient(92deg,#ff8a3d,#ffc04d) !important;
color:#1c1206 !important; border-color:transparent !important; font-weight:600;}
button.primary {box-shadow:0 6px 22px rgba(255,140,50,0.28);}
/* readable placeholders + helper/info text on the dark inputs */
.gradio-container input::placeholder, .gradio-container textarea::placeholder {color:#b3ada3 !important; opacity:1 !important;}
.gradio-container .info, .gradio-container [data-testid="block-info"] {color:#c4bfb5 !important;}
/* flatten the upload/hub groups so they don't stack a second lighter panel over the tab card */
.flat-group {background:transparent !important; border:none !important; box-shadow:none !important; padding:0 !important;}
/* file drop zone: a single input-shade panel (kill the lighter inner button) + readable text */
#ds-upload {background:rgba(255,255,255,0.04) !important;}
#ds-upload * {background-color:transparent !important; color:#cfcabf !important;}
footer {visibility:hidden;}
"""
MODE_HELP = (
"**IC-LoRA (in-context control)** learns from *pairs* — a target clip `clip.mp4` and its "
"control video `clip_reference.mp4` (depth, pose, edges, an inpainting-masked version, …), "
"matched by filename. **T2V / I2V** use only the target clips."
)
# Collapsed into a click-to-expand disclosure so it stays out of the way until needed.
UPLOAD_HELP = (
"<details><summary>ℹ️ <b>Captioning tips</b></summary>\n\n"
"- Drop a <code>clip.txt</code> next to each target clip for a <b>per-clip caption</b>. "
"The “same caption for all” box is used only where a clip has no <code>.txt</code>.\n"
"- LTX-2.3 likes <b>long, detailed, chronological</b> captions (~200 words): describe motion, "
"camera, lighting, and any audio.\n"
"- The <b>trigger word</b> is prepended to every caption so you can invoke the LoRA at inference.\n"
"- For <b>IC-LoRA</b>, also upload each <code>clip_reference.*</code> (matched by filename).\n"
"</details>"
)
HUB_HELP = (
"Point to a Hub **dataset** repo containing a trainer-format `dataset.json` "
"(`media_path` + `caption`, plus `reference_video` for IC-LoRA) and its media. "
"The job downloads it directly."
)
def _signin_state(profile: gr.OAuthProfile | None):
if profile is None:
return (
"🔒 **You're not signed in.** Use **Sign in with Hugging Face** (top-right) to start — "
"your dataset, the training job, and the resulting LoRA all run under **your** account "
"and billing.",
gr.update(),
)
return (
f"✅ Signed in as **{profile.username}** — jobs, buckets and the pushed LoRA live under your account.",
gr.update(value=f"{profile.username}/ltx2-lora"),
)
def _toggle_source(source: str):
upload = source == "Upload videos"
return gr.update(visible=upload), gr.update(visible=not upload)
def _apply_mode_defaults(mode: str, lr_dirty: bool, steps_dirty: bool):
"""On mode change, fill LR/Steps with the mode's recommended recipe — but only for fields
the user hasn't manually edited (dirty flags). Returns updates for (lr, steps)."""
rec = jobs.MODES[mode].get("recommended", {})
lr_u = gr.update() if lr_dirty else gr.update(value=rec.get("learning_rate"))
steps_u = gr.update() if steps_dirty else gr.update(value=rec.get("steps"))
return lr_u, steps_u
_FIRST_REC = jobs.MODES[MODE_KEYS[0]]["recommended"] # IC-LoRA recipe for initial render
# Performance profiles bundle the memory recipe (configs/*.yaml). Quality = 80 GB bf16;
# Low-VRAM = the t2v_lora_low_vram.yaml recipe (~32 GB). Custom = manual control.
PROFILE_QUALITY = "Quality (80 GB)"
PROFILE_LOWVRAM = "Low-VRAM (≤32 GB)"
PROFILE_CUSTOM = "Custom"
PROFILES = {
PROFILE_QUALITY: {"quantization": "none", "optimizer_type": "adamw", "te_8bit": False,
"rank": 32, "alpha": 32, "flavor": "rtx-pro-6000"},
PROFILE_LOWVRAM: {"quantization": "int8-quanto", "optimizer_type": "adamw8bit", "te_8bit": True,
"rank": 16, "alpha": 16, "flavor": "l40sx1"},
}
def _apply_profile(profile: str, rank_dirty: bool, alpha_dirty: bool, flavor_dirty: bool):
"""Cascade a profile to (quantization, optimizer, te_8bit, rank, alpha, manual-group visibility,
flavor). rank/alpha/flavor respect dirty flags (preserve manual edits); Custom reveals manual
controls and leaves flavor alone."""
if profile == PROFILE_CUSTOM:
return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(),
gr.update(visible=True), gr.update())
p = PROFILES[profile]
return (
gr.update(value=p["quantization"]), gr.update(value=p["optimizer_type"]), gr.update(value=p["te_8bit"]),
gr.update() if rank_dirty else gr.update(value=p["rank"]),
gr.update() if alpha_dirty else gr.update(value=p["alpha"]),
gr.update(visible=False),
gr.update() if flavor_dirty else gr.update(value=p["flavor"]),
)
def _suggest_profile(flavor: str, profile_dirty: bool):
"""Suggest a profile from the GPU flavor — unless the user has manually chosen one."""
if profile_dirty:
return gr.update()
return gr.update(value=PROFILE_LOWVRAM if flavor.startswith("l40s") else PROFILE_QUALITY)
def submit_job(
files, caption_all, trigger_word, dataset_source, dataset_repo, run_name, mode, resolution, rank,
alpha, lr, steps, batch_size, grad_accum, validation_interval, quantization, optimizer_type, te_8bit,
push, hub_id, flavor, timeout,
profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None,
):
if oauth_token is None or profile is None:
return "❌ Please **sign in with Hugging Face** first (top-right).", "", ""
from_hub = dataset_source == "Hugging Face dataset"
if from_hub and not (dataset_repo or "").strip():
return "❌ Enter a Hub dataset repo id (e.g. you/my-dataset).", "", ""
if not from_hub and not files:
return "❌ Upload at least one video.", "", ""
try:
jobs.parse_resolution(resolution)
except ValueError as e:
return f"❌ {e}", "", ""
if push and not hub_id.strip():
return "❌ Set a Hub model id (e.g. you/my-lora) or disable push.", "", ""
params = {
"run_name": run_name or "ltx2-lora", "mode": mode, "resolution": resolution.strip(),
"rank": rank, "alpha": alpha, "learning_rate": lr, "steps": steps, "batch_size": batch_size,
"gradient_accumulation_steps": grad_accum, "validation_interval": validation_interval,
"quantization": None if quantization in ("none", None) else quantization,
"optimizer_type": optimizer_type, "load_text_encoder_in_8bit": bool(te_8bit),
"push_to_hub": bool(push), "hub_model_id": hub_id.strip(), "hf_token": oauth_token.token,
"caption_all": caption_all or "", "trigger_word": trigger_word or "",
"dataset_repo": dataset_repo.strip() if from_hub else "", "seed": 42,
}
try:
res = jobs.submit(params, [] if from_hub else [f for f in files], flavor=flavor, timeout=timeout)
except Exception as e: # noqa: BLE001
return f"❌ Submission failed: {e}", "", ""
status = f"✅ Job submitted on **{flavor}**, running as **{profile.username}**."
link = ""
if res["url"]:
link = f"**Job:** [{res['job_id']}]({res['url']}) \n**Bucket:** `{res['bucket']}`"
elif res["job_id"]:
link = f"**Job id:** `{res['job_id']}` \n**Bucket:** `{res['bucket']}`"
return status, link, res["log"]
def refresh(job_id, oauth_token: gr.OAuthToken | None):
if not job_id.strip():
return "Enter a job id.", ""
token = oauth_token.token if oauth_token else ""
st = jobs.job_status(job_id.strip(), token)
logs = jobs.job_logs(job_id.strip(), token)
return f"**Status:** `{st}`", logs[-MAX_LOG:] if len(logs) > MAX_LOG else logs
def do_validate(files, dataset_source, dataset_repo, mode, caption_all,
oauth_token: gr.OAuthToken | None):
from_hub = dataset_source == "Hugging Face dataset"
token = oauth_token.token if oauth_token else ""
try:
return jobs.validate_dataset(mode, from_hub, dataset_repo, files, caption_all or "", token)
except Exception as e: # noqa: BLE001
return f"❌ Validation error: {e}"
def _extract_id(link_md: str) -> str:
m = re.search(r"\[([^\]]+)\]\(http", link_md or "")
return m.group(1) if m else ""
with gr.Blocks(title="LTX-2.3 LoRA Trainer") as demo:
# ── hero ───────────────────────────────────────────────────────────────
with gr.Row(equal_height=True):
gr.Markdown(
"# ✦ LTX-2.3 LoRA Trainer\n"
"Train a LoRA / IC-LoRA on your own videos — runs on **HF Jobs**, pushed to your Hub.",
elem_id="hero",
)
gr.LoginButton(scale=0, min_width=200)
banner = gr.Markdown(elem_id="banner")
with gr.Tabs():
# ----------------------------------------------------------------- 1 · Dataset
with gr.Tab("1 · Dataset"):
mode = gr.Dropdown(MODE_KEYS, value=MODE_KEYS[0], label="Training mode")
gr.Markdown(MODE_HELP)
dataset_source = gr.Radio(
["Upload videos", "Hugging Face dataset"], value="Upload videos", label="Dataset source",
)
with gr.Group(visible=True, elem_classes=["flat-group"]) as upload_group:
files = gr.File(
label="Videos + (IC-LoRA) *_reference videos + optional per-clip .txt captions, or a .zip",
file_count="multiple", file_types=["video", ".zip", ".txt"], height=200,
elem_id="ds-upload",
)
gr.Markdown(UPLOAD_HELP)
caption_all = gr.Textbox(
label="Same caption for all clips (used only where no .txt is provided)",
lines=2, placeholder="a cinematic shot of …",
)
trigger_word = gr.Textbox(
label="Trigger word (optional)", placeholder="e.g. TOK or my-style",
info="Prepended to every caption (uploads only — Hub-dataset captions are left as-is).",
)
with gr.Group(visible=False, elem_classes=["flat-group"]) as hub_group:
dataset_repo = gr.Textbox(label="Hub dataset repo id", placeholder="username/my-dataset")
gr.Markdown(HUB_HELP)
validate_btn = gr.Button("✓ Validate dataset", variant="secondary")
dataset_status = gr.Markdown("")
# ----------------------------------------------------------------- 2 · Training
with gr.Tab("2 · Training"):
gr.Markdown("⚙️ *These defaults were auto-picked from your **mode** and **performance "
"profile** — tweak anything you like.*")
resolution = gr.Dropdown(
choices=RESOLUTION_PRESETS, value="768x512x49", allow_custom_value=True,
label="Resolution (W×H×F)",
info="Pick a preset or type your own. W,H divisible by 32 · frames F satisfy F % 8 == 1.",
)
perf_profile = gr.Radio(
[PROFILE_QUALITY, PROFILE_LOWVRAM, PROFILE_CUSTOM], value=PROFILE_QUALITY,
label="Performance profile",
info="Bundles the memory recipe. Quality = 80 GB bf16 · Low-VRAM = int8 + 8-bit · Custom = manual.",
)
with gr.Row():
rank = gr.Number(label="LoRA rank", value=32, precision=0,
info="Recommended 32 (16 for Low-VRAM). Range 8–128.")
alpha = gr.Number(label="LoRA alpha", value=32, precision=0,
info="Keep equal to rank.")
with gr.Row():
lr = gr.Number(label="Learning rate", value=_FIRST_REC["learning_rate"],
info="Auto-set per mode (IC-LoRA 2e-4 · T2V/I2V 1e-4); edit to override.")
steps = gr.Number(label="Steps", value=_FIRST_REC["steps"], precision=0,
info="Auto-set per mode (IC-LoRA 3000 · T2V/I2V 2000); edit to override.")
with gr.Accordion("Advanced", open=False):
with gr.Row():
batch_size = gr.Number(label="Batch size", value=1, precision=0,
info="Keep at 1 (required for multi-resolution datasets).")
grad_accum = gr.Number(label="Grad accumulation", value=1, precision=0,
info="Raises effective batch size without more VRAM.")
validation_interval = gr.Number(
label="Validate every N steps", value=250, precision=0,
info="Generate a validation sample every N steps (clamped to ≤ total steps). "
"Lower = more mid-run previews but longer jobs. Inference uses 30 steps.",
)
# Manual memory knobs — shown only under the Custom profile.
with gr.Group(visible=False) as manual_mem_group:
with gr.Row():
quantization = gr.Dropdown(
["none", "int8-quanto", "fp8-quanto"], value="none", label="Quantization",
info="a100-large (80 GB) fits 22B in bf16 — quantize only on smaller GPUs.",
)
optimizer_type = gr.Dropdown(["adamw", "adamw8bit"], value="adamw", label="Optimizer")
te_8bit = gr.Checkbox(label="Load text encoder in 8-bit", value=False)
# ----------------------------------------------------------------- 3 · Launch
with gr.Tab("3 · Launch"):
run_name = gr.Textbox(label="Run name", value="ltx2-ic-lora")
push = gr.Checkbox(label="Push the trained LoRA to my Hub", value=True)
hub_id = gr.Textbox(label="Hub model id", placeholder="username/my-lora")
with gr.Row():
flavor = gr.Dropdown(FLAVORS, value=jobs.DEFAULT_FLAVOR, label="GPU flavor",
info="Auto-set from your profile (Quality → rtx-pro-6000 · Low-VRAM → l40sx1); "
"change it freely. See the guide below.")
timeout = gr.Textbox(label="Timeout", value="6h",
info="Max job runtime (e.g. 4h). First run spends ~minutes downloading the model.")
with gr.Accordion("GPU guide — which flavor fits when", open=False):
gr.Markdown(FLAVOR_GUIDE)
submit_btn = gr.Button("🚀 Submit training job", variant="primary", size="lg")
status = gr.Markdown("")
joblink = gr.Markdown("")
# ----------------------------------------------------------------- 4 · Monitor
with gr.Tab("4 · Monitor"):
with gr.Row():
job_id = gr.Textbox(label="Job id", scale=3)
refresh_btn = gr.Button("🔄 Refresh", scale=1)
mon_status = gr.Markdown("")
mon_logs = gr.Textbox(label="Job logs", lines=20, autoscroll=True, max_lines=20)
sublog = gr.Textbox(visible=False)
# dirty flags: flipped by a field's `.input` (user typing only), so programmatic auto-fills
# (mode/profile/flavor cascades) never overwrite values the user has manually edited.
lr_dirty = gr.State(False)
steps_dirty = gr.State(False)
rank_dirty = gr.State(False)
alpha_dirty = gr.State(False)
profile_dirty = gr.State(False)
flavor_dirty = gr.State(False)
demo.load(_signin_state, inputs=None, outputs=[banner, hub_id])
dataset_source.change(_toggle_source, inputs=dataset_source, outputs=[upload_group, hub_group])
validate_btn.click(do_validate,
inputs=[files, dataset_source, dataset_repo, mode, caption_all],
outputs=dataset_status)
lr.input(lambda: True, outputs=lr_dirty)
steps.input(lambda: True, outputs=steps_dirty)
rank.input(lambda: True, outputs=rank_dirty)
alpha.input(lambda: True, outputs=alpha_dirty)
mode.change(_apply_mode_defaults, inputs=[mode, lr_dirty, steps_dirty], outputs=[lr, steps])
# #3 performance profile ↔ flavor: a user-picked profile/flavor marks it dirty; a profile change
# cascades the recipe + a matching flavor (unless flavor was hand-set); a flavor change suggests
# a profile (unless profile was hand-set). Dirty flags prevent the two from ping-ponging.
perf_profile.input(lambda: True, outputs=profile_dirty)
flavor.input(lambda: True, outputs=flavor_dirty)
perf_profile.change(_apply_profile, inputs=[perf_profile, rank_dirty, alpha_dirty, flavor_dirty],
outputs=[quantization, optimizer_type, te_8bit, rank, alpha, manual_mem_group, flavor])
flavor.change(_suggest_profile, inputs=[flavor, profile_dirty], outputs=perf_profile)
submit_btn.click(
submit_job,
inputs=[files, caption_all, trigger_word, dataset_source, dataset_repo, run_name, mode, resolution,
rank, alpha, lr, steps, batch_size, grad_accum, validation_interval, quantization,
optimizer_type, te_8bit, push, hub_id, flavor, timeout],
outputs=[status, joblink, sublog],
).then(_extract_id, inputs=joblink, outputs=job_id)
refresh_btn.click(refresh, inputs=[job_id], outputs=[mon_status, mon_logs])
if __name__ == "__main__":
demo.queue(default_concurrency_limit=2).launch(theme=THEME, css=CSS)