| import os |
| import gc |
| import torch |
| import shutil |
| import uuid |
| import gradio as gr |
| from huggingface_hub import HfApi, hf_hub_download |
| from safetensors.torch import load_file, save_file |
|
|
| def convert_and_upload(token, source_repo, target_repo, precision, target_components): |
| if not token: |
| yield "❌ Error: Please provide a valid Hugging Face Write Token." |
| return |
| if not target_repo.strip() or "your-username" in target_repo: |
| yield "❌ Error: Please specify a valid Target Repository (e.g., your-username/repo-name)." |
| return |
| if not target_components: |
| yield "❌ Error: Please select at least one component to quantize." |
| return |
|
|
| target_dtype = None |
| is_int8 = False |
| |
| if precision == "FP8": target_dtype = torch.float8_e4m3fn |
| elif precision == "FP16": target_dtype = torch.float16 |
| elif precision == "BF16": target_dtype = torch.bfloat16 |
| elif precision == "INT8": is_int8 = True |
|
|
| api = HfApi(token=token) |
| yield f"🔄 Connecting to Hugging Face and verifying target repo: {target_repo}..." |
|
|
| try: |
| api.create_repo(repo_id=target_repo, exist_ok=True, private=False) |
| except Exception as e: |
| yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions." |
| return |
|
|
| yield f"📋 Fetching file list from {source_repo}..." |
| try: |
| files = api.list_repo_files(source_repo) |
| except Exception as e: |
| yield f"❌ Error fetching files: {str(e)}" |
| return |
|
|
| cache_dir = f"./hf_cache_{uuid.uuid4().hex[:8]}" |
| success_count = 0 |
| error_count = 0 |
|
|
| |
| |
| exclude_prefixes = [ |
| "t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer", "rope_embedder", |
| "embed_tokens", "norm", "ln_", "shared" |
| ] |
|
|
| for file in files: |
| is_root_safetensor = "/" not in file and file.endswith(".safetensors") |
| |
| if is_root_safetensor: |
| yield f"🗑️ Auto-skipping root model: {file}..." |
| try: |
| api.delete_file(path_in_repo=file, repo_id=target_repo, token=token, commit_message=f"Auto-deleted {file}") |
| except Exception: |
| pass |
| continue |
|
|
| yield f"⏳ Processing {file}..." |
|
|
| try: |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| local_path = hf_hub_download( |
| repo_id=source_repo, |
| filename=file, |
| cache_dir=cache_dir, |
| token=token |
| ) |
|
|
| in_target_component = any(f"{comp}/" in file for comp in target_components) |
|
|
| if file.endswith(".safetensors") and in_target_component: |
| yield f"🧠 Quantizing {file} to {precision}..." |
| |
| tensors = load_file(local_path) |
| new_tensors = {} |
|
|
| for k, v in tensors.items(): |
| |
| if is_int8: |
| is_2d_weight = "weight" in k and len(v.shape) == 2 |
| is_excluded = any(ex in k for ex in exclude_prefixes) |
|
|
| if is_2d_weight and not is_excluded: |
| |
| if v.dtype == torch.float8_e4m3fn: |
| v = v.to(torch.bfloat16) |
|
|
| scale = v.abs().max(dim=1, keepdim=True)[0] / 127.0 |
| scale = scale.clamp(min=1e-8) |
| weight_int8 = torch.round(v / scale).clamp(-127, 127).to(torch.int8) |
| |
| base_name = k.rsplit(".", 1)[0] |
| new_tensors[f"{base_name}.weight_int8"] = weight_int8 |
| new_tensors[f"{base_name}.weight_scale"] = scale.to(torch.bfloat16) |
| else: |
| new_tensors[k] = v.to(torch.bfloat16) if v.is_floating_point() else v |
|
|
| |
| else: |
| if v.is_floating_point(): |
| new_tensors[k] = v.to(target_dtype) |
| else: |
| new_tensors[k] = v |
|
|
| converted_path = "converted.safetensors" |
| save_file(new_tensors, converted_path) |
|
|
| del tensors |
| del new_tensors |
| gc.collect() |
|
|
| yield f"☁️ Uploading {precision} version of {file}..." |
| api.upload_file( |
| path_or_fileobj=converted_path, |
| path_in_repo=file, |
| repo_id=target_repo, |
| commit_message=f"Upload {precision} quantized {file}" |
| ) |
| |
| os.remove(converted_path) |
|
|
| else: |
| yield f"☁️ Copying {file} as-is..." |
| api.upload_file( |
| path_or_fileobj=local_path, |
| path_in_repo=file, |
| repo_id=target_repo, |
| commit_message=f"Copy {file} from original repo" |
| ) |
|
|
| success_count += 1 |
|
|
| if os.path.exists(cache_dir): |
| shutil.rmtree(cache_dir) |
| gc.collect() |
|
|
| except Exception as e: |
| error_count += 1 |
| yield f"⚠️ Error processing {file}: {str(e)}\nSkipping..." |
|
|
| if os.path.exists(cache_dir): |
| shutil.rmtree(cache_dir) |
|
|
| yield f"✅ Finished! Successfully processed {success_count} files. Errors encountered: {error_count}." |
|
|
|
|
| def update_target_repo(username, source, precision): |
| user_prefix = username.strip() if username.strip() else "your-username" |
| model_name = source.split("/")[-1] if "/" in source else source |
| return f"{user_prefix}/{model_name}-{precision}" |
|
|
| def update_warnings(precision): |
| if precision == "INT8": |
| return gr.update(value="⚠️ **INT8 Warning:** Modifies layer keys (`weight_int8`, `weight_scale`). Requires the custom `NativeInt8Linear` XPU inference code to run.", visible=True) |
| else: |
| return gr.update(visible=False) |
|
|
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🚀 Universal Z-Image Quantizer") |
| gr.Markdown( |
| "Convert sharded Z-Image models directly on Hugging Face to floating-point precisions (FP8/FP16/BF16) or dynamically trigger symmetric integer quantization (INT8)." |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| hf_token = gr.Textbox(label="Hugging Face Token (Write Access)", type="password", placeholder="hf_...") |
| hf_username = gr.Textbox(label="Hugging Face Username", placeholder="e.g., rootlocalghost") |
| |
| source_repo = gr.Dropdown( |
| choices=["your-username/Z-Image-Turbo", "your-username/Z-Image-Base"], |
| value="your-username/Z-Image-Turbo", |
| label="Source Repository", |
| allow_custom_value=True |
| ) |
| |
| target_components = gr.CheckboxGroup( |
| choices=["text_encoder", "transformer", "vae"], |
| value=["transformer"], |
| label="Components to Quantize" |
| ) |
| |
| precision = gr.Dropdown( |
| choices=["FP8", "FP16", "BF16", "INT8"], |
| value="INT8", |
| label="Target Precision" |
| ) |
| |
| int8_warning = gr.Markdown(visible=True, value="⚠️ **INT8 Warning:** Modifies layer keys (`weight_int8`, `weight_scale`). Requires the custom `NativeInt8Linear` XPU inference code to run.") |
| |
| target_repo = gr.Textbox(label="Target Repository (Auto-generated)", value="your-username/Z-Image-Turbo-INT8", interactive=True) |
| start_btn = gr.Button("Start Quantization & Upload", variant="primary") |
| |
| with gr.Column(scale=3): |
| output_log = gr.Textbox(label="Operation Logs", lines=20, interactive=False, max_lines=25) |
|
|
| inputs_to_watch = [hf_username, source_repo, precision] |
| for inp in inputs_to_watch: |
| inp.change(fn=update_target_repo, inputs=inputs_to_watch, outputs=[target_repo]) |
| |
| precision.change(fn=update_warnings, inputs=[precision], outputs=[int8_warning]) |
|
|
| start_btn.click( |
| fn=convert_and_upload, |
| inputs=[hf_token, source_repo, target_repo, precision, target_components], |
| outputs=[output_log] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |