| """LTX-2.3 LoRA Trainer — HF Space (HF Jobs + buckets). |
| |
| Sign in with Hugging Face, provide a dataset (upload videos or point to a Hub dataset), set |
| hyperparameters, and submit a training job to HF Jobs. The job reproduces the trainer env from |
| the lockfile, trains a LoRA / IC-LoRA for LTX-2.3, and pushes it to your Hub repo. Everything |
| runs under the *signed-in user's* account — no pasted tokens. |
| |
| Runs on `cpu-basic` — the Space only submits + monitors jobs (no GPU, no torch here). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import re |
|
|
| import gradio as gr |
|
|
| import jobs |
|
|
| |
| |
| FLAVORS = ["l40sx1", "a100-large", "rtx-pro-6000", "h200"] |
| FLAVOR_GUIDE = """**Which GPU?** You're billed per-minute of actual runtime. |
| |
| | Flavor | VRAM | $/h | Best for | |
| |---|---|---|---| |
| | `l40sx1` | 48 GB | $1.80 | cheapest — pair with the **Low-VRAM** profile (int8) | |
| | `a100-large` | 80 GB | $2.50 | fits the 22B model in bf16 | |
| | `rtx-pro-6000` | 96 GB | $2.75 | **recommended (default)** — Blackwell, more headroom | |
| | `h200` | 141 GB | $5.00 | fastest — best for higher resolution / longer clips | |
| """ |
| MAX_LOG = 60_000 |
| MODE_KEYS = list(jobs.MODES.keys()) |
|
|
| |
| |
| RESOLUTION_PRESETS = [ |
| ("512×320×25 · quick smoke test", "512x320x25"), |
| ("768×512×49 · balanced (default)", "768x512x49"), |
| ("960×544×49 · higher quality", "960x544x49"), |
| ("512×512×81 · square, longer", "512x512x81"), |
| ("576×576×49 · low-VRAM", "576x576x49"), |
| ] |
|
|
| THEME = gr.themes.Base( |
| primary_hue=gr.themes.colors.orange, |
| secondary_hue=gr.themes.colors.amber, |
| neutral_hue=gr.themes.colors.stone, |
| font=[gr.themes.GoogleFont("Space Grotesk"), "ui-sans-serif", "system-ui", "sans-serif"], |
| radius_size=gr.themes.sizes.radius_lg, |
| ).set( |
| body_background_fill="#0b0b0d", body_background_fill_dark="#0b0b0d", |
| body_text_color="#ECEAE6", body_text_color_dark="#ECEAE6", |
| body_text_color_subdued="#b3aea4", body_text_color_subdued_dark="#b3aea4", |
| background_fill_primary="#0b0b0d", background_fill_primary_dark="#0b0b0d", |
| background_fill_secondary="rgba(255,255,255,0.03)", background_fill_secondary_dark="rgba(255,255,255,0.03)", |
| block_background_fill="rgba(255,255,255,0.04)", block_background_fill_dark="rgba(255,255,255,0.04)", |
| block_border_color="rgba(255,255,255,0.08)", block_border_color_dark="rgba(255,255,255,0.08)", |
| block_label_text_color="#cfcabf", block_label_text_color_dark="#cfcabf", |
| block_info_text_color="#c4bfb5", block_info_text_color_dark="#c4bfb5", |
| input_placeholder_color="#a39e94", input_placeholder_color_dark="#a39e94", |
| block_title_text_color="#ECEAE6", block_title_text_color_dark="#ECEAE6", |
| border_color_primary="rgba(255,255,255,0.10)", border_color_primary_dark="rgba(255,255,255,0.10)", |
| input_background_fill="rgba(255,255,255,0.04)", input_background_fill_dark="rgba(255,255,255,0.04)", |
| input_border_color="rgba(255,255,255,0.10)", input_border_color_dark="rgba(255,255,255,0.10)", |
| button_primary_background_fill="linear-gradient(92deg,#ff8a3d,#ffc04d)", |
| button_primary_background_fill_dark="linear-gradient(92deg,#ff8a3d,#ffc04d)", |
| button_primary_background_fill_hover="linear-gradient(92deg,#ff9b54,#ffcc66)", |
| button_primary_text_color="#1c1206", button_primary_text_color_dark="#1c1206", |
| button_secondary_background_fill="rgba(255,255,255,0.06)", button_secondary_background_fill_dark="rgba(255,255,255,0.06)", |
| button_secondary_text_color="#ECEAE6", button_secondary_text_color_dark="#ECEAE6", |
| color_accent_soft="rgba(255,160,60,0.14)", |
| ) |
|
|
| CSS = """ |
| .gradio-container { |
| background: |
| radial-gradient(1100px 480px at 50% -8%, rgba(255,150,60,0.11), transparent 60%), |
| radial-gradient(820px 460px at 92% 2%, rgba(255,205,90,0.06), transparent 55%), |
| #0b0b0d !important; |
| } |
| #hero h1 {margin:0; font-size:1.95rem; letter-spacing:-.01em; font-weight:700;} |
| #hero p {margin:.3rem 0 0; color:#9b968e;} |
| #banner {border-radius:12px; padding:10px 14px; font-size:.95rem; |
| background:rgba(255,255,255,0.04); border:1px solid rgba(255,255,255,0.08);} |
| /* glass card around each step */ |
| .tabitem {background:rgba(255,255,255,0.035) !important; border:1px solid rgba(255,255,255,0.08) !important; |
| border-radius:16px !important; backdrop-filter:blur(8px); padding:18px !important;} |
| /* segmented pill tabs */ |
| .tab-nav {border:none !important; gap:6px; margin-bottom:10px;} |
| .tab-nav button {border:1px solid rgba(255,255,255,0.08) !important; border-radius:999px !important; |
| padding:6px 16px !important; background:rgba(255,255,255,0.03) !important; color:#bdb8b0 !important;} |
| .tab-nav button.selected {background:linear-gradient(92deg,#ff8a3d,#ffc04d) !important; |
| color:#1c1206 !important; border-color:transparent !important; font-weight:600;} |
| button.primary {box-shadow:0 6px 22px rgba(255,140,50,0.28);} |
| /* readable placeholders + helper/info text on the dark inputs */ |
| .gradio-container input::placeholder, .gradio-container textarea::placeholder {color:#b3ada3 !important; opacity:1 !important;} |
| .gradio-container .info, .gradio-container [data-testid="block-info"] {color:#c4bfb5 !important;} |
| /* flatten the upload/hub groups so they don't stack a second lighter panel over the tab card */ |
| .flat-group {background:transparent !important; border:none !important; box-shadow:none !important; padding:0 !important;} |
| /* file drop zone: a single input-shade panel (kill the lighter inner button) + readable text */ |
| #ds-upload {background:rgba(255,255,255,0.04) !important;} |
| #ds-upload * {background-color:transparent !important; color:#cfcabf !important;} |
| footer {visibility:hidden;} |
| """ |
|
|
| MODE_HELP = ( |
| "**IC-LoRA (in-context control)** learns from *pairs* — a target clip `clip.mp4` and its " |
| "control video `clip_reference.mp4` (depth, pose, edges, an inpainting-masked version, …), " |
| "matched by filename. **T2V / I2V** use only the target clips." |
| ) |
| |
| UPLOAD_HELP = ( |
| "<details><summary>ℹ️ <b>Captioning tips</b></summary>\n\n" |
| "- Drop a <code>clip.txt</code> next to each target clip for a <b>per-clip caption</b>. " |
| "The “same caption for all” box is used only where a clip has no <code>.txt</code>.\n" |
| "- LTX-2.3 likes <b>long, detailed, chronological</b> captions (~200 words): describe motion, " |
| "camera, lighting, and any audio.\n" |
| "- The <b>trigger word</b> is prepended to every caption so you can invoke the LoRA at inference.\n" |
| "- For <b>IC-LoRA</b>, also upload each <code>clip_reference.*</code> (matched by filename).\n" |
| "</details>" |
| ) |
| HUB_HELP = ( |
| "Point to a Hub **dataset** repo containing a trainer-format `dataset.json` " |
| "(`media_path` + `caption`, plus `reference_video` for IC-LoRA) and its media. " |
| "The job downloads it directly." |
| ) |
|
|
|
|
| def _signin_state(profile: gr.OAuthProfile | None): |
| if profile is None: |
| return ( |
| "🔒 **You're not signed in.** Use **Sign in with Hugging Face** (top-right) to start — " |
| "your dataset, the training job, and the resulting LoRA all run under **your** account " |
| "and billing.", |
| gr.update(), |
| ) |
| return ( |
| f"✅ Signed in as **{profile.username}** — jobs, buckets and the pushed LoRA live under your account.", |
| gr.update(value=f"{profile.username}/ltx2-lora"), |
| ) |
|
|
|
|
| def _toggle_source(source: str): |
| upload = source == "Upload videos" |
| return gr.update(visible=upload), gr.update(visible=not upload) |
|
|
|
|
| def _apply_mode_defaults(mode: str, lr_dirty: bool, steps_dirty: bool): |
| """On mode change, fill LR/Steps with the mode's recommended recipe — but only for fields |
| the user hasn't manually edited (dirty flags). Returns updates for (lr, steps).""" |
| rec = jobs.MODES[mode].get("recommended", {}) |
| lr_u = gr.update() if lr_dirty else gr.update(value=rec.get("learning_rate")) |
| steps_u = gr.update() if steps_dirty else gr.update(value=rec.get("steps")) |
| return lr_u, steps_u |
|
|
|
|
| _FIRST_REC = jobs.MODES[MODE_KEYS[0]]["recommended"] |
|
|
| |
| |
| PROFILE_QUALITY = "Quality (80 GB)" |
| PROFILE_LOWVRAM = "Low-VRAM (≤32 GB)" |
| PROFILE_CUSTOM = "Custom" |
| PROFILES = { |
| PROFILE_QUALITY: {"quantization": "none", "optimizer_type": "adamw", "te_8bit": False, |
| "rank": 32, "alpha": 32, "flavor": "rtx-pro-6000"}, |
| PROFILE_LOWVRAM: {"quantization": "int8-quanto", "optimizer_type": "adamw8bit", "te_8bit": True, |
| "rank": 16, "alpha": 16, "flavor": "l40sx1"}, |
| } |
|
|
|
|
| def _apply_profile(profile: str, rank_dirty: bool, alpha_dirty: bool, flavor_dirty: bool): |
| """Cascade a profile to (quantization, optimizer, te_8bit, rank, alpha, manual-group visibility, |
| flavor). rank/alpha/flavor respect dirty flags (preserve manual edits); Custom reveals manual |
| controls and leaves flavor alone.""" |
| if profile == PROFILE_CUSTOM: |
| return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), |
| gr.update(visible=True), gr.update()) |
| p = PROFILES[profile] |
| return ( |
| gr.update(value=p["quantization"]), gr.update(value=p["optimizer_type"]), gr.update(value=p["te_8bit"]), |
| gr.update() if rank_dirty else gr.update(value=p["rank"]), |
| gr.update() if alpha_dirty else gr.update(value=p["alpha"]), |
| gr.update(visible=False), |
| gr.update() if flavor_dirty else gr.update(value=p["flavor"]), |
| ) |
|
|
|
|
| def _suggest_profile(flavor: str, profile_dirty: bool): |
| """Suggest a profile from the GPU flavor — unless the user has manually chosen one.""" |
| if profile_dirty: |
| return gr.update() |
| return gr.update(value=PROFILE_LOWVRAM if flavor.startswith("l40s") else PROFILE_QUALITY) |
|
|
|
|
| def submit_job( |
| files, caption_all, trigger_word, dataset_source, dataset_repo, run_name, mode, resolution, rank, |
| alpha, lr, steps, batch_size, grad_accum, validation_interval, quantization, optimizer_type, te_8bit, |
| push, hub_id, flavor, timeout, |
| profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None, |
| ): |
| if oauth_token is None or profile is None: |
| return "❌ Please **sign in with Hugging Face** first (top-right).", "", "" |
| from_hub = dataset_source == "Hugging Face dataset" |
| if from_hub and not (dataset_repo or "").strip(): |
| return "❌ Enter a Hub dataset repo id (e.g. you/my-dataset).", "", "" |
| if not from_hub and not files: |
| return "❌ Upload at least one video.", "", "" |
| try: |
| jobs.parse_resolution(resolution) |
| except ValueError as e: |
| return f"❌ {e}", "", "" |
| if push and not hub_id.strip(): |
| return "❌ Set a Hub model id (e.g. you/my-lora) or disable push.", "", "" |
|
|
| params = { |
| "run_name": run_name or "ltx2-lora", "mode": mode, "resolution": resolution.strip(), |
| "rank": rank, "alpha": alpha, "learning_rate": lr, "steps": steps, "batch_size": batch_size, |
| "gradient_accumulation_steps": grad_accum, "validation_interval": validation_interval, |
| "quantization": None if quantization in ("none", None) else quantization, |
| "optimizer_type": optimizer_type, "load_text_encoder_in_8bit": bool(te_8bit), |
| "push_to_hub": bool(push), "hub_model_id": hub_id.strip(), "hf_token": oauth_token.token, |
| "caption_all": caption_all or "", "trigger_word": trigger_word or "", |
| "dataset_repo": dataset_repo.strip() if from_hub else "", "seed": 42, |
| } |
| try: |
| res = jobs.submit(params, [] if from_hub else [f for f in files], flavor=flavor, timeout=timeout) |
| except Exception as e: |
| return f"❌ Submission failed: {e}", "", "" |
|
|
| status = f"✅ Job submitted on **{flavor}**, running as **{profile.username}**." |
| link = "" |
| if res["url"]: |
| link = f"**Job:** [{res['job_id']}]({res['url']}) \n**Bucket:** `{res['bucket']}`" |
| elif res["job_id"]: |
| link = f"**Job id:** `{res['job_id']}` \n**Bucket:** `{res['bucket']}`" |
| return status, link, res["log"] |
|
|
|
|
| def refresh(job_id, oauth_token: gr.OAuthToken | None): |
| if not job_id.strip(): |
| return "Enter a job id.", "" |
| token = oauth_token.token if oauth_token else "" |
| st = jobs.job_status(job_id.strip(), token) |
| logs = jobs.job_logs(job_id.strip(), token) |
| return f"**Status:** `{st}`", logs[-MAX_LOG:] if len(logs) > MAX_LOG else logs |
|
|
|
|
| def do_validate(files, dataset_source, dataset_repo, mode, caption_all, |
| oauth_token: gr.OAuthToken | None): |
| from_hub = dataset_source == "Hugging Face dataset" |
| token = oauth_token.token if oauth_token else "" |
| try: |
| return jobs.validate_dataset(mode, from_hub, dataset_repo, files, caption_all or "", token) |
| except Exception as e: |
| return f"❌ Validation error: {e}" |
|
|
|
|
| def _extract_id(link_md: str) -> str: |
| m = re.search(r"\[([^\]]+)\]\(http", link_md or "") |
| return m.group(1) if m else "" |
|
|
|
|
| with gr.Blocks(title="LTX-2.3 LoRA Trainer") as demo: |
| |
| with gr.Row(equal_height=True): |
| gr.Markdown( |
| "# ✦ LTX-2.3 LoRA Trainer\n" |
| "Train a LoRA / IC-LoRA on your own videos — runs on **HF Jobs**, pushed to your Hub.", |
| elem_id="hero", |
| ) |
| gr.LoginButton(scale=0, min_width=200) |
|
|
| banner = gr.Markdown(elem_id="banner") |
|
|
| with gr.Tabs(): |
| |
| with gr.Tab("1 · Dataset"): |
| mode = gr.Dropdown(MODE_KEYS, value=MODE_KEYS[0], label="Training mode") |
| gr.Markdown(MODE_HELP) |
| dataset_source = gr.Radio( |
| ["Upload videos", "Hugging Face dataset"], value="Upload videos", label="Dataset source", |
| ) |
| with gr.Group(visible=True, elem_classes=["flat-group"]) as upload_group: |
| files = gr.File( |
| label="Videos + (IC-LoRA) *_reference videos + optional per-clip .txt captions, or a .zip", |
| file_count="multiple", file_types=["video", ".zip", ".txt"], height=200, |
| elem_id="ds-upload", |
| ) |
| gr.Markdown(UPLOAD_HELP) |
| caption_all = gr.Textbox( |
| label="Same caption for all clips (used only where no .txt is provided)", |
| lines=2, placeholder="a cinematic shot of …", |
| ) |
| trigger_word = gr.Textbox( |
| label="Trigger word (optional)", placeholder="e.g. TOK or my-style", |
| info="Prepended to every caption (uploads only — Hub-dataset captions are left as-is).", |
| ) |
| with gr.Group(visible=False, elem_classes=["flat-group"]) as hub_group: |
| dataset_repo = gr.Textbox(label="Hub dataset repo id", placeholder="username/my-dataset") |
| gr.Markdown(HUB_HELP) |
| validate_btn = gr.Button("✓ Validate dataset", variant="secondary") |
| dataset_status = gr.Markdown("") |
|
|
| |
| with gr.Tab("2 · Training"): |
| gr.Markdown("⚙️ *These defaults were auto-picked from your **mode** and **performance " |
| "profile** — tweak anything you like.*") |
| resolution = gr.Dropdown( |
| choices=RESOLUTION_PRESETS, value="768x512x49", allow_custom_value=True, |
| label="Resolution (W×H×F)", |
| info="Pick a preset or type your own. W,H divisible by 32 · frames F satisfy F % 8 == 1.", |
| ) |
| perf_profile = gr.Radio( |
| [PROFILE_QUALITY, PROFILE_LOWVRAM, PROFILE_CUSTOM], value=PROFILE_QUALITY, |
| label="Performance profile", |
| info="Bundles the memory recipe. Quality = 80 GB bf16 · Low-VRAM = int8 + 8-bit · Custom = manual.", |
| ) |
| with gr.Row(): |
| rank = gr.Number(label="LoRA rank", value=32, precision=0, |
| info="Recommended 32 (16 for Low-VRAM). Range 8–128.") |
| alpha = gr.Number(label="LoRA alpha", value=32, precision=0, |
| info="Keep equal to rank.") |
| with gr.Row(): |
| lr = gr.Number(label="Learning rate", value=_FIRST_REC["learning_rate"], |
| info="Auto-set per mode (IC-LoRA 2e-4 · T2V/I2V 1e-4); edit to override.") |
| steps = gr.Number(label="Steps", value=_FIRST_REC["steps"], precision=0, |
| info="Auto-set per mode (IC-LoRA 3000 · T2V/I2V 2000); edit to override.") |
| with gr.Accordion("Advanced", open=False): |
| with gr.Row(): |
| batch_size = gr.Number(label="Batch size", value=1, precision=0, |
| info="Keep at 1 (required for multi-resolution datasets).") |
| grad_accum = gr.Number(label="Grad accumulation", value=1, precision=0, |
| info="Raises effective batch size without more VRAM.") |
| validation_interval = gr.Number( |
| label="Validate every N steps", value=250, precision=0, |
| info="Generate a validation sample every N steps (clamped to ≤ total steps). " |
| "Lower = more mid-run previews but longer jobs. Inference uses 30 steps.", |
| ) |
| |
| with gr.Group(visible=False) as manual_mem_group: |
| with gr.Row(): |
| quantization = gr.Dropdown( |
| ["none", "int8-quanto", "fp8-quanto"], value="none", label="Quantization", |
| info="a100-large (80 GB) fits 22B in bf16 — quantize only on smaller GPUs.", |
| ) |
| optimizer_type = gr.Dropdown(["adamw", "adamw8bit"], value="adamw", label="Optimizer") |
| te_8bit = gr.Checkbox(label="Load text encoder in 8-bit", value=False) |
|
|
| |
| with gr.Tab("3 · Launch"): |
| run_name = gr.Textbox(label="Run name", value="ltx2-ic-lora") |
| push = gr.Checkbox(label="Push the trained LoRA to my Hub", value=True) |
| hub_id = gr.Textbox(label="Hub model id", placeholder="username/my-lora") |
| with gr.Row(): |
| flavor = gr.Dropdown(FLAVORS, value=jobs.DEFAULT_FLAVOR, label="GPU flavor", |
| info="Auto-set from your profile (Quality → rtx-pro-6000 · Low-VRAM → l40sx1); " |
| "change it freely. See the guide below.") |
| timeout = gr.Textbox(label="Timeout", value="6h", |
| info="Max job runtime (e.g. 4h). First run spends ~minutes downloading the model.") |
| with gr.Accordion("GPU guide — which flavor fits when", open=False): |
| gr.Markdown(FLAVOR_GUIDE) |
| submit_btn = gr.Button("🚀 Submit training job", variant="primary", size="lg") |
| status = gr.Markdown("") |
| joblink = gr.Markdown("") |
|
|
| |
| with gr.Tab("4 · Monitor"): |
| with gr.Row(): |
| job_id = gr.Textbox(label="Job id", scale=3) |
| refresh_btn = gr.Button("🔄 Refresh", scale=1) |
| mon_status = gr.Markdown("") |
| mon_logs = gr.Textbox(label="Job logs", lines=20, autoscroll=True, max_lines=20) |
|
|
| sublog = gr.Textbox(visible=False) |
| |
| |
| lr_dirty = gr.State(False) |
| steps_dirty = gr.State(False) |
| rank_dirty = gr.State(False) |
| alpha_dirty = gr.State(False) |
| profile_dirty = gr.State(False) |
| flavor_dirty = gr.State(False) |
|
|
| demo.load(_signin_state, inputs=None, outputs=[banner, hub_id]) |
| dataset_source.change(_toggle_source, inputs=dataset_source, outputs=[upload_group, hub_group]) |
| validate_btn.click(do_validate, |
| inputs=[files, dataset_source, dataset_repo, mode, caption_all], |
| outputs=dataset_status) |
| lr.input(lambda: True, outputs=lr_dirty) |
| steps.input(lambda: True, outputs=steps_dirty) |
| rank.input(lambda: True, outputs=rank_dirty) |
| alpha.input(lambda: True, outputs=alpha_dirty) |
| mode.change(_apply_mode_defaults, inputs=[mode, lr_dirty, steps_dirty], outputs=[lr, steps]) |
| |
| |
| |
| perf_profile.input(lambda: True, outputs=profile_dirty) |
| flavor.input(lambda: True, outputs=flavor_dirty) |
| perf_profile.change(_apply_profile, inputs=[perf_profile, rank_dirty, alpha_dirty, flavor_dirty], |
| outputs=[quantization, optimizer_type, te_8bit, rank, alpha, manual_mem_group, flavor]) |
| flavor.change(_suggest_profile, inputs=[flavor, profile_dirty], outputs=perf_profile) |
|
|
| submit_btn.click( |
| submit_job, |
| inputs=[files, caption_all, trigger_word, dataset_source, dataset_repo, run_name, mode, resolution, |
| rank, alpha, lr, steps, batch_size, grad_accum, validation_interval, quantization, |
| optimizer_type, te_8bit, push, hub_id, flavor, timeout], |
| outputs=[status, joblink, sublog], |
| ).then(_extract_id, inputs=joblink, outputs=job_id) |
| refresh_btn.click(refresh, inputs=[job_id], outputs=[mon_status, mon_logs]) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue(default_concurrency_limit=2).launch(theme=THEME, css=CSS) |
|
|