signsur4739379373 commited on
Commit
b717e92
·
verified ·
1 Parent(s): 1ff7948

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +340 -0
app.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import subprocess
4
+ import shutil
5
+ import json
6
+ import time
7
+ from pathlib import Path
8
+ import torch
9
+ import spaces
10
+ from diffusers import DiffusionPipeline
11
+
12
+ # ==========================================
13
+ # 1. SETUP & GLOBAL VARS
14
+ # ==========================================
15
+
16
+ DATASET_DIR = Path("./datasets")
17
+ OUTPUT_DIR = Path("./output")
18
+ DATASET_DIR.mkdir(exist_ok=True)
19
+ OUTPUT_DIR.mkdir(exist_ok=True)
20
+
21
+ # global tracking for loras
22
+ # key: friendly name, value: path
23
+ AVAILABLE_LORAS = {}
24
+
25
+ print("loading z-image-turbo pipeline...")
26
+ pipe = DiffusionPipeline.from_pretrained(
27
+ "Tongyi-MAI/Z-Image-Turbo",
28
+ torch_dtype=torch.bfloat16,
29
+ low_cpu_mem_usage=False,
30
+ )
31
+ pipe.to("cuda")
32
+ print("pipeline loaded!")
33
+
34
+ # ==========================================
35
+ # 2. TRAINING LOGIC
36
+ # ==========================================
37
+
38
+ def check_gpu():
39
+ if torch.cuda.is_available():
40
+ return f"✅ gpu available: {torch.cuda.get_device_name(0)}"
41
+ return "⚠️ no gpu detected"
42
+
43
+ def upload_and_prepare_dataset(files, dataset_name, trigger_word):
44
+ if not files:
45
+ return "❌ upload images first", None, ""
46
+
47
+ if not dataset_name:
48
+ dataset_name = f"dataset_{int(time.time())}"
49
+
50
+ dataset_path = DATASET_DIR / dataset_name
51
+ dataset_path.mkdir(exist_ok=True, parents=True)
52
+
53
+ image_count = 0
54
+ for file in files:
55
+ if file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.webp', '.bmp')):
56
+ filename = Path(file.name).name
57
+ dest = dataset_path / filename
58
+ shutil.copy(file.name, dest)
59
+
60
+ caption_file = dest.with_suffix('.txt')
61
+ caption_text = trigger_word if trigger_word else "a photo"
62
+ with open(caption_file, 'w') as f:
63
+ f.write(caption_text)
64
+
65
+ image_count += 1
66
+
67
+ if image_count == 0:
68
+ return "❌ no valid images found", None, ""
69
+
70
+ return f"✅ ready: {image_count} images in {dataset_name}", str(dataset_path), dataset_name
71
+
72
+ # request 1 hour gpu for training
73
+ @spaces.GPU(duration=3600)
74
+ def train_lora(
75
+ dataset_path,
76
+ project_name,
77
+ trigger_word,
78
+ steps,
79
+ learning_rate,
80
+ lora_rank,
81
+ resolution,
82
+ progress=gr.Progress()
83
+ ):
84
+ if not dataset_path:
85
+ return "❌ no dataset", None
86
+
87
+ if not project_name:
88
+ project_name = f"lora_{int(time.time())}"
89
+
90
+ output_path = OUTPUT_DIR / project_name
91
+ output_path.mkdir(exist_ok=True, parents=True)
92
+
93
+ # config generation
94
+ config = {
95
+ "job": "extension",
96
+ "config": {
97
+ "name": project_name,
98
+ "process": [{
99
+ "type": "sd_trainer",
100
+ "training_folder": str(output_path),
101
+ "device": "cuda:0",
102
+ "trigger_word": trigger_word or "",
103
+ "network": {
104
+ "type": "lora",
105
+ "linear": int(lora_rank),
106
+ "linear_alpha": int(lora_rank),
107
+ },
108
+ "save": {
109
+ "dtype": "float16",
110
+ "save_every": int(steps), # save only at end to save space
111
+ "max_step_saves_to_keep": 1,
112
+ },
113
+ "datasets": [{
114
+ "folder_path": dataset_path,
115
+ "caption_ext": "txt",
116
+ "caption_dropout_rate": 0.05,
117
+ "resolution": [int(resolution), int(resolution)],
118
+ }],
119
+ "train": {
120
+ "batch_size": 1,
121
+ "steps": int(steps),
122
+ "gradient_accumulation_steps": 1,
123
+ "train_unet": True,
124
+ "train_text_encoder": False,
125
+ "gradient_checkpointing": True,
126
+ "noise_scheduler": "flowmatch",
127
+ "optimizer": "adamw8bit",
128
+ "lr": float(learning_rate),
129
+ "ema_config": {"use_ema": True, "ema_decay": 0.99},
130
+ "dtype": "bf16",
131
+ },
132
+ "model": {
133
+ "name_or_path": "Tongyi-MAI/Z-Image-Base",
134
+ "is_v_pred": False,
135
+ "quantize": True,
136
+ },
137
+ }]
138
+ }
139
+ }
140
+
141
+ config_path = output_path / "config.json"
142
+ with open(config_path, 'w') as f:
143
+ json.dump(config, f, indent=2)
144
+
145
+ # install ai-toolkit
146
+ progress(0.1, desc="setting up environment...")
147
+ if not Path("./ai-toolkit").exists():
148
+ try:
149
+ subprocess.run(["git", "clone", "https://github.com/ostris/ai-toolkit.git"], check=True)
150
+ subprocess.run(["pip", "install", "-q", "-r", "ai-toolkit/requirements.txt"], check=True)
151
+ except Exception as e:
152
+ return f"❌ setup failed: {e}", None
153
+
154
+ progress(0.2, desc="training (this takes time)...")
155
+
156
+ try:
157
+ # run training script
158
+ # explicitly passing environment to ensure cuda visibility in subprocess
159
+ env = os.environ.copy()
160
+
161
+ proc = subprocess.run(
162
+ ["python", "ai-toolkit/run.py", str(config_path)],
163
+ capture_output=True,
164
+ text=True,
165
+ env=env,
166
+ timeout=3500
167
+ )
168
+
169
+ if proc.returncode != 0:
170
+ return f"❌ training crashed:\n{proc.stderr}", None
171
+
172
+ # find result
173
+ lora_files = list(output_path.glob("*.safetensors"))
174
+ if lora_files:
175
+ lora_file = lora_files[-1]
176
+ AVAILABLE_LORAS[project_name] = str(lora_file)
177
+
178
+ # update the dropdown choices dynamically
179
+ choices = [("None", None)] + [(k, v) for k, v in AVAILABLE_LORAS.items()]
180
+
181
+ return f"✅ trained: {project_name}", str(lora_file)
182
+
183
+ return "⚠️ finished but no safetensors found", None
184
+
185
+ except Exception as e:
186
+ return f"❌ fatal error: {e}", None
187
+
188
+ # ==========================================
189
+ # 3. INFERENCE LOGIC
190
+ # ==========================================
191
+
192
+ @spaces.GPU
193
+ def generate_image(
194
+ prompt,
195
+ height,
196
+ width,
197
+ steps,
198
+ seed,
199
+ randomize_seed,
200
+ lora_path,
201
+ lora_scale
202
+ ):
203
+ # handle lora loading/unloading
204
+ pipe.unload_lora_weights() # clean slate
205
+
206
+ if lora_path and os.path.exists(lora_path):
207
+ print(f"loading lora: {lora_path}")
208
+ try:
209
+ pipe.load_lora_weights(lora_path)
210
+ # manual scaling not always supported directly without fuse,
211
+ # but usually applied by default.
212
+ # for simplicitly we just load it.
213
+ except Exception as e:
214
+ print(f"lora load failed: {e}")
215
+
216
+ if randomize_seed:
217
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
218
+
219
+ generator = torch.Generator("cuda").manual_seed(int(seed))
220
+
221
+ image = pipe(
222
+ prompt=prompt,
223
+ height=int(height),
224
+ width=int(width),
225
+ num_inference_steps=int(steps),
226
+ guidance_scale=0.0,
227
+ generator=generator,
228
+ ).images[0]
229
+
230
+ return image, seed
231
+
232
+ def update_lora_list():
233
+ """helper to refresh dropdown"""
234
+ choices = [("None", None)] + [(k, v) for k, v in AVAILABLE_LORAS.items()]
235
+ return gr.Dropdown(choices=choices)
236
+
237
+ # ==========================================
238
+ # 4. UI CONSTRUCTION
239
+ # ==========================================
240
+
241
+ custom_theme = gr.themes.Soft(primary_hue="yellow", secondary_hue="slate")
242
+
243
+ with gr.Blocks(theme=custom_theme, title="Z-Image ZeroGPU Trainer") as demo:
244
+
245
+ gr.Markdown("# ⚡ Z-Image-Turbo: Train & Test")
246
+
247
+ with gr.Tabs():
248
+
249
+ # TAB 1: INFERENCE
250
+ with gr.Tab("🎨 Generate"):
251
+ with gr.Row():
252
+ with gr.Column():
253
+ prompt_input = gr.Textbox(label="Prompt", lines=3)
254
+
255
+ with gr.Row():
256
+ lora_selector = gr.Dropdown(
257
+ label="Select LoRA",
258
+ choices=[("None", None)],
259
+ value=None,
260
+ interactive=True
261
+ )
262
+ refresh_btn = gr.Button("🔄", size="sm", scale=0)
263
+
264
+ with gr.Accordion("Settings", open=False):
265
+ h_slider = gr.Slider(512, 2048, 1024, step=64, label="Height")
266
+ w_slider = gr.Slider(512, 2048, 1024, step=64, label="Width")
267
+ steps_slider = gr.Slider(1, 50, 9, step=1, label="Steps")
268
+ seed_num = gr.Number(42, label="Seed")
269
+ rand_seed = gr.Checkbox(True, label="Randomize Seed")
270
+
271
+ gen_btn = gr.Button("Generate", variant="primary")
272
+
273
+ with gr.Column():
274
+ out_img = gr.Image(label="Result")
275
+ out_seed = gr.Number(label="Seed Used")
276
+
277
+ # TAB 2: TRAINING
278
+ with gr.Tab("🏋️ Train LoRA"):
279
+ gr.Markdown("⚠️ **Note:** Requires paid GPU space for long timeouts.")
280
+
281
+ with gr.Row():
282
+ with gr.Column():
283
+ train_files = gr.Files(label="Images", file_types=["image"])
284
+ train_name = gr.Textbox(label="Project Name", value="my_lora")
285
+ train_trigger = gr.Textbox(label="Trigger Word", value="ohwx")
286
+
287
+ # hidden state for dataset path
288
+ dataset_path_state = gr.State()
289
+
290
+ upload_btn = gr.Button("1. Process Dataset")
291
+ upload_status = gr.Textbox(label="Dataset Status")
292
+
293
+ gr.Markdown("---")
294
+
295
+ train_steps = gr.Slider(100, 2000, 500, step=100, label="Steps")
296
+ train_lr = gr.Slider(1e-5, 1e-3, 1e-4, step=1e-5, label="Learning Rate")
297
+ train_rank = gr.Slider(4, 128, 16, step=4, label="Rank")
298
+
299
+ start_train_btn = gr.Button("2. Start Training", variant="stop")
300
+
301
+ with gr.Column():
302
+ train_log = gr.Textbox(label="Training Log", lines=10)
303
+ lora_file_download = gr.File(label="Download LoRA")
304
+
305
+ # WIRING
306
+
307
+ # Refresh LoRA list
308
+ refresh_btn.click(update_lora_list, outputs=lora_selector)
309
+
310
+ # Upload
311
+ upload_btn.click(
312
+ upload_and_prepare_dataset,
313
+ [train_files, train_name, train_trigger],
314
+ [upload_status, dataset_path_state, train_name]
315
+ )
316
+
317
+ # Train
318
+ def on_train_complete(status, file_path):
319
+ # Update available loras list immediately after training
320
+ new_choices = [("None", None)] + [(k, v) for k, v in AVAILABLE_LORAS.items()]
321
+ return status, file_path, gr.Dropdown(choices=new_choices)
322
+
323
+ start_train_btn.click(
324
+ train_lora,
325
+ [dataset_path_state, train_name, train_trigger, train_steps, train_lr, train_rank, h_slider], # reusing h_slider for res
326
+ [train_log, lora_file_download]
327
+ ).then(
328
+ update_lora_list,
329
+ outputs=[lora_selector]
330
+ )
331
+
332
+ # Generate
333
+ gen_btn.click(
334
+ generate_image,
335
+ [prompt_input, h_slider, w_slider, steps_slider, seed_num, rand_seed, lora_selector, train_lr], # train_lr dummy
336
+ [out_img, out_seed]
337
+ )
338
+
339
+ if __name__ == "__main__":
340
+ demo.launch()