rootlocalghost commited on
Commit
f239fbd
Β·
verified Β·
1 Parent(s): a6158dc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import shutil
5
+ import uuid
6
+ import gradio as gr
7
+ from huggingface_hub import HfApi, hf_hub_download
8
+ from safetensors.torch import load_file, save_file
9
+
10
+ # --- ARCHITECTURE PROFILES ---
11
+ # Defines which sensitive layers must stay in BF16 during INT8 quantization to prevent precision collapse.
12
+ ARCH_PROFILES = {
13
+ "FLUX / Generic Rectified Flow": ["norm", "ln_", "embed", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"],
14
+ "Z-Image / DiT Core": ["t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer", "rope_embedder", "embed_tokens", "norm", "ln_", "shared"],
15
+ "Stable Diffusion (SDXL/SD3)": ["time_embed", "label_emb", "norm", "ln_", "out."]
16
+ }
17
+
18
+ def convert_and_upload(token, source_repo, target_repo, precision, target_components, arch_profile):
19
+ if not token:
20
+ yield "❌ Error: Please provide a valid Hugging Face Write Token."
21
+ return
22
+ if not target_repo.strip() or "/" not in target_repo:
23
+ yield "❌ Error: Target Repository must be in format 'username/repo-name'."
24
+ return
25
+ if not target_components:
26
+ yield "❌ Error: Please select at least one component to quantize."
27
+ return
28
+
29
+ # Map precision
30
+ target_dtype = None
31
+ is_int8 = precision == "INT8"
32
+ if precision == "FP8": target_dtype = torch.float8_e4m3fn
33
+ elif precision == "FP16": target_dtype = torch.float16
34
+ elif precision == "BF16": target_dtype = torch.bfloat16
35
+
36
+ api = HfApi(token=token)
37
+ yield f"πŸ”„ Verifying target repo: {target_repo}..."
38
+
39
+ try:
40
+ api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
41
+ except Exception as e:
42
+ yield f"❌ Error creating repo: {str(e)}"
43
+ return
44
+
45
+ yield f"πŸ“‹ Fetching files from {source_repo}..."
46
+ try:
47
+ files = api.list_repo_files(source_repo)
48
+ except Exception as e:
49
+ yield f"❌ Error fetching files: {str(e)}"
50
+ return
51
+
52
+ cache_dir = f"./hf_cache_{uuid.uuid4().hex[:8]}"
53
+ success_count, error_count = 0, 0
54
+ exclude_prefixes = ARCH_PROFILES.get(arch_profile, [])
55
+
56
+ for file in files:
57
+ if "/" not in file and file.endswith(".safetensors"):
58
+ yield f"πŸ—‘οΈ Auto-skipping massive root model: {file}..."
59
+ continue
60
+
61
+ yield f"⏳ Processing {file}..."
62
+
63
+ try:
64
+ os.makedirs(cache_dir, exist_ok=True)
65
+ local_path = hf_hub_download(repo_id=source_repo, filename=file, cache_dir=cache_dir, token=token)
66
+
67
+ in_target_component = any(f"{comp}/" in file for comp in target_components)
68
+
69
+ if file.endswith(".safetensors") and in_target_component:
70
+ yield f"🧠 Quantizing {file} to {precision}..."
71
+
72
+ tensors = load_file(local_path)
73
+ new_tensors = {}
74
+
75
+ for k, v in tensors.items():
76
+ if is_int8:
77
+ is_2d_weight = "weight" in k and len(v.shape) == 2
78
+ is_excluded = any(ex in k for ex in exclude_prefixes)
79
+
80
+ if is_2d_weight and not is_excluded:
81
+ if v.dtype == torch.float8_e4m3fn: v = v.to(torch.bfloat16)
82
+ scale = v.abs().max(dim=1, keepdim=True)[0] / 127.0
83
+ scale = scale.clamp(min=1e-8)
84
+ new_tensors[f"{k.rsplit('.', 1)[0]}.weight_int8"] = torch.round(v / scale).clamp(-127, 127).to(torch.int8)
85
+ new_tensors[f"{k.rsplit('.', 1)[0]}.weight_scale"] = scale.to(torch.bfloat16)
86
+ else:
87
+ new_tensors[k] = v.to(torch.bfloat16) if v.is_floating_point() else v
88
+ else:
89
+ new_tensors[k] = v.to(target_dtype) if v.is_floating_point() else v
90
+
91
+ converted_path = "converted.safetensors"
92
+ save_file(new_tensors, converted_path)
93
+
94
+ del tensors, new_tensors
95
+ gc.collect()
96
+
97
+ yield f"☁️ Uploading {precision} version of {file}..."
98
+ api.upload_file(path_or_fileobj=converted_path, path_in_repo=file, repo_id=target_repo)
99
+ os.remove(converted_path)
100
+
101
+ else:
102
+ yield f"☁️ Copying {file} as-is..."
103
+ api.upload_file(path_or_fileobj=local_path, path_in_repo=file, repo_id=target_repo)
104
+
105
+ success_count += 1
106
+ if os.path.exists(cache_dir): shutil.rmtree(cache_dir)
107
+ gc.collect()
108
+
109
+ except Exception as e:
110
+ error_count += 1
111
+ yield f"⚠️ Error processing {file}: {str(e)}\nSkipping..."
112
+
113
+ if os.path.exists(cache_dir): shutil.rmtree(cache_dir)
114
+ yield f"βœ… Finished! Processed: {success_count} | Errors: {error_count}."
115
+
116
+ # --- UI LOGIC ---
117
+ def generate_target_repo(source, precision):
118
+ model_name = source.split("/")[-1] if "/" in source else source
119
+ return f"your-username/{model_name}-{precision}"
120
+
121
+ def toggle_int8_warning(precision):
122
+ return gr.update(visible=(precision == "INT8"))
123
+
124
+ # --- GUI ---
125
+ with gr.Blocks(theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate")) as demo:
126
+ gr.Markdown(
127
+ """
128
+ # ⚑ Universal Model Quantizer Hub
129
+ Convert massive diffusion and transformer models directly on the Hugging Face hub.
130
+ Engineered with aggressive cache-clearing to prevent storage crashes on free-tier Spaces.
131
+ """
132
+ )
133
+
134
+ with gr.Row():
135
+ # LEFT COLUMN: Configuration
136
+ with gr.Column(scale=5):
137
+ with gr.Tabs():
138
+
139
+ with gr.TabItem("1. Authentication & Source"):
140
+ hf_token = gr.Textbox(label="HF Access Token (Write)", type="password", placeholder="hf_...")
141
+ source_repo = gr.Textbox(
142
+ label="Source Repository",
143
+ placeholder="e.g., black-forest-labs/FLUX.1-dev",
144
+ info="Paste any Hugging Face model repository ID."
145
+ )
146
+
147
+ gr.Markdown("### Popular Presets")
148
+ with gr.Row():
149
+ preset_flux = gr.Button("FLUX.2-klein-9B", size="sm")
150
+ preset_zimage = gr.Button("Z-Image-Turbo", size="sm")
151
+ preset_sd3 = gr.Button("SD3.5-Large", size="sm")
152
+
153
+ with gr.TabItem("2. Quantization Rules"):
154
+ arch_profile = gr.Radio(
155
+ choices=list(ARCH_PROFILES.keys()),
156
+ value="FLUX / Generic Rectified Flow",
157
+ label="Architecture Profile",
158
+ info="Crucial for INT8: Selects which layers to protect from precision loss."
159
+ )
160
+ target_components = gr.CheckboxGroup(
161
+ choices=["transformer", "text_encoder", "text_encoder_2", "vae"],
162
+ value=["transformer"],
163
+ label="Folders to Quantize",
164
+ info="Unselected folders will be copied to the new repo unchanged."
165
+ )
166
+
167
+ with gr.TabItem("3. Output Settings"):
168
+ precision = gr.Dropdown(
169
+ choices=["FP8", "FP16", "BF16", "INT8"],
170
+ value="INT8",
171
+ label="Target Precision"
172
+ )
173
+ int8_warning = gr.Markdown(
174
+ "⚠️ **INT8 Selected:** Keys will be split into `weight_int8` and `weight_scale`. "
175
+ "Requires custom XPU/CUDA native linear classes to execute.",
176
+ visible=True
177
+ )
178
+ target_repo = gr.Textbox(
179
+ label="Target Repository",
180
+ placeholder="your-username/model-name",
181
+ interactive=True
182
+ )
183
+
184
+ start_btn = gr.Button("πŸš€ Start Cloud Quantization", variant="primary", size="lg")
185
+
186
+ # RIGHT COLUMN: Logs
187
+ with gr.Column(scale=4):
188
+ output_log = gr.Textbox(
189
+ label="Terminal Output",
190
+ lines=24,
191
+ interactive=False,
192
+ max_lines=30,
193
+ show_copy_button=True
194
+ )
195
+
196
+ # --- WIRING ---
197
+ # Presets
198
+ preset_flux.click(lambda: ("black-forest-labs/FLUX.2-klein-9B", "FLUX / Generic Rectified Flow"), outputs=[source_repo, arch_profile])
199
+ preset_zimage.click(lambda: ("your-username/Z-Image-Turbo", "Z-Image / DiT Core"), outputs=[source_repo, arch_profile])
200
+ preset_sd3.click(lambda: ("stabilityai/stable-diffusion-3.5-large", "Stable Diffusion (SDXL/SD3)"), outputs=[source_repo, arch_profile])
201
+
202
+ # Dynamic Updates
203
+ source_repo.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo])
204
+ precision.change(fn=generate_target_repo, inputs=[source_repo, precision], outputs=[target_repo])
205
+ precision.change(fn=toggle_int8_warning, inputs=[precision], outputs=[int8_warning])
206
+
207
+ # Execution
208
+ start_btn.click(
209
+ fn=convert_and_upload,
210
+ inputs=[hf_token, source_repo, target_repo, precision, target_components, arch_profile],
211
+ outputs=[output_log]
212
+ )
213
+
214
+ if __name__ == "__main__":
215
+ demo.launch()