rootlocalghost's picture
Update app.py
17e0b73 verified
Raw
History Blame Contribute Delete
8.74 kB
import os
import gc
import torch
import shutil
import uuid
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file, save_file
def convert_and_upload(token, source_repo, target_repo, precision, target_components):
if not token:
yield "❌ Error: Please provide a valid Hugging Face Write Token."
return
if not target_repo.strip() or "your-username" in target_repo:
yield "❌ Error: Please specify a valid Target Repository (e.g., your-username/repo-name)."
return
if not target_components:
yield "❌ Error: Please select at least one component to quantize."
return
target_dtype = None
is_int8 = False
if precision == "FP8": target_dtype = torch.float8_e4m3fn
elif precision == "FP16": target_dtype = torch.float16
elif precision == "BF16": target_dtype = torch.bfloat16
elif precision == "INT8": is_int8 = True
api = HfApi(token=token)
yield f"🔄 Connecting to Hugging Face and verifying target repo: {target_repo}..."
try:
api.create_repo(repo_id=target_repo, exist_ok=True, private=False)
except Exception as e:
yield f"❌ Error checking/creating repo: {str(e)}\nMake sure your token has 'Write' permissions."
return
yield f"📋 Fetching file list from {source_repo}..."
try:
files = api.list_repo_files(source_repo)
except Exception as e:
yield f"❌ Error fetching files: {str(e)}"
return
cache_dir = f"./hf_cache_{uuid.uuid4().hex[:8]}"
success_count = 0
error_count = 0
# Z-IMAGE SPECIFIC EXCLUSIONS
# Protects the DiT's embedders/final layers and the Text Encoder's sensitive norms from INT8 destruction
exclude_prefixes = [
"t_embedder", "cap_embedder", "all_x_embedder", "all_final_layer", "rope_embedder",
"embed_tokens", "norm", "ln_", "shared"
]
for file in files:
is_root_safetensor = "/" not in file and file.endswith(".safetensors")
if is_root_safetensor:
yield f"🗑️ Auto-skipping root model: {file}..."
try:
api.delete_file(path_in_repo=file, repo_id=target_repo, token=token, commit_message=f"Auto-deleted {file}")
except Exception:
pass
continue
yield f"⏳ Processing {file}..."
try:
os.makedirs(cache_dir, exist_ok=True)
local_path = hf_hub_download(
repo_id=source_repo,
filename=file,
cache_dir=cache_dir,
token=token
)
in_target_component = any(f"{comp}/" in file for comp in target_components)
if file.endswith(".safetensors") and in_target_component:
yield f"🧠 Quantizing {file} to {precision}..."
tensors = load_file(local_path)
new_tensors = {}
for k, v in tensors.items():
# --- BRANCH 1: INT8 Symmetric Quantization ---
if is_int8:
is_2d_weight = "weight" in k and len(v.shape) == 2
is_excluded = any(ex in k for ex in exclude_prefixes)
if is_2d_weight and not is_excluded:
# Upcast to BF16 for math
if v.dtype == torch.float8_e4m3fn:
v = v.to(torch.bfloat16)
scale = v.abs().max(dim=1, keepdim=True)[0] / 127.0
scale = scale.clamp(min=1e-8)
weight_int8 = torch.round(v / scale).clamp(-127, 127).to(torch.int8)
base_name = k.rsplit(".", 1)[0]
new_tensors[f"{base_name}.weight_int8"] = weight_int8
new_tensors[f"{base_name}.weight_scale"] = scale.to(torch.bfloat16)
else:
new_tensors[k] = v.to(torch.bfloat16) if v.is_floating_point() else v
# --- BRANCH 2: Standard Floating Point Casting ---
else:
if v.is_floating_point():
new_tensors[k] = v.to(target_dtype)
else:
new_tensors[k] = v
converted_path = "converted.safetensors"
save_file(new_tensors, converted_path)
del tensors
del new_tensors
gc.collect()
yield f"☁️ Uploading {precision} version of {file}..."
api.upload_file(
path_or_fileobj=converted_path,
path_in_repo=file,
repo_id=target_repo,
commit_message=f"Upload {precision} quantized {file}"
)
os.remove(converted_path)
else:
yield f"☁️ Copying {file} as-is..."
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=file,
repo_id=target_repo,
commit_message=f"Copy {file} from original repo"
)
success_count += 1
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
gc.collect()
except Exception as e:
error_count += 1
yield f"⚠️ Error processing {file}: {str(e)}\nSkipping..."
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
yield f"✅ Finished! Successfully processed {success_count} files. Errors encountered: {error_count}."
def update_target_repo(username, source, precision):
user_prefix = username.strip() if username.strip() else "your-username"
model_name = source.split("/")[-1] if "/" in source else source
return f"{user_prefix}/{model_name}-{precision}"
def update_warnings(precision):
if precision == "INT8":
return gr.update(value="⚠️ **INT8 Warning:** Modifies layer keys (`weight_int8`, `weight_scale`). Requires the custom `NativeInt8Linear` XPU inference code to run.", visible=True)
else:
return gr.update(visible=False)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🚀 Universal Z-Image Quantizer")
gr.Markdown(
"Convert sharded Z-Image models directly on Hugging Face to floating-point precisions (FP8/FP16/BF16) or dynamically trigger symmetric integer quantization (INT8)."
)
with gr.Row():
with gr.Column(scale=2):
hf_token = gr.Textbox(label="Hugging Face Token (Write Access)", type="password", placeholder="hf_...")
hf_username = gr.Textbox(label="Hugging Face Username", placeholder="e.g., rootlocalghost")
source_repo = gr.Dropdown(
choices=["your-username/Z-Image-Turbo", "your-username/Z-Image-Base"],
value="your-username/Z-Image-Turbo",
label="Source Repository",
allow_custom_value=True
)
target_components = gr.CheckboxGroup(
choices=["text_encoder", "transformer", "vae"],
value=["transformer"],
label="Components to Quantize"
)
precision = gr.Dropdown(
choices=["FP8", "FP16", "BF16", "INT8"],
value="INT8",
label="Target Precision"
)
int8_warning = gr.Markdown(visible=True, value="⚠️ **INT8 Warning:** Modifies layer keys (`weight_int8`, `weight_scale`). Requires the custom `NativeInt8Linear` XPU inference code to run.")
target_repo = gr.Textbox(label="Target Repository (Auto-generated)", value="your-username/Z-Image-Turbo-INT8", interactive=True)
start_btn = gr.Button("Start Quantization & Upload", variant="primary")
with gr.Column(scale=3):
output_log = gr.Textbox(label="Operation Logs", lines=20, interactive=False, max_lines=25)
inputs_to_watch = [hf_username, source_repo, precision]
for inp in inputs_to_watch:
inp.change(fn=update_target_repo, inputs=inputs_to_watch, outputs=[target_repo])
precision.change(fn=update_warnings, inputs=[precision], outputs=[int8_warning])
start_btn.click(
fn=convert_and_upload,
inputs=[hf_token, source_repo, target_repo, precision, target_components],
outputs=[output_log]
)
if __name__ == "__main__":
demo.launch()