linoyts HF Staff commited on
Commit
b325d81
·
verified ·
1 Parent(s): dc0ad54

OAuth login (no pasted tokens) + ai-toolkit-style UI + generic IC-LoRA (paired references, drop Canny)

Browse files
Files changed (4) hide show
  1. README.md +29 -13
  2. app.py +132 -61
  3. jobs.py +51 -22
  4. requirements.txt +1 -1
README.md CHANGED
@@ -8,6 +8,12 @@ sdk_version: 6.18.0
8
  app_file: app.py
9
  pinned: false
10
  hardware: cpu-basic
 
 
 
 
 
 
11
  ---
12
 
13
  # LTX-2 LoRA Trainer (HF Jobs)
@@ -15,9 +21,11 @@ hardware: cpu-basic
15
  Train a **LoRA / IC-LoRA for LTX-2.3** from your own videos, entirely on Hugging Face
16
  infrastructure:
17
 
 
 
18
  - the **Space** (this app, `cpu-basic`) collects your videos + hyperparameters and submits a job;
19
- - training runs on **HF Jobs** (`a100-large`), reproducing the trainer environment from the
20
- monorepo lockfile (`uv sync --frozen`);
21
  - your dataset is staged on **HF buckets** (`hf://buckets/<you>/ltx2-train-<run>`);
22
  - the trained LoRA is pushed to the Hub model repo you choose.
23
 
@@ -25,21 +33,29 @@ You only pay for the Job's actual GPU runtime.
25
 
26
  ## Setup
27
 
28
- 1. **Space secret `HF_TOKEN`** — a write token (Settings Variables and secrets). Used to
29
- create buckets, submit the job, download the gated Gemma encoder, and push the LoRA.
30
- 2. The job pulls the trainer source from the bucket **`LTX_SRC_BUCKET`**
31
- (default `linoyts/ltx2-trainer-src`). To use your own, upload the monorepo
32
- (`pyproject.toml`, `uv.lock`, `packages/`) to a bucket and set the `LTX_SRC_BUCKET`
33
- Space variable.
 
 
34
 
35
  ## Usage
36
 
37
- 1. Upload videos (or a `.zip`) and one caption per line (alphabetical by filename).
38
- 2. Pick a mode (IC-LoRA Canny / T2V / I2V), set hyperparameters, and a Hub model id.
39
- 3. **Submit training job** copy the job id into the monitor box and **Refresh** to stream logs.
 
 
 
 
 
 
40
 
41
- The job: sync buckets → `uv sync` → download base checkpoint + Gemma → (Canny references) →
42
- `process_dataset.py` → `train.py` → push LoRA to the Hub.
43
 
44
  ## Notes
45
 
 
8
  app_file: app.py
9
  pinned: false
10
  hardware: cpu-basic
11
+ hf_oauth: true
12
+ hf_oauth_scopes:
13
+ - read-repos
14
+ - write-repos
15
+ - manage-repos
16
+ - jobs
17
  ---
18
 
19
  # LTX-2 LoRA Trainer (HF Jobs)
 
21
  Train a **LoRA / IC-LoRA for LTX-2.3** from your own videos, entirely on Hugging Face
22
  infrastructure:
23
 
24
+ - **Sign in with Hugging Face** — the dataset, the job, and the pushed LoRA all run under
25
+ **your** account and billing (no pasted tokens);
26
  - the **Space** (this app, `cpu-basic`) collects your videos + hyperparameters and submits a job;
27
+ - training runs on **HF Jobs**, reproducing the trainer environment from the monorepo
28
+ lockfile (`uv sync --frozen`);
29
  - your dataset is staged on **HF buckets** (`hf://buckets/<you>/ltx2-train-<run>`);
30
  - the trained LoRA is pushed to the Hub model repo you choose.
31
 
 
33
 
34
  ## Setup
35
 
36
+ Just **sign in with Hugging Face** in the app the OAuth token (scopes: repos + `jobs`) is
37
+ used to create buckets, submit the job, download the gated Gemma encoder, and push the LoRA,
38
+ all as the signed-in user. No Space secret is required.
39
+
40
+ The job pulls the trainer source from the bucket **`LTX_SRC_BUCKET`**
41
+ (default `linoyts/ltx2-trainer-src`). To use your own, upload the monorepo
42
+ (`pyproject.toml`, `uv.lock`, `packages/`) to a bucket and set the `LTX_SRC_BUCKET`
43
+ Space variable.
44
 
45
  ## Usage
46
 
47
+ 1. Sign in with Hugging Face.
48
+ 2. Pick a mode and upload your videos:
49
+ - **IC-LoRA (in-context control):** upload each target `clip.mp4` **and** its control video
50
+ `clip_reference.mp4` (depth, pose, edges, an inpainting-masked version, …) — matched by
51
+ filename. References are user-supplied; nothing is auto-derived.
52
+ - **T2V / I2V:** upload only the target clips.
53
+ 3. Add one caption per line (in filename order of the target clips), set hyperparameters and a
54
+ Hub model id, then **Submit training job**.
55
+ 4. Copy the job id into the monitor box and **Refresh** to stream logs.
56
 
57
+ The job: sync buckets → `uv sync` → download base checkpoint + Gemma → `process_dataset.py`
58
+ → `train.py` → push LoRA to the Hub.
59
 
60
  ## Notes
61
 
app.py CHANGED
@@ -1,41 +1,66 @@
1
  """LTX-2 LoRA Trainer — HF Space (HF Jobs + buckets).
2
 
3
- Upload videos, set hyperparameters, and submit a training job to HF Jobs. The job runs
4
- on a100-large, reproduces the trainer env from the lockfile, trains a LoRA/IC-LoRA for
5
- LTX-2.3, and pushes it to your Hub repo. Datasets are staged on HF buckets.
 
6
 
7
  Runs on `cpu-basic` — the Space only submits + monitors jobs (no GPU, no torch here).
8
  """
9
 
10
  from __future__ import annotations
11
 
12
- import os
13
 
14
  import gradio as gr
15
 
16
  import jobs
17
 
18
- HEADER = """# 🎬 LTX-2 LoRA Trainer (HF Jobs)
19
- Train a LoRA / IC-LoRA for **LTX-2.3** from your own videos — training runs on **HF Jobs**
20
- (a100-large), data is staged on **HF buckets**, and the trained LoRA is pushed to your Hub repo.
21
-
22
- You only pay for the Job's actual runtime. Set `HF_TOKEN` as a Space secret (or paste a token below).
23
- """
24
-
25
  FLAVORS = ["a100-large", "a100x4", "l40sx1", "l40sx4"]
26
  MAX_LOG = 60_000
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
-
29
- def _extract_id(link_md: str) -> str:
30
- import re
31
- m = re.search(r"\[([^\]]+)\]\(http", link_md or "")
32
- return m.group(1) if m else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  def submit_job(
36
  files, captions_text, run_name, mode, resolution, rank, alpha, lr, steps, batch_size,
37
- grad_accum, quantization, optimizer_type, te_8bit, gen_canny, push, hub_id, hf_token, flavor, timeout,
 
38
  ):
 
 
39
  if not files:
40
  return "❌ Upload at least one video.", "", ""
41
  try:
@@ -43,7 +68,7 @@ def submit_job(
43
  except ValueError as e:
44
  return f"❌ {e}", "", ""
45
  if push and not hub_id.strip():
46
- return "❌ Set a Hub model id (username/my-lora) or disable push.", "", ""
47
 
48
  params = {
49
  "run_name": run_name or "ltx2-lora", "mode": mode, "resolution": resolution.strip(),
@@ -51,8 +76,8 @@ def submit_job(
51
  "gradient_accumulation_steps": grad_accum,
52
  "quantization": None if quantization in ("none", None) else quantization,
53
  "optimizer_type": optimizer_type, "load_text_encoder_in_8bit": bool(te_8bit),
54
- "generate_canny_reference": bool(gen_canny), "push_to_hub": bool(push),
55
- "hub_model_id": hub_id.strip(), "hf_token": (hf_token or os.environ.get("HF_TOKEN", "")).strip(),
56
  "captions": captions_text.splitlines() if captions_text else [], "seed": 42,
57
  }
58
  try:
@@ -60,7 +85,7 @@ def submit_job(
60
  except Exception as e: # noqa: BLE001
61
  return f"❌ Submission failed: {e}", "", ""
62
 
63
- status = f"✅ Job submitted on **{flavor}**."
64
  link = ""
65
  if res["url"]:
66
  link = f"**Job:** [{res['job_id']}]({res['url']}) \n**Bucket:** `{res['bucket']}`"
@@ -69,67 +94,113 @@ def submit_job(
69
  return status, link, res["log"]
70
 
71
 
72
- def refresh(job_id, hf_token):
73
  if not job_id.strip():
74
  return "Enter a job id.", ""
75
- token = (hf_token or os.environ.get("HF_TOKEN", "")).strip()
76
  st = jobs.job_status(job_id.strip(), token)
77
  logs = jobs.job_logs(job_id.strip(), token)
78
- return f"**Status:** {st}", logs[-MAX_LOG:] if len(logs) > MAX_LOG else logs
 
 
 
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- with gr.Blocks(title="LTX-2 LoRA Trainer (HF Jobs)") as demo:
82
- gr.Markdown(HEADER)
83
  with gr.Row():
84
- with gr.Column(scale=1):
85
- run_name = gr.Textbox(label="Run name", value="ltx2-ic-lora")
86
- mode = gr.Dropdown(list(jobs.MODES.keys()), value=list(jobs.MODES.keys())[0], label="Training mode")
87
- files = gr.File(label="Videos (or a .zip)", file_count="multiple", file_types=["video", ".zip"])
88
- captions = gr.Textbox(label="Captions (one per line, alphabetical by filename)", lines=4)
89
- gen_canny = gr.Checkbox(label="Auto-generate Canny references (IC-LoRA)", value=True)
90
- with gr.Accordion("Hyperparameters", open=True):
91
- resolution = gr.Textbox(label="Resolution WxHxF", value="512x320x25")
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  with gr.Row():
93
  rank = gr.Number(label="LoRA rank", value=32, precision=0)
94
  alpha = gr.Number(label="LoRA alpha", value=32, precision=0)
95
  with gr.Row():
96
  lr = gr.Number(label="Learning rate", value=2e-4)
97
  steps = gr.Number(label="Steps", value=2000, precision=0)
98
- with gr.Row():
99
- batch_size = gr.Number(label="Batch size", value=1, precision=0)
100
- grad_accum = gr.Number(label="Grad accumulation", value=1, precision=0)
101
- with gr.Row():
102
- quantization = gr.Dropdown(["none", "int8-quanto", "fp8-quanto"], value="none",
103
- label="Quantization", info="a100-large (80GB) fits 22B in bf16.")
104
- optimizer_type = gr.Dropdown(["adamw", "adamw8bit"], value="adamw", label="Optimizer")
105
- te_8bit = gr.Checkbox(label="Load text encoder in 8-bit", value=False)
106
- with gr.Accordion("Hub & hardware", open=True):
107
- push = gr.Checkbox(label="Push trained LoRA to the Hub", value=True)
108
- hub_id = gr.Textbox(label="Hub model id", value="linoyts/ltx2-ic-lora", placeholder="username/my-lora")
109
- hf_token = gr.Textbox(label="HF token (blank → Space secret HF_TOKEN)", type="password")
 
 
 
 
 
 
 
110
  with gr.Row():
111
  flavor = gr.Dropdown(FLAVORS, value=jobs.DEFAULT_FLAVOR, label="GPU flavor")
112
  timeout = gr.Textbox(label="Timeout", value="4h")
113
- submit_btn = gr.Button("🚀 Submit training job", variant="primary")
114
- with gr.Column(scale=1):
115
- status = gr.Markdown("Ready.")
116
- joblink = gr.Markdown("")
117
- sublog = gr.Textbox(label="Submission output", lines=6)
118
- gr.Markdown("### Monitor a job")
119
- with gr.Row():
120
- job_id = gr.Textbox(label="Job id", scale=3)
121
- refresh_btn = gr.Button("🔄 Refresh", scale=1)
122
- mon_status = gr.Markdown("")
123
- mon_logs = gr.Textbox(label="Job logs", lines=18, autoscroll=True)
 
 
 
 
124
 
125
  submit_btn.click(
126
  submit_job,
127
  inputs=[files, captions, run_name, mode, resolution, rank, alpha, lr, steps, batch_size,
128
- grad_accum, quantization, optimizer_type, te_8bit, gen_canny, push, hub_id, hf_token, flavor, timeout],
129
  outputs=[status, joblink, sublog],
130
- ).then(lambda link: _extract_id(link), inputs=joblink, outputs=job_id)
131
- refresh_btn.click(refresh, inputs=[job_id, hf_token], outputs=[mon_status, mon_logs])
132
 
133
 
134
  if __name__ == "__main__":
135
- demo.queue(default_concurrency_limit=2).launch()
 
1
  """LTX-2 LoRA Trainer — HF Space (HF Jobs + buckets).
2
 
3
+ Sign in with Hugging Face, upload videos + captions, set hyperparameters, and submit a
4
+ training job to HF Jobs. The job runs on a GPU flavor, reproduces the trainer env from the
5
+ lockfile, trains a LoRA / IC-LoRA for LTX-2.3, and pushes it to your Hub repo. Datasets are
6
+ staged on HF buckets. Everything runs under the *signed-in user's* account — no pasted tokens.
7
 
8
  Runs on `cpu-basic` — the Space only submits + monitors jobs (no GPU, no torch here).
9
  """
10
 
11
  from __future__ import annotations
12
 
13
+ import re
14
 
15
  import gradio as gr
16
 
17
  import jobs
18
 
 
 
 
 
 
 
 
19
  FLAVORS = ["a100-large", "a100x4", "l40sx1", "l40sx4"]
20
  MAX_LOG = 60_000
21
+ MODE_KEYS = list(jobs.MODES.keys())
22
+
23
+ CSS = """
24
+ #hdr {text-align:left; padding:4px 2px 0 2px;}
25
+ #hdr h1 {margin:0; font-size:1.7rem;}
26
+ #hdr p {margin:.25rem 0 0 0; color:var(--body-text-color-subdued);}
27
+ .section-card {border:1px solid var(--block-border-color); border-radius:14px;
28
+ padding:14px 16px; background:var(--block-background-fill);}
29
+ .step-badge {font-weight:600; color:var(--primary-500);}
30
+ #banner {border-radius:12px; padding:10px 14px; font-size:.95rem;}
31
+ footer {visibility:hidden;}
32
+ """
33
 
34
+ THEME = gr.themes.Soft(
35
+ primary_hue=gr.themes.colors.indigo,
36
+ secondary_hue=gr.themes.colors.purple,
37
+ radius_size=gr.themes.sizes.radius_lg,
38
+ )
39
+
40
+
41
+ def _signin_state(profile: gr.OAuthProfile | None):
42
+ """Banner + default hub id, recomputed on load and after sign-in."""
43
+ if profile is None:
44
+ return (
45
+ "🔒 **You're not signed in.** Use **Sign in with Hugging Face** (top-right) to start — "
46
+ "your dataset, the training job, and the resulting LoRA all run under **your** account "
47
+ "and billing. No tokens to paste.",
48
+ gr.update(),
49
+ )
50
+ return (
51
+ f"✅ Signed in as **{profile.username}** — jobs, buckets and the pushed LoRA will live under "
52
+ f"your account.",
53
+ gr.update(value=f"{profile.username}/ltx2-lora"),
54
+ )
55
 
56
 
57
  def submit_job(
58
  files, captions_text, run_name, mode, resolution, rank, alpha, lr, steps, batch_size,
59
+ grad_accum, quantization, optimizer_type, te_8bit, push, hub_id, flavor, timeout,
60
+ profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None,
61
  ):
62
+ if oauth_token is None or profile is None:
63
+ return "❌ Please **sign in with Hugging Face** first (top-right).", "", ""
64
  if not files:
65
  return "❌ Upload at least one video.", "", ""
66
  try:
 
68
  except ValueError as e:
69
  return f"❌ {e}", "", ""
70
  if push and not hub_id.strip():
71
+ return "❌ Set a Hub model id (e.g. you/my-lora) or disable push.", "", ""
72
 
73
  params = {
74
  "run_name": run_name or "ltx2-lora", "mode": mode, "resolution": resolution.strip(),
 
76
  "gradient_accumulation_steps": grad_accum,
77
  "quantization": None if quantization in ("none", None) else quantization,
78
  "optimizer_type": optimizer_type, "load_text_encoder_in_8bit": bool(te_8bit),
79
+ "push_to_hub": bool(push), "hub_model_id": hub_id.strip(),
80
+ "hf_token": oauth_token.token,
81
  "captions": captions_text.splitlines() if captions_text else [], "seed": 42,
82
  }
83
  try:
 
85
  except Exception as e: # noqa: BLE001
86
  return f"❌ Submission failed: {e}", "", ""
87
 
88
+ status = f"✅ Job submitted on **{flavor}**, running as **{profile.username}**."
89
  link = ""
90
  if res["url"]:
91
  link = f"**Job:** [{res['job_id']}]({res['url']}) \n**Bucket:** `{res['bucket']}`"
 
94
  return status, link, res["log"]
95
 
96
 
97
+ def refresh(job_id, oauth_token: gr.OAuthToken | None):
98
  if not job_id.strip():
99
  return "Enter a job id.", ""
100
+ token = oauth_token.token if oauth_token else ""
101
  st = jobs.job_status(job_id.strip(), token)
102
  logs = jobs.job_logs(job_id.strip(), token)
103
+ return f"**Status:** `{st}`", logs[-MAX_LOG:] if len(logs) > MAX_LOG else logs
104
+
105
+
106
+ def _extract_id(link_md: str) -> str:
107
+ m = re.search(r"\[([^\]]+)\]\(http", link_md or "")
108
+ return m.group(1) if m else ""
109
+
110
 
111
+ MODE_HELP = (
112
+ "**IC-LoRA (in-context control)** learns from *pairs*: a target video `clip.mp4` and its "
113
+ "control/reference video `clip_reference.mp4` (depth, pose, edges, an inpainting-masked "
114
+ "version, …). Upload both — they're matched by filename. **T2V / I2V** need only the target clips."
115
+ )
116
+
117
+ with gr.Blocks(title="LTX-2 LoRA Trainer") as demo:
118
+ with gr.Row(equal_height=True):
119
+ gr.Markdown(
120
+ "# 🎬 LTX-2 LoRA Trainer\n"
121
+ "Train a **LoRA / IC-LoRA for LTX-2.3** on your own videos. Training runs on "
122
+ "**HF Jobs**, data is staged on **HF buckets**, and the LoRA is pushed to your Hub — "
123
+ "you only pay for the GPU runtime.",
124
+ elem_id="hdr",
125
+ )
126
+ gr.LoginButton(scale=0, min_width=200)
127
+
128
+ banner = gr.Markdown(elem_id="banner")
129
 
 
 
130
  with gr.Row():
131
+ # ---------------------------------------------------------------- left: data + training
132
+ with gr.Column(scale=3):
133
+ with gr.Group():
134
+ gr.Markdown("<span class='step-badge'>STEP 1</span> &nbsp; **Dataset**")
135
+ mode = gr.Dropdown(MODE_KEYS, value=MODE_KEYS[0], label="Training mode")
136
+ gr.Markdown(MODE_HELP)
137
+ files = gr.File(
138
+ label="Videos — targets + (for IC-LoRA) their *_reference videos, or a .zip",
139
+ file_count="multiple", file_types=["video", ".zip"], height=160,
140
+ )
141
+ captions = gr.Textbox(
142
+ label="Captions — one per line, in filename order of the target clips",
143
+ lines=4, placeholder="a fluffy cat in a sunlit room\na sweeping green landscape\n…",
144
+ )
145
+
146
+ with gr.Group():
147
+ gr.Markdown("<span class='step-badge'>STEP 2</span> &nbsp; **Training settings**")
148
+ resolution = gr.Textbox(
149
+ label="Resolution (W×H×F)", value="768x512x49",
150
+ info="W,H divisible by 32 · frames F satisfy F % 8 == 1 (25, 49, 81, …)",
151
+ )
152
  with gr.Row():
153
  rank = gr.Number(label="LoRA rank", value=32, precision=0)
154
  alpha = gr.Number(label="LoRA alpha", value=32, precision=0)
155
  with gr.Row():
156
  lr = gr.Number(label="Learning rate", value=2e-4)
157
  steps = gr.Number(label="Steps", value=2000, precision=0)
158
+ with gr.Accordion("Advanced", open=False):
159
+ with gr.Row():
160
+ batch_size = gr.Number(label="Batch size", value=1, precision=0)
161
+ grad_accum = gr.Number(label="Grad accumulation", value=1, precision=0)
162
+ with gr.Row():
163
+ quantization = gr.Dropdown(
164
+ ["none", "int8-quanto", "fp8-quanto"], value="none", label="Quantization",
165
+ info="a100-large (80 GB) fits 22B in bf16 — quantize only on smaller GPUs.",
166
+ )
167
+ optimizer_type = gr.Dropdown(["adamw", "adamw8bit"], value="adamw", label="Optimizer")
168
+ te_8bit = gr.Checkbox(label="Load text encoder in 8-bit", value=False)
169
+
170
+ # ---------------------------------------------------------------- right: output + submit + monitor
171
+ with gr.Column(scale=2):
172
+ with gr.Group():
173
+ gr.Markdown("<span class='step-badge'>STEP 3</span> &nbsp; **Output & launch**")
174
+ run_name = gr.Textbox(label="Run name", value="ltx2-ic-lora")
175
+ push = gr.Checkbox(label="Push the trained LoRA to my Hub", value=True)
176
+ hub_id = gr.Textbox(label="Hub model id", placeholder="username/my-lora")
177
  with gr.Row():
178
  flavor = gr.Dropdown(FLAVORS, value=jobs.DEFAULT_FLAVOR, label="GPU flavor")
179
  timeout = gr.Textbox(label="Timeout", value="4h")
180
+ submit_btn = gr.Button("🚀 Submit training job", variant="primary", size="lg")
181
+ status = gr.Markdown("")
182
+ joblink = gr.Markdown("")
183
+
184
+ with gr.Group():
185
+ gr.Markdown("**Monitor a job**")
186
+ with gr.Row():
187
+ job_id = gr.Textbox(label="Job id", scale=3)
188
+ refresh_btn = gr.Button("🔄 Refresh", scale=1)
189
+ mon_status = gr.Markdown("")
190
+ mon_logs = gr.Textbox(label="Job logs", lines=16, autoscroll=True, max_lines=16)
191
+
192
+ sublog = gr.Textbox(label="Submission output", lines=4, visible=False)
193
+
194
+ demo.load(_signin_state, inputs=None, outputs=[banner, hub_id])
195
 
196
  submit_btn.click(
197
  submit_job,
198
  inputs=[files, captions, run_name, mode, resolution, rank, alpha, lr, steps, batch_size,
199
+ grad_accum, quantization, optimizer_type, te_8bit, push, hub_id, flavor, timeout],
200
  outputs=[status, joblink, sublog],
201
+ ).then(_extract_id, inputs=joblink, outputs=job_id)
202
+ refresh_btn.click(refresh, inputs=[job_id], outputs=[mon_status, mon_logs])
203
 
204
 
205
  if __name__ == "__main__":
206
+ demo.queue(default_concurrency_limit=2).launch(theme=THEME, css=CSS)
jobs.py CHANGED
@@ -8,8 +8,9 @@ Per training request:
8
 
9
  On the Job, the script syncs the source bucket + run bucket, runs `uv sync --frozen`
10
  (reproducing the working trainer env from the lockfile), downloads the base checkpoint
11
- and Gemma, then runs compute_reference (Canny) → process_dataset.py → train.py, which
12
- pushes the trained LoRA to the Hub.
 
13
 
14
  The Space itself only needs gradio + huggingface_hub + pyyaml (no torch).
15
  """
@@ -38,7 +39,7 @@ JOB_GEMMA = f"{JOB_ROOT}/gemma"
38
  JOB_RUN = f"{JOB_ROOT}/run"
39
 
40
  MODES = {
41
- "IC-LoRA (video→video, Canny control)": {
42
  "needs_reference": True,
43
  "target_modules": [
44
  "attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0",
@@ -80,6 +81,47 @@ def _collect_videos(uploaded: list[str], dest: Path) -> list[Path]:
80
  return sorted(p for p in dest.glob("*") if p.suffix.lower() in VIDEO_EXTS)
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def build_config_dict(params: dict, videos: list[Path]) -> dict:
84
  w, h, f = parse_resolution(params["resolution"])
85
  mode_cfg = MODES[params["mode"]]
@@ -182,14 +224,6 @@ def main():
182
  sh(["uv", "run", "python", *args], cwd=tr)
183
 
184
  ds_json = RUN / "dataset" / "dataset.json"
185
- if jc.get("needs_reference") and jc.get("generate_canny_reference"):
186
- print("=== 5a/6 generate Canny reference videos ===", flush=True)
187
- uvrun(["scripts/compute_reference.py", str(RUN / "dataset"), "--output", str(ds_json)])
188
- items = json.loads(ds_json.read_text())
189
- for it in items:
190
- if "reference_path" in it:
191
- it["reference_video"] = it.pop("reference_path")
192
- ds_json.write_text(json.dumps(items, indent=2))
193
 
194
  print("=== 5/6 preprocess dataset ===", flush=True)
195
  uvrun(["scripts/process_dataset.py", str(ds_json),
@@ -227,20 +261,15 @@ def submit(params: dict, uploaded_videos: list[str], flavor: str, timeout: str)
227
  videos = _collect_videos(uploaded_videos, tmp / "dataset" / "videos")
228
  if not videos:
229
  raise ValueError("No valid video files in the upload.")
230
- # dataset.json (media_path + caption, relative to dataset dir)
231
- caps = params.get("captions", [])
232
- items = [{"media_path": f"videos/{v.name}",
233
- "caption": caps[i] if i < len(caps) and caps[i].strip() else "a video"}
234
- for i, v in enumerate(videos)]
235
  (tmp / "dataset" / "dataset.json").write_text(json.dumps(items, indent=2))
236
- # config.yaml + job_config.json
237
- cfg = build_config_dict(params, videos)
238
  (tmp / "config.yaml").write_text(yaml.safe_dump(cfg, sort_keys=False))
239
- job_cfg = {"resolution": params["resolution"],
240
- "needs_reference": MODES[params["mode"]]["needs_reference"],
241
- "generate_canny_reference": bool(params.get("generate_canny_reference", True))}
242
  # job script reads job_config.json + config.yaml at the run root (sibling of dataset/)
243
- (tmp / "job_config.json").write_text(json.dumps(job_cfg, indent=2))
244
 
245
  env = os.environ.copy()
246
  if token:
 
8
 
9
  On the Job, the script syncs the source bucket + run bucket, runs `uv sync --frozen`
10
  (reproducing the working trainer env from the lockfile), downloads the base checkpoint
11
+ and Gemma, then runs process_dataset.py → train.py, which pushes the trained LoRA to the
12
+ Hub. For IC-LoRA, references are user-supplied (paired `*_reference` videos) — no
13
+ auto-derivation.
14
 
15
  The Space itself only needs gradio + huggingface_hub + pyyaml (no torch).
16
  """
 
39
  JOB_RUN = f"{JOB_ROOT}/run"
40
 
41
  MODES = {
42
+ "IC-LoRA (in-context control)": {
43
  "needs_reference": True,
44
  "target_modules": [
45
  "attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0",
 
81
  return sorted(p for p in dest.glob("*") if p.suffix.lower() in VIDEO_EXTS)
82
 
83
 
84
+ def _is_reference(p: Path) -> bool:
85
+ return p.stem.endswith("_reference")
86
+
87
+
88
+ def build_dataset_items(
89
+ videos: list[Path], captions: list[str], needs_reference: bool
90
+ ) -> tuple[list[dict], list[Path]]:
91
+ """Build dataset.json rows. For IC-LoRA, pair each target `X.ext` with `X_reference.ext`
92
+ (user-supplied — no auto-derivation). Captions align to the sorted target clips.
93
+
94
+ Returns (items, targets). Raises ValueError on missing references or no targets.
95
+ """
96
+ vids = [v for v in videos if v.suffix.lower() in VIDEO_EXTS]
97
+ if needs_reference:
98
+ targets = sorted(v for v in vids if not _is_reference(v))
99
+ refs = {v.stem[: -len("_reference")]: v for v in vids if _is_reference(v)}
100
+ else:
101
+ targets, refs = sorted(vids), {}
102
+
103
+ items, missing = [], []
104
+ for i, v in enumerate(targets):
105
+ cap = captions[i] if i < len(captions) and captions[i].strip() else "a video"
106
+ row = {"media_path": f"videos/{v.name}", "caption": cap}
107
+ if needs_reference:
108
+ ref = refs.get(v.stem)
109
+ if ref is None:
110
+ missing.append(v.name)
111
+ continue
112
+ row["reference_video"] = f"videos/{ref.name}"
113
+ items.append(row)
114
+
115
+ if needs_reference and missing:
116
+ raise ValueError(
117
+ "Missing reference video(s) for: " + ", ".join(missing)
118
+ + ". For IC-LoRA, every target `X.mp4` needs a paired `X_reference.mp4`."
119
+ )
120
+ if not items:
121
+ raise ValueError("No target videos found in the upload.")
122
+ return items, targets
123
+
124
+
125
  def build_config_dict(params: dict, videos: list[Path]) -> dict:
126
  w, h, f = parse_resolution(params["resolution"])
127
  mode_cfg = MODES[params["mode"]]
 
224
  sh(["uv", "run", "python", *args], cwd=tr)
225
 
226
  ds_json = RUN / "dataset" / "dataset.json"
 
 
 
 
 
 
 
 
227
 
228
  print("=== 5/6 preprocess dataset ===", flush=True)
229
  uvrun(["scripts/process_dataset.py", str(ds_json),
 
261
  videos = _collect_videos(uploaded_videos, tmp / "dataset" / "videos")
262
  if not videos:
263
  raise ValueError("No valid video files in the upload.")
264
+ # dataset.json pairs targets with user-supplied references for IC-LoRA
265
+ needs_reference = MODES[params["mode"]]["needs_reference"]
266
+ items, targets = build_dataset_items(videos, params.get("captions", []), needs_reference)
 
 
267
  (tmp / "dataset" / "dataset.json").write_text(json.dumps(items, indent=2))
268
+ # config.yaml + job_config.json (validation sample references the first target's reference)
269
+ cfg = build_config_dict(params, targets)
270
  (tmp / "config.yaml").write_text(yaml.safe_dump(cfg, sort_keys=False))
 
 
 
271
  # job script reads job_config.json + config.yaml at the run root (sibling of dataset/)
272
+ (tmp / "job_config.json").write_text(json.dumps({"resolution": params["resolution"]}, indent=2))
273
 
274
  env = os.environ.copy()
275
  if token:
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio>=6.18
2
  huggingface_hub[hf-xet]>=1.5
3
  pyyaml
4
  hf_transfer
 
1
+ gradio[oauth]>=6.18
2
  huggingface_hub[hf-xet]>=1.5
3
  pyyaml
4
  hf_transfer