OAuth login (no pasted tokens) + ai-toolkit-style UI + generic IC-LoRA (paired references, drop Canny)
Browse files- README.md +29 -13
- app.py +132 -61
- jobs.py +51 -22
- requirements.txt +1 -1
README.md
CHANGED
|
@@ -8,6 +8,12 @@ sdk_version: 6.18.0
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
hardware: cpu-basic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
# LTX-2 LoRA Trainer (HF Jobs)
|
|
@@ -15,9 +21,11 @@ hardware: cpu-basic
|
|
| 15 |
Train a **LoRA / IC-LoRA for LTX-2.3** from your own videos, entirely on Hugging Face
|
| 16 |
infrastructure:
|
| 17 |
|
|
|
|
|
|
|
| 18 |
- the **Space** (this app, `cpu-basic`) collects your videos + hyperparameters and submits a job;
|
| 19 |
-
- training runs on **HF Jobs**
|
| 20 |
-
|
| 21 |
- your dataset is staged on **HF buckets** (`hf://buckets/<you>/ltx2-train-<run>`);
|
| 22 |
- the trained LoRA is pushed to the Hub model repo you choose.
|
| 23 |
|
|
@@ -25,21 +33,29 @@ You only pay for the Job's actual GPU runtime.
|
|
| 25 |
|
| 26 |
## Setup
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
| 35 |
## Usage
|
| 36 |
|
| 37 |
-
1.
|
| 38 |
-
2. Pick a mode
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
The job: sync buckets → `uv sync` → download base checkpoint + Gemma →
|
| 42 |
-
|
| 43 |
|
| 44 |
## Notes
|
| 45 |
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
hardware: cpu-basic
|
| 11 |
+
hf_oauth: true
|
| 12 |
+
hf_oauth_scopes:
|
| 13 |
+
- read-repos
|
| 14 |
+
- write-repos
|
| 15 |
+
- manage-repos
|
| 16 |
+
- jobs
|
| 17 |
---
|
| 18 |
|
| 19 |
# LTX-2 LoRA Trainer (HF Jobs)
|
|
|
|
| 21 |
Train a **LoRA / IC-LoRA for LTX-2.3** from your own videos, entirely on Hugging Face
|
| 22 |
infrastructure:
|
| 23 |
|
| 24 |
+
- **Sign in with Hugging Face** — the dataset, the job, and the pushed LoRA all run under
|
| 25 |
+
**your** account and billing (no pasted tokens);
|
| 26 |
- the **Space** (this app, `cpu-basic`) collects your videos + hyperparameters and submits a job;
|
| 27 |
+
- training runs on **HF Jobs**, reproducing the trainer environment from the monorepo
|
| 28 |
+
lockfile (`uv sync --frozen`);
|
| 29 |
- your dataset is staged on **HF buckets** (`hf://buckets/<you>/ltx2-train-<run>`);
|
| 30 |
- the trained LoRA is pushed to the Hub model repo you choose.
|
| 31 |
|
|
|
|
| 33 |
|
| 34 |
## Setup
|
| 35 |
|
| 36 |
+
Just **sign in with Hugging Face** in the app — the OAuth token (scopes: repos + `jobs`) is
|
| 37 |
+
used to create buckets, submit the job, download the gated Gemma encoder, and push the LoRA,
|
| 38 |
+
all as the signed-in user. No Space secret is required.
|
| 39 |
+
|
| 40 |
+
The job pulls the trainer source from the bucket **`LTX_SRC_BUCKET`**
|
| 41 |
+
(default `linoyts/ltx2-trainer-src`). To use your own, upload the monorepo
|
| 42 |
+
(`pyproject.toml`, `uv.lock`, `packages/`) to a bucket and set the `LTX_SRC_BUCKET`
|
| 43 |
+
Space variable.
|
| 44 |
|
| 45 |
## Usage
|
| 46 |
|
| 47 |
+
1. Sign in with Hugging Face.
|
| 48 |
+
2. Pick a mode and upload your videos:
|
| 49 |
+
- **IC-LoRA (in-context control):** upload each target `clip.mp4` **and** its control video
|
| 50 |
+
`clip_reference.mp4` (depth, pose, edges, an inpainting-masked version, …) — matched by
|
| 51 |
+
filename. References are user-supplied; nothing is auto-derived.
|
| 52 |
+
- **T2V / I2V:** upload only the target clips.
|
| 53 |
+
3. Add one caption per line (in filename order of the target clips), set hyperparameters and a
|
| 54 |
+
Hub model id, then **Submit training job**.
|
| 55 |
+
4. Copy the job id into the monitor box and **Refresh** to stream logs.
|
| 56 |
|
| 57 |
+
The job: sync buckets → `uv sync` → download base checkpoint + Gemma → `process_dataset.py`
|
| 58 |
+
→ `train.py` → push LoRA to the Hub.
|
| 59 |
|
| 60 |
## Notes
|
| 61 |
|
app.py
CHANGED
|
@@ -1,41 +1,66 @@
|
|
| 1 |
"""LTX-2 LoRA Trainer — HF Space (HF Jobs + buckets).
|
| 2 |
|
| 3 |
-
|
| 4 |
-
on
|
| 5 |
-
LTX-2.3, and pushes it to your Hub repo. Datasets are
|
|
|
|
| 6 |
|
| 7 |
Runs on `cpu-basic` — the Space only submits + monitors jobs (no GPU, no torch here).
|
| 8 |
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
| 11 |
|
| 12 |
-
import
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
|
| 16 |
import jobs
|
| 17 |
|
| 18 |
-
HEADER = """# 🎬 LTX-2 LoRA Trainer (HF Jobs)
|
| 19 |
-
Train a LoRA / IC-LoRA for **LTX-2.3** from your own videos — training runs on **HF Jobs**
|
| 20 |
-
(a100-large), data is staged on **HF buckets**, and the trained LoRA is pushed to your Hub repo.
|
| 21 |
-
|
| 22 |
-
You only pay for the Job's actual runtime. Set `HF_TOKEN` as a Space secret (or paste a token below).
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
FLAVORS = ["a100-large", "a100x4", "l40sx1", "l40sx4"]
|
| 26 |
MAX_LOG = 60_000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def submit_job(
|
| 36 |
files, captions_text, run_name, mode, resolution, rank, alpha, lr, steps, batch_size,
|
| 37 |
-
grad_accum, quantization, optimizer_type, te_8bit,
|
|
|
|
| 38 |
):
|
|
|
|
|
|
|
| 39 |
if not files:
|
| 40 |
return "❌ Upload at least one video.", "", ""
|
| 41 |
try:
|
|
@@ -43,7 +68,7 @@ def submit_job(
|
|
| 43 |
except ValueError as e:
|
| 44 |
return f"❌ {e}", "", ""
|
| 45 |
if push and not hub_id.strip():
|
| 46 |
-
return "❌ Set a Hub model id (
|
| 47 |
|
| 48 |
params = {
|
| 49 |
"run_name": run_name or "ltx2-lora", "mode": mode, "resolution": resolution.strip(),
|
|
@@ -51,8 +76,8 @@ def submit_job(
|
|
| 51 |
"gradient_accumulation_steps": grad_accum,
|
| 52 |
"quantization": None if quantization in ("none", None) else quantization,
|
| 53 |
"optimizer_type": optimizer_type, "load_text_encoder_in_8bit": bool(te_8bit),
|
| 54 |
-
"
|
| 55 |
-
"
|
| 56 |
"captions": captions_text.splitlines() if captions_text else [], "seed": 42,
|
| 57 |
}
|
| 58 |
try:
|
|
@@ -60,7 +85,7 @@ def submit_job(
|
|
| 60 |
except Exception as e: # noqa: BLE001
|
| 61 |
return f"❌ Submission failed: {e}", "", ""
|
| 62 |
|
| 63 |
-
status = f"✅ Job submitted on **{flavor}**."
|
| 64 |
link = ""
|
| 65 |
if res["url"]:
|
| 66 |
link = f"**Job:** [{res['job_id']}]({res['url']}) \n**Bucket:** `{res['bucket']}`"
|
|
@@ -69,67 +94,113 @@ def submit_job(
|
|
| 69 |
return status, link, res["log"]
|
| 70 |
|
| 71 |
|
| 72 |
-
def refresh(job_id,
|
| 73 |
if not job_id.strip():
|
| 74 |
return "Enter a job id.", ""
|
| 75 |
-
token =
|
| 76 |
st = jobs.job_status(job_id.strip(), token)
|
| 77 |
logs = jobs.job_logs(job_id.strip(), token)
|
| 78 |
-
return f"**Status:** {st}", logs[-MAX_LOG:] if len(logs) > MAX_LOG else logs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
with gr.Blocks(title="LTX-2 LoRA Trainer (HF Jobs)") as demo:
|
| 82 |
-
gr.Markdown(HEADER)
|
| 83 |
with gr.Row():
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
with gr.Row():
|
| 93 |
rank = gr.Number(label="LoRA rank", value=32, precision=0)
|
| 94 |
alpha = gr.Number(label="LoRA alpha", value=32, precision=0)
|
| 95 |
with gr.Row():
|
| 96 |
lr = gr.Number(label="Learning rate", value=2e-4)
|
| 97 |
steps = gr.Number(label="Steps", value=2000, precision=0)
|
| 98 |
-
with gr.
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
with gr.Row():
|
| 111 |
flavor = gr.Dropdown(FLAVORS, value=jobs.DEFAULT_FLAVOR, label="GPU flavor")
|
| 112 |
timeout = gr.Textbox(label="Timeout", value="4h")
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
submit_btn.click(
|
| 126 |
submit_job,
|
| 127 |
inputs=[files, captions, run_name, mode, resolution, rank, alpha, lr, steps, batch_size,
|
| 128 |
-
grad_accum, quantization, optimizer_type, te_8bit,
|
| 129 |
outputs=[status, joblink, sublog],
|
| 130 |
-
).then(
|
| 131 |
-
refresh_btn.click(refresh, inputs=[job_id
|
| 132 |
|
| 133 |
|
| 134 |
if __name__ == "__main__":
|
| 135 |
-
demo.queue(default_concurrency_limit=2).launch()
|
|
|
|
| 1 |
"""LTX-2 LoRA Trainer — HF Space (HF Jobs + buckets).
|
| 2 |
|
| 3 |
+
Sign in with Hugging Face, upload videos + captions, set hyperparameters, and submit a
|
| 4 |
+
training job to HF Jobs. The job runs on a GPU flavor, reproduces the trainer env from the
|
| 5 |
+
lockfile, trains a LoRA / IC-LoRA for LTX-2.3, and pushes it to your Hub repo. Datasets are
|
| 6 |
+
staged on HF buckets. Everything runs under the *signed-in user's* account — no pasted tokens.
|
| 7 |
|
| 8 |
Runs on `cpu-basic` — the Space only submits + monitors jobs (no GPU, no torch here).
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
| 13 |
+
import re
|
| 14 |
|
| 15 |
import gradio as gr
|
| 16 |
|
| 17 |
import jobs
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
FLAVORS = ["a100-large", "a100x4", "l40sx1", "l40sx4"]
|
| 20 |
MAX_LOG = 60_000
|
| 21 |
+
MODE_KEYS = list(jobs.MODES.keys())
|
| 22 |
+
|
| 23 |
+
CSS = """
|
| 24 |
+
#hdr {text-align:left; padding:4px 2px 0 2px;}
|
| 25 |
+
#hdr h1 {margin:0; font-size:1.7rem;}
|
| 26 |
+
#hdr p {margin:.25rem 0 0 0; color:var(--body-text-color-subdued);}
|
| 27 |
+
.section-card {border:1px solid var(--block-border-color); border-radius:14px;
|
| 28 |
+
padding:14px 16px; background:var(--block-background-fill);}
|
| 29 |
+
.step-badge {font-weight:600; color:var(--primary-500);}
|
| 30 |
+
#banner {border-radius:12px; padding:10px 14px; font-size:.95rem;}
|
| 31 |
+
footer {visibility:hidden;}
|
| 32 |
+
"""
|
| 33 |
|
| 34 |
+
THEME = gr.themes.Soft(
|
| 35 |
+
primary_hue=gr.themes.colors.indigo,
|
| 36 |
+
secondary_hue=gr.themes.colors.purple,
|
| 37 |
+
radius_size=gr.themes.sizes.radius_lg,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _signin_state(profile: gr.OAuthProfile | None):
|
| 42 |
+
"""Banner + default hub id, recomputed on load and after sign-in."""
|
| 43 |
+
if profile is None:
|
| 44 |
+
return (
|
| 45 |
+
"🔒 **You're not signed in.** Use **Sign in with Hugging Face** (top-right) to start — "
|
| 46 |
+
"your dataset, the training job, and the resulting LoRA all run under **your** account "
|
| 47 |
+
"and billing. No tokens to paste.",
|
| 48 |
+
gr.update(),
|
| 49 |
+
)
|
| 50 |
+
return (
|
| 51 |
+
f"✅ Signed in as **{profile.username}** — jobs, buckets and the pushed LoRA will live under "
|
| 52 |
+
f"your account.",
|
| 53 |
+
gr.update(value=f"{profile.username}/ltx2-lora"),
|
| 54 |
+
)
|
| 55 |
|
| 56 |
|
| 57 |
def submit_job(
|
| 58 |
files, captions_text, run_name, mode, resolution, rank, alpha, lr, steps, batch_size,
|
| 59 |
+
grad_accum, quantization, optimizer_type, te_8bit, push, hub_id, flavor, timeout,
|
| 60 |
+
profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None,
|
| 61 |
):
|
| 62 |
+
if oauth_token is None or profile is None:
|
| 63 |
+
return "❌ Please **sign in with Hugging Face** first (top-right).", "", ""
|
| 64 |
if not files:
|
| 65 |
return "❌ Upload at least one video.", "", ""
|
| 66 |
try:
|
|
|
|
| 68 |
except ValueError as e:
|
| 69 |
return f"❌ {e}", "", ""
|
| 70 |
if push and not hub_id.strip():
|
| 71 |
+
return "❌ Set a Hub model id (e.g. you/my-lora) or disable push.", "", ""
|
| 72 |
|
| 73 |
params = {
|
| 74 |
"run_name": run_name or "ltx2-lora", "mode": mode, "resolution": resolution.strip(),
|
|
|
|
| 76 |
"gradient_accumulation_steps": grad_accum,
|
| 77 |
"quantization": None if quantization in ("none", None) else quantization,
|
| 78 |
"optimizer_type": optimizer_type, "load_text_encoder_in_8bit": bool(te_8bit),
|
| 79 |
+
"push_to_hub": bool(push), "hub_model_id": hub_id.strip(),
|
| 80 |
+
"hf_token": oauth_token.token,
|
| 81 |
"captions": captions_text.splitlines() if captions_text else [], "seed": 42,
|
| 82 |
}
|
| 83 |
try:
|
|
|
|
| 85 |
except Exception as e: # noqa: BLE001
|
| 86 |
return f"❌ Submission failed: {e}", "", ""
|
| 87 |
|
| 88 |
+
status = f"✅ Job submitted on **{flavor}**, running as **{profile.username}**."
|
| 89 |
link = ""
|
| 90 |
if res["url"]:
|
| 91 |
link = f"**Job:** [{res['job_id']}]({res['url']}) \n**Bucket:** `{res['bucket']}`"
|
|
|
|
| 94 |
return status, link, res["log"]
|
| 95 |
|
| 96 |
|
| 97 |
+
def refresh(job_id, oauth_token: gr.OAuthToken | None):
|
| 98 |
if not job_id.strip():
|
| 99 |
return "Enter a job id.", ""
|
| 100 |
+
token = oauth_token.token if oauth_token else ""
|
| 101 |
st = jobs.job_status(job_id.strip(), token)
|
| 102 |
logs = jobs.job_logs(job_id.strip(), token)
|
| 103 |
+
return f"**Status:** `{st}`", logs[-MAX_LOG:] if len(logs) > MAX_LOG else logs
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _extract_id(link_md: str) -> str:
|
| 107 |
+
m = re.search(r"\[([^\]]+)\]\(http", link_md or "")
|
| 108 |
+
return m.group(1) if m else ""
|
| 109 |
+
|
| 110 |
|
| 111 |
+
MODE_HELP = (
|
| 112 |
+
"**IC-LoRA (in-context control)** learns from *pairs*: a target video `clip.mp4` and its "
|
| 113 |
+
"control/reference video `clip_reference.mp4` (depth, pose, edges, an inpainting-masked "
|
| 114 |
+
"version, …). Upload both — they're matched by filename. **T2V / I2V** need only the target clips."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
with gr.Blocks(title="LTX-2 LoRA Trainer") as demo:
|
| 118 |
+
with gr.Row(equal_height=True):
|
| 119 |
+
gr.Markdown(
|
| 120 |
+
"# 🎬 LTX-2 LoRA Trainer\n"
|
| 121 |
+
"Train a **LoRA / IC-LoRA for LTX-2.3** on your own videos. Training runs on "
|
| 122 |
+
"**HF Jobs**, data is staged on **HF buckets**, and the LoRA is pushed to your Hub — "
|
| 123 |
+
"you only pay for the GPU runtime.",
|
| 124 |
+
elem_id="hdr",
|
| 125 |
+
)
|
| 126 |
+
gr.LoginButton(scale=0, min_width=200)
|
| 127 |
+
|
| 128 |
+
banner = gr.Markdown(elem_id="banner")
|
| 129 |
|
|
|
|
|
|
|
| 130 |
with gr.Row():
|
| 131 |
+
# ---------------------------------------------------------------- left: data + training
|
| 132 |
+
with gr.Column(scale=3):
|
| 133 |
+
with gr.Group():
|
| 134 |
+
gr.Markdown("<span class='step-badge'>STEP 1</span> **Dataset**")
|
| 135 |
+
mode = gr.Dropdown(MODE_KEYS, value=MODE_KEYS[0], label="Training mode")
|
| 136 |
+
gr.Markdown(MODE_HELP)
|
| 137 |
+
files = gr.File(
|
| 138 |
+
label="Videos — targets + (for IC-LoRA) their *_reference videos, or a .zip",
|
| 139 |
+
file_count="multiple", file_types=["video", ".zip"], height=160,
|
| 140 |
+
)
|
| 141 |
+
captions = gr.Textbox(
|
| 142 |
+
label="Captions — one per line, in filename order of the target clips",
|
| 143 |
+
lines=4, placeholder="a fluffy cat in a sunlit room\na sweeping green landscape\n…",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
with gr.Group():
|
| 147 |
+
gr.Markdown("<span class='step-badge'>STEP 2</span> **Training settings**")
|
| 148 |
+
resolution = gr.Textbox(
|
| 149 |
+
label="Resolution (W×H×F)", value="768x512x49",
|
| 150 |
+
info="W,H divisible by 32 · frames F satisfy F % 8 == 1 (25, 49, 81, …)",
|
| 151 |
+
)
|
| 152 |
with gr.Row():
|
| 153 |
rank = gr.Number(label="LoRA rank", value=32, precision=0)
|
| 154 |
alpha = gr.Number(label="LoRA alpha", value=32, precision=0)
|
| 155 |
with gr.Row():
|
| 156 |
lr = gr.Number(label="Learning rate", value=2e-4)
|
| 157 |
steps = gr.Number(label="Steps", value=2000, precision=0)
|
| 158 |
+
with gr.Accordion("Advanced", open=False):
|
| 159 |
+
with gr.Row():
|
| 160 |
+
batch_size = gr.Number(label="Batch size", value=1, precision=0)
|
| 161 |
+
grad_accum = gr.Number(label="Grad accumulation", value=1, precision=0)
|
| 162 |
+
with gr.Row():
|
| 163 |
+
quantization = gr.Dropdown(
|
| 164 |
+
["none", "int8-quanto", "fp8-quanto"], value="none", label="Quantization",
|
| 165 |
+
info="a100-large (80 GB) fits 22B in bf16 — quantize only on smaller GPUs.",
|
| 166 |
+
)
|
| 167 |
+
optimizer_type = gr.Dropdown(["adamw", "adamw8bit"], value="adamw", label="Optimizer")
|
| 168 |
+
te_8bit = gr.Checkbox(label="Load text encoder in 8-bit", value=False)
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------- right: output + submit + monitor
|
| 171 |
+
with gr.Column(scale=2):
|
| 172 |
+
with gr.Group():
|
| 173 |
+
gr.Markdown("<span class='step-badge'>STEP 3</span> **Output & launch**")
|
| 174 |
+
run_name = gr.Textbox(label="Run name", value="ltx2-ic-lora")
|
| 175 |
+
push = gr.Checkbox(label="Push the trained LoRA to my Hub", value=True)
|
| 176 |
+
hub_id = gr.Textbox(label="Hub model id", placeholder="username/my-lora")
|
| 177 |
with gr.Row():
|
| 178 |
flavor = gr.Dropdown(FLAVORS, value=jobs.DEFAULT_FLAVOR, label="GPU flavor")
|
| 179 |
timeout = gr.Textbox(label="Timeout", value="4h")
|
| 180 |
+
submit_btn = gr.Button("🚀 Submit training job", variant="primary", size="lg")
|
| 181 |
+
status = gr.Markdown("")
|
| 182 |
+
joblink = gr.Markdown("")
|
| 183 |
+
|
| 184 |
+
with gr.Group():
|
| 185 |
+
gr.Markdown("**Monitor a job**")
|
| 186 |
+
with gr.Row():
|
| 187 |
+
job_id = gr.Textbox(label="Job id", scale=3)
|
| 188 |
+
refresh_btn = gr.Button("🔄 Refresh", scale=1)
|
| 189 |
+
mon_status = gr.Markdown("")
|
| 190 |
+
mon_logs = gr.Textbox(label="Job logs", lines=16, autoscroll=True, max_lines=16)
|
| 191 |
+
|
| 192 |
+
sublog = gr.Textbox(label="Submission output", lines=4, visible=False)
|
| 193 |
+
|
| 194 |
+
demo.load(_signin_state, inputs=None, outputs=[banner, hub_id])
|
| 195 |
|
| 196 |
submit_btn.click(
|
| 197 |
submit_job,
|
| 198 |
inputs=[files, captions, run_name, mode, resolution, rank, alpha, lr, steps, batch_size,
|
| 199 |
+
grad_accum, quantization, optimizer_type, te_8bit, push, hub_id, flavor, timeout],
|
| 200 |
outputs=[status, joblink, sublog],
|
| 201 |
+
).then(_extract_id, inputs=joblink, outputs=job_id)
|
| 202 |
+
refresh_btn.click(refresh, inputs=[job_id], outputs=[mon_status, mon_logs])
|
| 203 |
|
| 204 |
|
| 205 |
if __name__ == "__main__":
|
| 206 |
+
demo.queue(default_concurrency_limit=2).launch(theme=THEME, css=CSS)
|
jobs.py
CHANGED
|
@@ -8,8 +8,9 @@ Per training request:
|
|
| 8 |
|
| 9 |
On the Job, the script syncs the source bucket + run bucket, runs `uv sync --frozen`
|
| 10 |
(reproducing the working trainer env from the lockfile), downloads the base checkpoint
|
| 11 |
-
and Gemma, then runs
|
| 12 |
-
|
|
|
|
| 13 |
|
| 14 |
The Space itself only needs gradio + huggingface_hub + pyyaml (no torch).
|
| 15 |
"""
|
|
@@ -38,7 +39,7 @@ JOB_GEMMA = f"{JOB_ROOT}/gemma"
|
|
| 38 |
JOB_RUN = f"{JOB_ROOT}/run"
|
| 39 |
|
| 40 |
MODES = {
|
| 41 |
-
"IC-LoRA (
|
| 42 |
"needs_reference": True,
|
| 43 |
"target_modules": [
|
| 44 |
"attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0",
|
|
@@ -80,6 +81,47 @@ def _collect_videos(uploaded: list[str], dest: Path) -> list[Path]:
|
|
| 80 |
return sorted(p for p in dest.glob("*") if p.suffix.lower() in VIDEO_EXTS)
|
| 81 |
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
def build_config_dict(params: dict, videos: list[Path]) -> dict:
|
| 84 |
w, h, f = parse_resolution(params["resolution"])
|
| 85 |
mode_cfg = MODES[params["mode"]]
|
|
@@ -182,14 +224,6 @@ def main():
|
|
| 182 |
sh(["uv", "run", "python", *args], cwd=tr)
|
| 183 |
|
| 184 |
ds_json = RUN / "dataset" / "dataset.json"
|
| 185 |
-
if jc.get("needs_reference") and jc.get("generate_canny_reference"):
|
| 186 |
-
print("=== 5a/6 generate Canny reference videos ===", flush=True)
|
| 187 |
-
uvrun(["scripts/compute_reference.py", str(RUN / "dataset"), "--output", str(ds_json)])
|
| 188 |
-
items = json.loads(ds_json.read_text())
|
| 189 |
-
for it in items:
|
| 190 |
-
if "reference_path" in it:
|
| 191 |
-
it["reference_video"] = it.pop("reference_path")
|
| 192 |
-
ds_json.write_text(json.dumps(items, indent=2))
|
| 193 |
|
| 194 |
print("=== 5/6 preprocess dataset ===", flush=True)
|
| 195 |
uvrun(["scripts/process_dataset.py", str(ds_json),
|
|
@@ -227,20 +261,15 @@ def submit(params: dict, uploaded_videos: list[str], flavor: str, timeout: str)
|
|
| 227 |
videos = _collect_videos(uploaded_videos, tmp / "dataset" / "videos")
|
| 228 |
if not videos:
|
| 229 |
raise ValueError("No valid video files in the upload.")
|
| 230 |
-
# dataset.json
|
| 231 |
-
|
| 232 |
-
items =
|
| 233 |
-
"caption": caps[i] if i < len(caps) and caps[i].strip() else "a video"}
|
| 234 |
-
for i, v in enumerate(videos)]
|
| 235 |
(tmp / "dataset" / "dataset.json").write_text(json.dumps(items, indent=2))
|
| 236 |
-
# config.yaml + job_config.json
|
| 237 |
-
cfg = build_config_dict(params,
|
| 238 |
(tmp / "config.yaml").write_text(yaml.safe_dump(cfg, sort_keys=False))
|
| 239 |
-
job_cfg = {"resolution": params["resolution"],
|
| 240 |
-
"needs_reference": MODES[params["mode"]]["needs_reference"],
|
| 241 |
-
"generate_canny_reference": bool(params.get("generate_canny_reference", True))}
|
| 242 |
# job script reads job_config.json + config.yaml at the run root (sibling of dataset/)
|
| 243 |
-
(tmp / "job_config.json").write_text(json.dumps(
|
| 244 |
|
| 245 |
env = os.environ.copy()
|
| 246 |
if token:
|
|
|
|
| 8 |
|
| 9 |
On the Job, the script syncs the source bucket + run bucket, runs `uv sync --frozen`
|
| 10 |
(reproducing the working trainer env from the lockfile), downloads the base checkpoint
|
| 11 |
+
and Gemma, then runs process_dataset.py → train.py, which pushes the trained LoRA to the
|
| 12 |
+
Hub. For IC-LoRA, references are user-supplied (paired `*_reference` videos) — no
|
| 13 |
+
auto-derivation.
|
| 14 |
|
| 15 |
The Space itself only needs gradio + huggingface_hub + pyyaml (no torch).
|
| 16 |
"""
|
|
|
|
| 39 |
JOB_RUN = f"{JOB_ROOT}/run"
|
| 40 |
|
| 41 |
MODES = {
|
| 42 |
+
"IC-LoRA (in-context control)": {
|
| 43 |
"needs_reference": True,
|
| 44 |
"target_modules": [
|
| 45 |
"attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0",
|
|
|
|
| 81 |
return sorted(p for p in dest.glob("*") if p.suffix.lower() in VIDEO_EXTS)
|
| 82 |
|
| 83 |
|
| 84 |
+
def _is_reference(p: Path) -> bool:
|
| 85 |
+
return p.stem.endswith("_reference")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def build_dataset_items(
|
| 89 |
+
videos: list[Path], captions: list[str], needs_reference: bool
|
| 90 |
+
) -> tuple[list[dict], list[Path]]:
|
| 91 |
+
"""Build dataset.json rows. For IC-LoRA, pair each target `X.ext` with `X_reference.ext`
|
| 92 |
+
(user-supplied — no auto-derivation). Captions align to the sorted target clips.
|
| 93 |
+
|
| 94 |
+
Returns (items, targets). Raises ValueError on missing references or no targets.
|
| 95 |
+
"""
|
| 96 |
+
vids = [v for v in videos if v.suffix.lower() in VIDEO_EXTS]
|
| 97 |
+
if needs_reference:
|
| 98 |
+
targets = sorted(v for v in vids if not _is_reference(v))
|
| 99 |
+
refs = {v.stem[: -len("_reference")]: v for v in vids if _is_reference(v)}
|
| 100 |
+
else:
|
| 101 |
+
targets, refs = sorted(vids), {}
|
| 102 |
+
|
| 103 |
+
items, missing = [], []
|
| 104 |
+
for i, v in enumerate(targets):
|
| 105 |
+
cap = captions[i] if i < len(captions) and captions[i].strip() else "a video"
|
| 106 |
+
row = {"media_path": f"videos/{v.name}", "caption": cap}
|
| 107 |
+
if needs_reference:
|
| 108 |
+
ref = refs.get(v.stem)
|
| 109 |
+
if ref is None:
|
| 110 |
+
missing.append(v.name)
|
| 111 |
+
continue
|
| 112 |
+
row["reference_video"] = f"videos/{ref.name}"
|
| 113 |
+
items.append(row)
|
| 114 |
+
|
| 115 |
+
if needs_reference and missing:
|
| 116 |
+
raise ValueError(
|
| 117 |
+
"Missing reference video(s) for: " + ", ".join(missing)
|
| 118 |
+
+ ". For IC-LoRA, every target `X.mp4` needs a paired `X_reference.mp4`."
|
| 119 |
+
)
|
| 120 |
+
if not items:
|
| 121 |
+
raise ValueError("No target videos found in the upload.")
|
| 122 |
+
return items, targets
|
| 123 |
+
|
| 124 |
+
|
| 125 |
def build_config_dict(params: dict, videos: list[Path]) -> dict:
|
| 126 |
w, h, f = parse_resolution(params["resolution"])
|
| 127 |
mode_cfg = MODES[params["mode"]]
|
|
|
|
| 224 |
sh(["uv", "run", "python", *args], cwd=tr)
|
| 225 |
|
| 226 |
ds_json = RUN / "dataset" / "dataset.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
print("=== 5/6 preprocess dataset ===", flush=True)
|
| 229 |
uvrun(["scripts/process_dataset.py", str(ds_json),
|
|
|
|
| 261 |
videos = _collect_videos(uploaded_videos, tmp / "dataset" / "videos")
|
| 262 |
if not videos:
|
| 263 |
raise ValueError("No valid video files in the upload.")
|
| 264 |
+
# dataset.json — pairs targets with user-supplied references for IC-LoRA
|
| 265 |
+
needs_reference = MODES[params["mode"]]["needs_reference"]
|
| 266 |
+
items, targets = build_dataset_items(videos, params.get("captions", []), needs_reference)
|
|
|
|
|
|
|
| 267 |
(tmp / "dataset" / "dataset.json").write_text(json.dumps(items, indent=2))
|
| 268 |
+
# config.yaml + job_config.json (validation sample references the first target's reference)
|
| 269 |
+
cfg = build_config_dict(params, targets)
|
| 270 |
(tmp / "config.yaml").write_text(yaml.safe_dump(cfg, sort_keys=False))
|
|
|
|
|
|
|
|
|
|
| 271 |
# job script reads job_config.json + config.yaml at the run root (sibling of dataset/)
|
| 272 |
+
(tmp / "job_config.json").write_text(json.dumps({"resolution": params["resolution"]}, indent=2))
|
| 273 |
|
| 274 |
env = os.environ.copy()
|
| 275 |
if token:
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
gradio>=6.18
|
| 2 |
huggingface_hub[hf-xet]>=1.5
|
| 3 |
pyyaml
|
| 4 |
hf_transfer
|
|
|
|
| 1 |
+
gradio[oauth]>=6.18
|
| 2 |
huggingface_hub[hf-xet]>=1.5
|
| 3 |
pyyaml
|
| 4 |
hf_transfer
|