""" 将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中 功能: 1. 加载基础 Flux Fill 模型的 Transformer 2. 加载 RL 训练的 PEFT LoRA 权重 3. 将 LoRA 权重合并到基础模型 4. 保存合并后的完整模型 使用方法: python combine_peft_weights.py \ --base_model_path /path/to/base/model \ --lora_weights_path /path/to/lora/weights \ --output_path /path/to/output \ --save_full_pipeline # 可选:保存完整 pipeline 而不只是 transformer """ import torch import argparse import os from pathlib import Path from diffusers import FluxFillPipeline from peft import PeftModel def merge_and_save_transformer( base_model_path: str, lora_weights_path: str, output_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu" ): """ 合并 LoRA 权重到 Transformer 并保存 Args: base_model_path: 基础 Flux Fill 模型路径 lora_weights_path: PEFT LoRA 权重路径 output_path: 输出路径(保存合并后的 transformer) dtype: 数据类型 device: 加载设备(建议用 CPU 以节省显存) """ print("=" * 80) print("Step 1: 加载基础 Flux Fill 模型...") print("=" * 80) # 加载基础模型(只加载 transformer 部分以节省内存) pipe = FluxFillPipeline.from_pretrained( base_model_path, torch_dtype=dtype, low_cpu_mem_usage=True ) print(f"✓ 基础模型加载完成: {base_model_path}") print(f" Transformer 参数量: {sum(p.numel() for p in pipe.transformer.parameters()) / 1e9:.2f}B") # 移动到指定设备 if device != "cpu": print(f" 移动 transformer 到 {device}...") pipe.transformer = pipe.transformer.to(device) print("\n" + "=" * 80) print("Step 2: 加载 PEFT LoRA 权重...") print("=" * 80) # 加载 PEFT 模型 print(f" 从 {lora_weights_path} 加载 LoRA 权重...") peft_model = PeftModel.from_pretrained( pipe.transformer, lora_weights_path, is_trainable=False ) peft_model.set_adapter("default") print(f"✓ PEFT 模型加载完成") # 检查 LoRA 配置 lora_config = peft_model.peft_config.get("default", None) if lora_config: print(f" LoRA 配置:") print(f" - Rank (r): {lora_config.r}") print(f" - Alpha: {lora_config.lora_alpha}") print(f" - Dropout: {lora_config.lora_dropout}") print(f" - Target modules: {lora_config.target_modules}") print("\n" + "=" * 80) print("Step 3: 合并 LoRA 权重到基础模型...") print("=" * 80) # 合并权重 merged_model = peft_model.merge_and_unload() print(f"✓ 权重合并完成") print(f" 合并后模型参数量: {sum(p.numel() for p in merged_model.parameters()) / 1e9:.2f}B") print("\n" + "=" * 80) print("Step 4: 保存合并后的模型...") print("=" * 80) # 创建输出目录 os.makedirs(output_path, exist_ok=True) # 保存合并后的 transformer print(f" 保存到 {output_path}...") merged_model.save_pretrained( output_path, safe_serialization=True, # 使用 safetensors 格式 max_shard_size="5GB" # 分片大小 ) print(f"✓ 模型保存完成: {output_path}") # 保存模型配置信息 info_path = os.path.join(output_path, "merge_info.txt") with open(info_path, "w") as f: f.write(f"Base model: {base_model_path}\n") f.write(f"LoRA weights: {lora_weights_path}\n") f.write(f"Merged model: {output_path}\n") f.write(f"Data type: {dtype}\n") if lora_config: f.write(f"\nLoRA Configuration:\n") f.write(f" Rank (r): {lora_config.r}\n") f.write(f" Alpha: {lora_config.lora_alpha}\n") f.write(f" Dropout: {lora_config.lora_dropout}\n") f.write(f" Target modules: {lora_config.target_modules}\n") print(f"✓ 合并信息保存到: {info_path}") return merged_model def merge_and_save_full_pipeline( base_model_path: str, lora_weights_path: str, output_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu" ): """ 合并 LoRA 权重并保存完整的 FluxFillPipeline Args: base_model_path: 基础 Flux Fill 模型路径 lora_weights_path: PEFT LoRA 权重路径 output_path: 输出路径(保存完整 pipeline) dtype: 数据类型 device: 加载设备 """ print("=" * 80) print("Step 1: 加载基础 Flux Fill Pipeline...") print("=" * 80) # 加载完整 pipeline pipe = FluxFillPipeline.from_pretrained( base_model_path, torch_dtype=dtype, low_cpu_mem_usage=True ) print(f"✓ Pipeline 加载完成: {base_model_path}") # 移动到指定设备 if device != "cpu": print(f" 移动 transformer 到 {device}...") pipe.transformer = pipe.transformer.to(device) print("\n" + "=" * 80) print("Step 2: 加载并合并 PEFT LoRA 权重...") print("=" * 80) # 加载 PEFT 模型 peft_model = PeftModel.from_pretrained( pipe.transformer, lora_weights_path, is_trainable=False ) peft_model.set_adapter("default") # 合并权重 merged_transformer = peft_model.merge_and_unload() # 替换 pipeline 中的 transformer pipe.transformer = merged_transformer print(f"✓ 权重合并完成") print("\n" + "=" * 80) print("Step 3: 保存完整 Pipeline...") print("=" * 80) # 创建输出目录 os.makedirs(output_path, exist_ok=True) # 保存完整 pipeline print(f" 保存到 {output_path}...") pipe.save_pretrained( output_path, safe_serialization=True, max_shard_size="5GB" ) print(f"✓ 完整 Pipeline 保存完成: {output_path}") # 保存合并信息 info_path = os.path.join(output_path, "merge_info.txt") with open(info_path, "w") as f: f.write(f"Base model: {base_model_path}\n") f.write(f"LoRA weights: {lora_weights_path}\n") f.write(f"Merged pipeline: {output_path}\n") f.write(f"Data type: {dtype}\n") f.write(f"Components saved:\n") f.write(f" - Transformer (merged with LoRA)\n") f.write(f" - VAE\n") f.write(f" - Text Encoder\n") f.write(f" - Scheduler\n") f.write(f" - Other components\n") print(f"✓ 合并信息保存到: {info_path}") return pipe def main(): parser = argparse.ArgumentParser( description="将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中" ) parser.add_argument( "--base_model_path", type=str, default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4", help="基础 Flux Fill 模型路径" ) parser.add_argument( "--lora_weights_path", type=str, default="/home/tione/notebook2/research/retofan/code/RL/flow_grpo/logs/defectgen_det/flux_fill_redux/checkpoints/checkpoint-60/lora", help="PEFT LoRA 权重路径" ) parser.add_argument( "--output_path", type=str, default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4_defect_gen_det_e60", help="输出路径(保存合并后的模型)" ) parser.add_argument( "--save_full_pipeline", action="store_true", help="保存完整 FluxFillPipeline(包含 VAE、Text Encoder 等),而不只是 Transformer" ) parser.add_argument( "--dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], help="数据类型" ) parser.add_argument( "--device", type=str, default="cpu", help="加载设备(cpu 或 cuda:0 等)。建议使用 cpu 以节省显存" ) args = parser.parse_args() # 转换 dtype dtype_map = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16 } dtype = dtype_map[args.dtype] # 检查路径 if not os.path.exists(args.base_model_path): print(f"错误: 基础模型路径不存在: {args.base_model_path}") return if not os.path.exists(args.lora_weights_path): print(f"错误: LoRA 权重路径不存在: {args.lora_weights_path}") return print("\n" + "=" * 80) print("PEFT LoRA 权重合并工具") print("=" * 80) print(f"基础模型: {args.base_model_path}") print(f"LoRA 权重: {args.lora_weights_path}") print(f"输出路径: {args.output_path}") print(f"保存类型: {'完整 Pipeline' if args.save_full_pipeline else '仅 Transformer'}") print(f"数据类型: {args.dtype}") print(f"加载设备: {args.device}") print("=" * 80 + "\n") try: if args.save_full_pipeline: # 保存完整 pipeline merge_and_save_full_pipeline( base_model_path=args.base_model_path, lora_weights_path=args.lora_weights_path, output_path=args.output_path, dtype=dtype, device=args.device ) else: # 只保存 transformer merge_and_save_transformer( base_model_path=args.base_model_path, lora_weights_path=args.lora_weights_path, output_path=args.output_path, dtype=dtype, device=args.device ) print("\n" + "=" * 80) print("✅ 合并完成!") print("=" * 80) print(f"\n合并后的模型已保存到: {args.output_path}") print("\n使用方法:") if args.save_full_pipeline: print(" # 直接加载合并后的完整 pipeline") print(" from diffusers import FluxFillPipeline") print(f" pipe = FluxFillPipeline.from_pretrained('{args.output_path}')") else: print(" # 加载基础 pipeline,然后替换 transformer") print(" from diffusers import FluxFillPipeline") print(" from diffusers.models import FluxTransformer2DModel") print(f" pipe = FluxFillPipeline.from_pretrained('{args.base_model_path}')") print(f" pipe.transformer = FluxTransformer2DModel.from_pretrained('{args.output_path}')") print("\n" + "=" * 80) except Exception as e: print(f"\n❌ 错误: {e}") import traceback traceback.print_exc() return 1 return 0 if __name__ == "__main__": exit(main())