import os import time import torch from safetensors.torch import load_file from diffusers import ZImagePipeline, ZImageTransformer2DModel from diffusers.models.attention_dispatch import attention_backend from huggingface_hub import hf_hub_download REPO_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_ID = os.environ.get("MODEL_ID", "ykarout/Z-Image-Turbo-FP8-Full") USE_LOCAL = os.environ.get("USE_LOCAL", "") == "1" or not MODEL_ID MODEL_SRC = REPO_DIR if USE_LOCAL else MODEL_ID # Select which FP8 transformer file to use (weights will be cast to BF16 for compute). if USE_LOCAL: TRANSFORMER_FP8 = os.path.join(REPO_DIR, "transformer", "diffusion_pytorch_model.safetensors") # E4M3FN (local) # TRANSFORMER_FP8 = os.path.join(REPO_DIR, "transformer", "diffusion_pytorch_model_e5m2.safetensors") # E5M2 (local) else: TRANSFORMER_FP8 = hf_hub_download( repo_id=MODEL_ID, filename="transformer/diffusion_pytorch_model.safetensors", local_files_only=USE_LOCAL, ) # TRANSFORMER_FP8 = hf_hub_download(repo_id=MODEL_ID, filename="transformer/diffusion_pytorch_model_e5m2.safetensors") def _strip_prefix(state_dict): prefix = "model.diffusion_model." return {k[len(prefix):] if k.startswith(prefix) else k: v for k, v in state_dict.items()} def load_fp8_transformer(fp8_path, config_root): # Load FP8 weights but cast to BF16 for compute (works on GPUs without FP8 support) raw = load_file(fp8_path) raw = _strip_prefix(raw) return ZImageTransformer2DModel.from_single_file( raw, config=config_root, subfolder="transformer", torch_dtype=torch.bfloat16, local_files_only=USE_LOCAL, low_cpu_mem_usage=False, ) # 1. Load the pipeline # BF16 compute path for GPUs without FP8 kernel support. pipe = ZImagePipeline.from_pretrained( MODEL_SRC, torch_dtype=torch.bfloat16, local_files_only=USE_LOCAL, low_cpu_mem_usage=False, ) pipe.transformer = load_fp8_transformer(TRANSFORMER_FP8, MODEL_SRC) # [Optional] Attention Backend # Diffusers uses SDPA by default. Switch to Flash Attention for better efficiency if supported: # pipe.transformer.set_attention_backend("flash") # Flash-Attention-2 # pipe.transformer.set_attention_backend("_flash_3") # Flash-Attention-3 # [Optional] Model Compilation # pipe.transformer.compile() # [Optional] CPU Offloading pipe.enable_model_cpu_offload() prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights." torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() start = time.perf_counter() # 2. Generate Image with attention_backend("_native_flash"), torch.inference_mode(): image = pipe( prompt=prompt, height=1024, width=1024, num_inference_steps=9, guidance_scale=0.0, generator=torch.Generator("cuda").manual_seed(42), ).images[0] torch.cuda.synchronize() elapsed_s = time.perf_counter() - start peak_gb = torch.cuda.max_memory_allocated() / (1024**3) image.save("example-bf16.png") print(f"saved example-bf16.png | elapsed_s={elapsed_s:.3f} peak_allocated_gb={peak_gb:.3f}")