import os import gc import glob import json import torch import torch.nn as nn import gradio as gr from huggingface_hub import HfApi, snapshot_download from safetensors.torch import load_file from torchao.quantization import quantize_, Int8WeightOnlyConfig def quantize_and_upload(source_repo, target_repo, hf_token, mount_path): if not source_repo or not target_repo or not hf_token or not mount_path: yield "❌ Error: Please fill in all fields." return if not os.path.exists(mount_path): yield f"❌ Error: Mount path '{mount_path}' does not exist. Check your Space settings." return api = HfApi(token=hf_token) work_dir = os.path.join(mount_path, "quantization_workspace") os.makedirs(work_dir, exist_ok=True) try: yield f"🚀 Creating target repository: {target_repo}..." api.create_repo(repo_id=target_repo, exist_ok=True, private=False) yield f"⬇️ Downloading full repository to bucket at {work_dir}..." local_dir = snapshot_download( repo_id=source_repo, local_dir=work_dir, token=hf_token ) shard_files = sorted(glob.glob(os.path.join(local_dir, "transformer", "*.safetensors"))) yield f"🧠 Found {len(shard_files)} shards. Beginning Sequential Quantization..." # Process one shard at a time to keep RAM usage strictly under 10GB for i, filepath in enumerate(shard_files): filename = os.path.basename(filepath) yield f" -> Processing Shard {i+1}/{len(shard_files)}: {filename}" # 1. Load shard into RAM shard_dict = load_file(filepath) new_shard_dict = {} handled_biases = set() # 2. Iterate through tensors for k, v in shard_dict.items(): if k in handled_biases: continue # Identify Linear Layer weights (2D tensors) if k.endswith(".weight") and v.dim() == 2: out_features, in_features = v.shape # Check if this linear layer has a corresponding bias in the shard bias_key = k.replace(".weight", ".bias") has_bias = bias_key in shard_dict # Create a temporary dummy module to feed to torchao dummy = nn.Linear(in_features, out_features, bias=has_bias) dummy.weight.data = v if has_bias: dummy.bias.data = shard_dict[bias_key] handled_biases.add(bias_key) # Apply INT8 Quantization in-place quantize_(dummy, Int8WeightOnlyConfig()) # Extract the custom INT8 tensor subclasses back to the dictionary q_state = dummy.state_dict() new_shard_dict[k] = q_state["weight"] if has_bias: new_shard_dict[bias_key] = q_state["bias"] else: # 1D tensors (LayerNorm, etc.) pass through untouched new_shard_dict[k] = v # 3. Save the quantized shard as a .pt file new_filename = filename.replace(".safetensors", ".pt") new_filepath = os.path.join(local_dir, "transformer", new_filename) torch.save(new_shard_dict, new_filepath) yield f" ✅ Saved {new_filename}. Freeing RAM and deleting BF16 source..." # 4. Aggressive memory cleanup & disk cleanup del shard_dict del new_shard_dict gc.collect() os.remove(filepath) yield "🔗 Updating model index file to point to new INT8 .pt shards..." index_path = os.path.join(local_dir, "transformer", "diffusion_pytorch_model.safetensors.index.json") if os.path.exists(index_path): with open(index_path, "r") as f: index_data = json.load(f) # Rewrite safetensors references to .pt for key, shard_name in index_data["weight_map"].items(): index_data["weight_map"][key] = shard_name.replace(".safetensors", ".pt") new_index_path = os.path.join(local_dir, "transformer", "diffusion_pytorch_model.pt.index.json") with open(new_index_path, "w") as f: json.dump(index_data, f, indent=2) os.remove(index_path) yield f"☁️ Uploading the complete INT8 repository to {target_repo}..." api.upload_folder( folder_path=local_dir, repo_id=target_repo, repo_type="model", commit_message="Upload sequential INT8 quantized shards" ) yield f"✅ Success! Your INT8 model is ready at: https://huggingface.co/{target_repo}" except Exception as e: yield f"❌ Process failed: {str(e)}" # --- Gradio UI Layout --- with gr.Blocks() as app: gr.Markdown("# 🚀 Krea-2 Transformer INT8 Sequential Quantizer") gr.Markdown("Downloads a BF16 repo directly into a bucket, processes weights shard-by-shard to bypass RAM limits, quantizes to INT8, updates the index, and uploads.") with gr.Row(): source_repo = gr.Textbox(label="Source Repository", value="rootlocalghost/Krea-2-Turbo-BF16", interactive=True) target_repo = gr.Textbox(label="Target Repository (e.g., your-username/Krea-2-INT8)", placeholder="username/Repo-Name") with gr.Row(): hf_token = gr.Textbox(label="Hugging Face Write Token", type="password", placeholder="hf_...") mount_path = gr.Textbox(label="Bucket Mount Path", value="/data", interactive=True) start_btn = gr.Button("Start Quantization & Upload", variant="primary") output_log = gr.Textbox(label="Process Logs", lines=15, interactive=False) start_btn.click( fn=quantize_and_upload, inputs=[source_repo, target_repo, hf_token, mount_path], outputs=output_log ) if __name__ == "__main__": app.launch(theme=gr.themes.Monochrome())