import os import gc import torch import shutil import uuid import gradio as gr from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import load_file, save_file def convert_and_upload(token, source_repo, target_repo, precision, target_components): if not token: yield "❌ Error: Please provide a valid Hugging Face Write Token." return if not target_repo.strip() or "your-username" in target_repo: yield "❌ Error: Please specify a valid Target Repository (e.g., your-username/repo-name)." return if not target_components: yield "❌ Error: Please select at least one component to quantize." return target_dtype = None is_int8 = False if precision == "FP8": target_dtype = torch.float8_e4m3fn elif precision == "FP16": target_dtype = torch.float16 elif precision == "BF16": target_dtype = torch.bfloat16 elif precision == "INT8": is_int8 = True api = HfApi(token=token) yield f"🔄 Connecting to Hugging Face and verifying target repo: {target_repo}..." try: api.create_repo(repo_id=target_repo, exist_ok=True, private=False) except Exception as e: yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions." return yield f"📋 Fetching file list from {source_repo}..." try: files = api.list_repo_files(source_repo) except Exception as e: yield f"❌ Error fetching files: {str(e)}" return cache_dir = f"./hf_cache_{uuid.uuid4().hex[:8]}" success_count = 0 error_count = 0 # Z-IMAGE SPECIFIC EXCLUSIONS # Protects the DiT's embedders/final layers and the Text Encoder's sensitive norms from INT8 destruction exclude_prefixes = [ "t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer", "rope_embedder", "embed_tokens", "norm", "ln_", "shared" ] for file in files: is_root_safetensor = "/" not in file and file.endswith(".safetensors") if is_root_safetensor: yield f"🗑️ Auto-skipping root model: {file}..." try: api.delete_file(path_in_repo=file, repo_id=target_repo, token=token, commit_message=f"Auto-deleted {file}") except Exception: pass continue yield f"⏳ Processing {file}..." try: os.makedirs(cache_dir, exist_ok=True) local_path = hf_hub_download( repo_id=source_repo, filename=file, cache_dir=cache_dir, token=token ) in_target_component = any(f"{comp}/" in file for comp in target_components) if file.endswith(".safetensors") and in_target_component: yield f"🧠 Quantizing {file} to {precision}..." tensors = load_file(local_path) new_tensors = {} for k, v in tensors.items(): # --- BRANCH 1: INT8 Symmetric Quantization --- if is_int8: is_2d_weight = "weight" in k and len(v.shape) == 2 is_excluded = any(ex in k for ex in exclude_prefixes) if is_2d_weight and not is_excluded: # Upcast to BF16 for math if v.dtype == torch.float8_e4m3fn: v = v.to(torch.bfloat16) scale = v.abs().max(dim=1, keepdim=True)[0] / 127.0 scale = scale.clamp(min=1e-8) weight_int8 = torch.round(v / scale).clamp(-127, 127).to(torch.int8) base_name = k.rsplit(".", 1)[0] new_tensors[f"{base_name}.weight_int8"] = weight_int8 new_tensors[f"{base_name}.weight_scale"] = scale.to(torch.bfloat16) else: new_tensors[k] = v.to(torch.bfloat16) if v.is_floating_point() else v # --- BRANCH 2: Standard Floating Point Casting --- else: if v.is_floating_point(): new_tensors[k] = v.to(target_dtype) else: new_tensors[k] = v converted_path = "converted.safetensors" save_file(new_tensors, converted_path) del tensors del new_tensors gc.collect() yield f"☁️ Uploading {precision} version of {file}..." api.upload_file( path_or_fileobj=converted_path, path_in_repo=file, repo_id=target_repo, commit_message=f"Upload {precision} quantized {file}" ) os.remove(converted_path) else: yield f"☁️ Copying {file} as-is..." api.upload_file( path_or_fileobj=local_path, path_in_repo=file, repo_id=target_repo, commit_message=f"Copy {file} from original repo" ) success_count += 1 if os.path.exists(cache_dir): shutil.rmtree(cache_dir) gc.collect() except Exception as e: error_count += 1 yield f"⚠️ Error processing {file}: {str(e)}\nSkipping..." if os.path.exists(cache_dir): shutil.rmtree(cache_dir) yield f"✅ Finished! Successfully processed {success_count} files. Errors encountered: {error_count}." def update_target_repo(username, source, precision): user_prefix = username.strip() if username.strip() else "your-username" model_name = source.split("/")[-1] if "/" in source else source return f"{user_prefix}/{model_name}-{precision}" def update_warnings(precision): if precision == "INT8": return gr.update(value="⚠️ **INT8 Warning:** Modifies layer keys (`weight_int8`, `weight_scale`). Requires the custom `NativeInt8Linear` XPU inference code to run.", visible=True) else: return gr.update(visible=False) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🚀 Universal Z-Image Quantizer") gr.Markdown( "Convert sharded Z-Image models directly on Hugging Face to floating-point precisions (FP8/FP16/BF16) or dynamically trigger symmetric integer quantization (INT8)." ) with gr.Row(): with gr.Column(scale=2): hf_token = gr.Textbox(label="Hugging Face Token (Write Access)", type="password", placeholder="hf_...") hf_username = gr.Textbox(label="Hugging Face Username", placeholder="e.g., rootlocalghost") source_repo = gr.Dropdown( choices=["your-username/Z-Image-Turbo", "your-username/Z-Image-Base"], value="your-username/Z-Image-Turbo", label="Source Repository", allow_custom_value=True ) target_components = gr.CheckboxGroup( choices=["text_encoder", "transformer", "vae"], value=["transformer"], label="Components to Quantize" ) precision = gr.Dropdown( choices=["FP8", "FP16", "BF16", "INT8"], value="INT8", label="Target Precision" ) int8_warning = gr.Markdown(visible=True, value="⚠️ **INT8 Warning:** Modifies layer keys (`weight_int8`, `weight_scale`). Requires the custom `NativeInt8Linear` XPU inference code to run.") target_repo = gr.Textbox(label="Target Repository (Auto-generated)", value="your-username/Z-Image-Turbo-INT8", interactive=True) start_btn = gr.Button("Start Quantization & Upload", variant="primary") with gr.Column(scale=3): output_log = gr.Textbox(label="Operation Logs", lines=20, interactive=False, max_lines=25) inputs_to_watch = [hf_username, source_repo, precision] for inp in inputs_to_watch: inp.change(fn=update_target_repo, inputs=inputs_to_watch, outputs=[target_repo]) precision.change(fn=update_warnings, inputs=[precision], outputs=[int8_warning]) start_btn.click( fn=convert_and_upload, inputs=[hf_token, source_repo, target_repo, precision, target_components], outputs=[output_log] ) if __name__ == "__main__": demo.launch()