"""LTX-2.3 LoRA Trainer — HF Space (HF Jobs + buckets). Sign in with Hugging Face, provide a dataset (upload videos or point to a Hub dataset), set hyperparameters, and submit a training job to HF Jobs. The job reproduces the trainer env from the lockfile, trains a LoRA / IC-LoRA for LTX-2.3, and pushes it to your Hub repo. Everything runs under the *signed-in user's* account — no pasted tokens. Runs on `cpu-basic` — the Space only submits + monitors jobs (no GPU, no torch here). """ from __future__ import annotations import re import gradio as gr import jobs # Single-GPU flavors that fit the 22B model. The v2 trainer pins torch 2.9.1/cu128, which has # Blackwell (sm_120) kernels, so rtx-pro-6000 works here. Multi-GPU flavors omitted (single-GPU job). FLAVORS = ["l40sx1", "a100-large", "rtx-pro-6000", "h200"] FLAVOR_GUIDE = """**Which GPU?** You're billed per-minute of actual runtime. | Flavor | VRAM | $/h | Best for | |---|---|---|---| | `l40sx1` | 48 GB | $1.80 | cheapest — pair with the **Low-VRAM** profile (int8) | | `a100-large` | 80 GB | $2.50 | fits the 22B model in bf16 | | `rtx-pro-6000` | 96 GB | $2.75 | **recommended (default)** — Blackwell, more headroom | | `h200` | 141 GB | $5.00 | fastest — best for higher resolution / longer clips | """ MAX_LOG = 60_000 MODE_KEYS = list(jobs.MODES.keys()) # Resolution presets (label → "WxHxF"); grounded in configs/*.yaml + dataset-preparation.md. # Editable dropdown — users can also type a custom W×H×F (validated on submit: W,H÷32, F%8==1). RESOLUTION_PRESETS = [ ("512×320×25 · quick smoke test", "512x320x25"), ("768×512×49 · balanced (default)", "768x512x49"), ("960×544×49 · higher quality", "960x544x49"), ("512×512×81 · square, longer", "512x512x81"), ("576×576×49 · low-VRAM", "576x576x49"), ] THEME = gr.themes.Base( primary_hue=gr.themes.colors.orange, secondary_hue=gr.themes.colors.amber, neutral_hue=gr.themes.colors.stone, font=[gr.themes.GoogleFont("Space Grotesk"), "ui-sans-serif", "system-ui", "sans-serif"], radius_size=gr.themes.sizes.radius_lg, ).set( body_background_fill="#0b0b0d", body_background_fill_dark="#0b0b0d", body_text_color="#ECEAE6", body_text_color_dark="#ECEAE6", body_text_color_subdued="#b3aea4", body_text_color_subdued_dark="#b3aea4", background_fill_primary="#0b0b0d", background_fill_primary_dark="#0b0b0d", background_fill_secondary="rgba(255,255,255,0.03)", background_fill_secondary_dark="rgba(255,255,255,0.03)", block_background_fill="rgba(255,255,255,0.04)", block_background_fill_dark="rgba(255,255,255,0.04)", block_border_color="rgba(255,255,255,0.08)", block_border_color_dark="rgba(255,255,255,0.08)", block_label_text_color="#cfcabf", block_label_text_color_dark="#cfcabf", block_info_text_color="#c4bfb5", block_info_text_color_dark="#c4bfb5", input_placeholder_color="#a39e94", input_placeholder_color_dark="#a39e94", block_title_text_color="#ECEAE6", block_title_text_color_dark="#ECEAE6", border_color_primary="rgba(255,255,255,0.10)", border_color_primary_dark="rgba(255,255,255,0.10)", input_background_fill="rgba(255,255,255,0.04)", input_background_fill_dark="rgba(255,255,255,0.04)", input_border_color="rgba(255,255,255,0.10)", input_border_color_dark="rgba(255,255,255,0.10)", button_primary_background_fill="linear-gradient(92deg,#ff8a3d,#ffc04d)", button_primary_background_fill_dark="linear-gradient(92deg,#ff8a3d,#ffc04d)", button_primary_background_fill_hover="linear-gradient(92deg,#ff9b54,#ffcc66)", button_primary_text_color="#1c1206", button_primary_text_color_dark="#1c1206", button_secondary_background_fill="rgba(255,255,255,0.06)", button_secondary_background_fill_dark="rgba(255,255,255,0.06)", button_secondary_text_color="#ECEAE6", button_secondary_text_color_dark="#ECEAE6", color_accent_soft="rgba(255,160,60,0.14)", ) CSS = """ .gradio-container { background: radial-gradient(1100px 480px at 50% -8%, rgba(255,150,60,0.11), transparent 60%), radial-gradient(820px 460px at 92% 2%, rgba(255,205,90,0.06), transparent 55%), #0b0b0d !important; } #hero h1 {margin:0; font-size:1.95rem; letter-spacing:-.01em; font-weight:700;} #hero p {margin:.3rem 0 0; color:#9b968e;} #banner {border-radius:12px; padding:10px 14px; font-size:.95rem; background:rgba(255,255,255,0.04); border:1px solid rgba(255,255,255,0.08);} /* glass card around each step */ .tabitem {background:rgba(255,255,255,0.035) !important; border:1px solid rgba(255,255,255,0.08) !important; border-radius:16px !important; backdrop-filter:blur(8px); padding:18px !important;} /* segmented pill tabs */ .tab-nav {border:none !important; gap:6px; margin-bottom:10px;} .tab-nav button {border:1px solid rgba(255,255,255,0.08) !important; border-radius:999px !important; padding:6px 16px !important; background:rgba(255,255,255,0.03) !important; color:#bdb8b0 !important;} .tab-nav button.selected {background:linear-gradient(92deg,#ff8a3d,#ffc04d) !important; color:#1c1206 !important; border-color:transparent !important; font-weight:600;} button.primary {box-shadow:0 6px 22px rgba(255,140,50,0.28);} /* readable placeholders + helper/info text on the dark inputs */ .gradio-container input::placeholder, .gradio-container textarea::placeholder {color:#b3ada3 !important; opacity:1 !important;} .gradio-container .info, .gradio-container [data-testid="block-info"] {color:#c4bfb5 !important;} /* flatten the upload/hub groups so they don't stack a second lighter panel over the tab card */ .flat-group {background:transparent !important; border:none !important; box-shadow:none !important; padding:0 !important;} /* file drop zone: a single input-shade panel (kill the lighter inner button) + readable text */ #ds-upload {background:rgba(255,255,255,0.04) !important;} #ds-upload * {background-color:transparent !important; color:#cfcabf !important;} footer {visibility:hidden;} """ MODE_HELP = ( "**IC-LoRA (in-context control)** learns from *pairs* — a target clip `clip.mp4` and its " "control video `clip_reference.mp4` (depth, pose, edges, an inpainting-masked version, …), " "matched by filename. **T2V / I2V** use only the target clips." ) # Collapsed into a click-to-expand disclosure so it stays out of the way until needed. UPLOAD_HELP = ( "
ℹ️ Captioning tips\n\n" "- Drop a clip.txt next to each target clip for a per-clip caption. " "The “same caption for all” box is used only where a clip has no .txt.\n" "- LTX-2.3 likes long, detailed, chronological captions (~200 words): describe motion, " "camera, lighting, and any audio.\n" "- The trigger word is prepended to every caption so you can invoke the LoRA at inference.\n" "- For IC-LoRA, also upload each clip_reference.* (matched by filename).\n" "
" ) HUB_HELP = ( "Point to a Hub **dataset** repo containing a trainer-format `dataset.json` " "(`media_path` + `caption`, plus `reference_video` for IC-LoRA) and its media. " "The job downloads it directly." ) def _signin_state(profile: gr.OAuthProfile | None): if profile is None: return ( "🔒 **You're not signed in.** Use **Sign in with Hugging Face** (top-right) to start — " "your dataset, the training job, and the resulting LoRA all run under **your** account " "and billing.", gr.update(), ) return ( f"✅ Signed in as **{profile.username}** — jobs, buckets and the pushed LoRA live under your account.", gr.update(value=f"{profile.username}/ltx2-lora"), ) def _toggle_source(source: str): upload = source == "Upload videos" return gr.update(visible=upload), gr.update(visible=not upload) def _apply_mode_defaults(mode: str, lr_dirty: bool, steps_dirty: bool): """On mode change, fill LR/Steps with the mode's recommended recipe — but only for fields the user hasn't manually edited (dirty flags). Returns updates for (lr, steps).""" rec = jobs.MODES[mode].get("recommended", {}) lr_u = gr.update() if lr_dirty else gr.update(value=rec.get("learning_rate")) steps_u = gr.update() if steps_dirty else gr.update(value=rec.get("steps")) return lr_u, steps_u _FIRST_REC = jobs.MODES[MODE_KEYS[0]]["recommended"] # IC-LoRA recipe for initial render # Performance profiles bundle the memory recipe (configs/*.yaml). Quality = 80 GB bf16; # Low-VRAM = the t2v_lora_low_vram.yaml recipe (~32 GB). Custom = manual control. PROFILE_QUALITY = "Quality (80 GB)" PROFILE_LOWVRAM = "Low-VRAM (≤32 GB)" PROFILE_CUSTOM = "Custom" PROFILES = { PROFILE_QUALITY: {"quantization": "none", "optimizer_type": "adamw", "te_8bit": False, "rank": 32, "alpha": 32, "flavor": "rtx-pro-6000"}, PROFILE_LOWVRAM: {"quantization": "int8-quanto", "optimizer_type": "adamw8bit", "te_8bit": True, "rank": 16, "alpha": 16, "flavor": "l40sx1"}, } def _apply_profile(profile: str, rank_dirty: bool, alpha_dirty: bool, flavor_dirty: bool): """Cascade a profile to (quantization, optimizer, te_8bit, rank, alpha, manual-group visibility, flavor). rank/alpha/flavor respect dirty flags (preserve manual edits); Custom reveals manual controls and leaves flavor alone.""" if profile == PROFILE_CUSTOM: return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=True), gr.update()) p = PROFILES[profile] return ( gr.update(value=p["quantization"]), gr.update(value=p["optimizer_type"]), gr.update(value=p["te_8bit"]), gr.update() if rank_dirty else gr.update(value=p["rank"]), gr.update() if alpha_dirty else gr.update(value=p["alpha"]), gr.update(visible=False), gr.update() if flavor_dirty else gr.update(value=p["flavor"]), ) def _suggest_profile(flavor: str, profile_dirty: bool): """Suggest a profile from the GPU flavor — unless the user has manually chosen one.""" if profile_dirty: return gr.update() return gr.update(value=PROFILE_LOWVRAM if flavor.startswith("l40s") else PROFILE_QUALITY) def submit_job( files, caption_all, trigger_word, dataset_source, dataset_repo, run_name, mode, resolution, rank, alpha, lr, steps, batch_size, grad_accum, validation_interval, quantization, optimizer_type, te_8bit, push, hub_id, flavor, timeout, profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, ): if oauth_token is None or profile is None: return "❌ Please **sign in with Hugging Face** first (top-right).", "", "" from_hub = dataset_source == "Hugging Face dataset" if from_hub and not (dataset_repo or "").strip(): return "❌ Enter a Hub dataset repo id (e.g. you/my-dataset).", "", "" if not from_hub and not files: return "❌ Upload at least one video.", "", "" try: jobs.parse_resolution(resolution) except ValueError as e: return f"❌ {e}", "", "" if push and not hub_id.strip(): return "❌ Set a Hub model id (e.g. you/my-lora) or disable push.", "", "" params = { "run_name": run_name or "ltx2-lora", "mode": mode, "resolution": resolution.strip(), "rank": rank, "alpha": alpha, "learning_rate": lr, "steps": steps, "batch_size": batch_size, "gradient_accumulation_steps": grad_accum, "validation_interval": validation_interval, "quantization": None if quantization in ("none", None) else quantization, "optimizer_type": optimizer_type, "load_text_encoder_in_8bit": bool(te_8bit), "push_to_hub": bool(push), "hub_model_id": hub_id.strip(), "hf_token": oauth_token.token, "caption_all": caption_all or "", "trigger_word": trigger_word or "", "dataset_repo": dataset_repo.strip() if from_hub else "", "seed": 42, } try: res = jobs.submit(params, [] if from_hub else [f for f in files], flavor=flavor, timeout=timeout) except Exception as e: # noqa: BLE001 return f"❌ Submission failed: {e}", "", "" status = f"✅ Job submitted on **{flavor}**, running as **{profile.username}**." link = "" if res["url"]: link = f"**Job:** [{res['job_id']}]({res['url']}) \n**Bucket:** `{res['bucket']}`" elif res["job_id"]: link = f"**Job id:** `{res['job_id']}` \n**Bucket:** `{res['bucket']}`" return status, link, res["log"] def refresh(job_id, oauth_token: gr.OAuthToken | None): if not job_id.strip(): return "Enter a job id.", "" token = oauth_token.token if oauth_token else "" st = jobs.job_status(job_id.strip(), token) logs = jobs.job_logs(job_id.strip(), token) return f"**Status:** `{st}`", logs[-MAX_LOG:] if len(logs) > MAX_LOG else logs def do_validate(files, dataset_source, dataset_repo, mode, caption_all, oauth_token: gr.OAuthToken | None): from_hub = dataset_source == "Hugging Face dataset" token = oauth_token.token if oauth_token else "" try: return jobs.validate_dataset(mode, from_hub, dataset_repo, files, caption_all or "", token) except Exception as e: # noqa: BLE001 return f"❌ Validation error: {e}" def _extract_id(link_md: str) -> str: m = re.search(r"\[([^\]]+)\]\(http", link_md or "") return m.group(1) if m else "" with gr.Blocks(title="LTX-2.3 LoRA Trainer") as demo: # ── hero ─────────────────────────────────────────────────────────────── with gr.Row(equal_height=True): gr.Markdown( "# ✦ LTX-2.3 LoRA Trainer\n" "Train a LoRA / IC-LoRA on your own videos — runs on **HF Jobs**, pushed to your Hub.", elem_id="hero", ) gr.LoginButton(scale=0, min_width=200) banner = gr.Markdown(elem_id="banner") with gr.Tabs(): # ----------------------------------------------------------------- 1 · Dataset with gr.Tab("1 · Dataset"): mode = gr.Dropdown(MODE_KEYS, value=MODE_KEYS[0], label="Training mode") gr.Markdown(MODE_HELP) dataset_source = gr.Radio( ["Upload videos", "Hugging Face dataset"], value="Upload videos", label="Dataset source", ) with gr.Group(visible=True, elem_classes=["flat-group"]) as upload_group: files = gr.File( label="Videos + (IC-LoRA) *_reference videos + optional per-clip .txt captions, or a .zip", file_count="multiple", file_types=["video", ".zip", ".txt"], height=200, elem_id="ds-upload", ) gr.Markdown(UPLOAD_HELP) caption_all = gr.Textbox( label="Same caption for all clips (used only where no .txt is provided)", lines=2, placeholder="a cinematic shot of …", ) trigger_word = gr.Textbox( label="Trigger word (optional)", placeholder="e.g. TOK or my-style", info="Prepended to every caption (uploads only — Hub-dataset captions are left as-is).", ) with gr.Group(visible=False, elem_classes=["flat-group"]) as hub_group: dataset_repo = gr.Textbox(label="Hub dataset repo id", placeholder="username/my-dataset") gr.Markdown(HUB_HELP) validate_btn = gr.Button("✓ Validate dataset", variant="secondary") dataset_status = gr.Markdown("") # ----------------------------------------------------------------- 2 · Training with gr.Tab("2 · Training"): gr.Markdown("⚙️ *These defaults were auto-picked from your **mode** and **performance " "profile** — tweak anything you like.*") resolution = gr.Dropdown( choices=RESOLUTION_PRESETS, value="768x512x49", allow_custom_value=True, label="Resolution (W×H×F)", info="Pick a preset or type your own. W,H divisible by 32 · frames F satisfy F % 8 == 1.", ) perf_profile = gr.Radio( [PROFILE_QUALITY, PROFILE_LOWVRAM, PROFILE_CUSTOM], value=PROFILE_QUALITY, label="Performance profile", info="Bundles the memory recipe. Quality = 80 GB bf16 · Low-VRAM = int8 + 8-bit · Custom = manual.", ) with gr.Row(): rank = gr.Number(label="LoRA rank", value=32, precision=0, info="Recommended 32 (16 for Low-VRAM). Range 8–128.") alpha = gr.Number(label="LoRA alpha", value=32, precision=0, info="Keep equal to rank.") with gr.Row(): lr = gr.Number(label="Learning rate", value=_FIRST_REC["learning_rate"], info="Auto-set per mode (IC-LoRA 2e-4 · T2V/I2V 1e-4); edit to override.") steps = gr.Number(label="Steps", value=_FIRST_REC["steps"], precision=0, info="Auto-set per mode (IC-LoRA 3000 · T2V/I2V 2000); edit to override.") with gr.Accordion("Advanced", open=False): with gr.Row(): batch_size = gr.Number(label="Batch size", value=1, precision=0, info="Keep at 1 (required for multi-resolution datasets).") grad_accum = gr.Number(label="Grad accumulation", value=1, precision=0, info="Raises effective batch size without more VRAM.") validation_interval = gr.Number( label="Validate every N steps", value=250, precision=0, info="Generate a validation sample every N steps (clamped to ≤ total steps). " "Lower = more mid-run previews but longer jobs. Inference uses 30 steps.", ) # Manual memory knobs — shown only under the Custom profile. with gr.Group(visible=False) as manual_mem_group: with gr.Row(): quantization = gr.Dropdown( ["none", "int8-quanto", "fp8-quanto"], value="none", label="Quantization", info="a100-large (80 GB) fits 22B in bf16 — quantize only on smaller GPUs.", ) optimizer_type = gr.Dropdown(["adamw", "adamw8bit"], value="adamw", label="Optimizer") te_8bit = gr.Checkbox(label="Load text encoder in 8-bit", value=False) # ----------------------------------------------------------------- 3 · Launch with gr.Tab("3 · Launch"): run_name = gr.Textbox(label="Run name", value="ltx2-ic-lora") push = gr.Checkbox(label="Push the trained LoRA to my Hub", value=True) hub_id = gr.Textbox(label="Hub model id", placeholder="username/my-lora") with gr.Row(): flavor = gr.Dropdown(FLAVORS, value=jobs.DEFAULT_FLAVOR, label="GPU flavor", info="Auto-set from your profile (Quality → rtx-pro-6000 · Low-VRAM → l40sx1); " "change it freely. See the guide below.") timeout = gr.Textbox(label="Timeout", value="6h", info="Max job runtime (e.g. 4h). First run spends ~minutes downloading the model.") with gr.Accordion("GPU guide — which flavor fits when", open=False): gr.Markdown(FLAVOR_GUIDE) submit_btn = gr.Button("🚀 Submit training job", variant="primary", size="lg") status = gr.Markdown("") joblink = gr.Markdown("") # ----------------------------------------------------------------- 4 · Monitor with gr.Tab("4 · Monitor"): with gr.Row(): job_id = gr.Textbox(label="Job id", scale=3) refresh_btn = gr.Button("🔄 Refresh", scale=1) mon_status = gr.Markdown("") mon_logs = gr.Textbox(label="Job logs", lines=20, autoscroll=True, max_lines=20) sublog = gr.Textbox(visible=False) # dirty flags: flipped by a field's `.input` (user typing only), so programmatic auto-fills # (mode/profile/flavor cascades) never overwrite values the user has manually edited. lr_dirty = gr.State(False) steps_dirty = gr.State(False) rank_dirty = gr.State(False) alpha_dirty = gr.State(False) profile_dirty = gr.State(False) flavor_dirty = gr.State(False) demo.load(_signin_state, inputs=None, outputs=[banner, hub_id]) dataset_source.change(_toggle_source, inputs=dataset_source, outputs=[upload_group, hub_group]) validate_btn.click(do_validate, inputs=[files, dataset_source, dataset_repo, mode, caption_all], outputs=dataset_status) lr.input(lambda: True, outputs=lr_dirty) steps.input(lambda: True, outputs=steps_dirty) rank.input(lambda: True, outputs=rank_dirty) alpha.input(lambda: True, outputs=alpha_dirty) mode.change(_apply_mode_defaults, inputs=[mode, lr_dirty, steps_dirty], outputs=[lr, steps]) # #3 performance profile ↔ flavor: a user-picked profile/flavor marks it dirty; a profile change # cascades the recipe + a matching flavor (unless flavor was hand-set); a flavor change suggests # a profile (unless profile was hand-set). Dirty flags prevent the two from ping-ponging. perf_profile.input(lambda: True, outputs=profile_dirty) flavor.input(lambda: True, outputs=flavor_dirty) perf_profile.change(_apply_profile, inputs=[perf_profile, rank_dirty, alpha_dirty, flavor_dirty], outputs=[quantization, optimizer_type, te_8bit, rank, alpha, manual_mem_group, flavor]) flavor.change(_suggest_profile, inputs=[flavor, profile_dirty], outputs=perf_profile) submit_btn.click( submit_job, inputs=[files, caption_all, trigger_word, dataset_source, dataset_repo, run_name, mode, resolution, rank, alpha, lr, steps, batch_size, grad_accum, validation_interval, quantization, optimizer_type, te_8bit, push, hub_id, flavor, timeout], outputs=[status, joblink, sublog], ).then(_extract_id, inputs=joblink, outputs=job_id) refresh_btn.click(refresh, inputs=[job_id], outputs=[mon_status, mon_logs]) if __name__ == "__main__": demo.queue(default_concurrency_limit=2).launch(theme=THEME, css=CSS)