| """ |
| 将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中 |
| |
| 功能: |
| 1. 加载基础 Flux Fill 模型的 Transformer |
| 2. 加载 RL 训练的 PEFT LoRA 权重 |
| 3. 将 LoRA 权重合并到基础模型 |
| 4. 保存合并后的完整模型 |
| |
| 使用方法: |
| python combine_peft_weights.py \ |
| --base_model_path /path/to/base/model \ |
| --lora_weights_path /path/to/lora/weights \ |
| --output_path /path/to/output \ |
| --save_full_pipeline # 可选:保存完整 pipeline 而不只是 transformer |
| """ |
|
|
| import torch |
| import argparse |
| import os |
| from pathlib import Path |
| from diffusers import FluxFillPipeline |
| from peft import PeftModel |
|
|
|
|
| def merge_and_save_transformer( |
| base_model_path: str, |
| lora_weights_path: str, |
| output_path: str, |
| dtype: torch.dtype = torch.bfloat16, |
| device: str = "cpu" |
| ): |
| """ |
| 合并 LoRA 权重到 Transformer 并保存 |
| |
| Args: |
| base_model_path: 基础 Flux Fill 模型路径 |
| lora_weights_path: PEFT LoRA 权重路径 |
| output_path: 输出路径(保存合并后的 transformer) |
| dtype: 数据类型 |
| device: 加载设备(建议用 CPU 以节省显存) |
| """ |
| print("=" * 80) |
| print("Step 1: 加载基础 Flux Fill 模型...") |
| print("=" * 80) |
| |
| |
| pipe = FluxFillPipeline.from_pretrained( |
| base_model_path, |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True |
| ) |
| |
| print(f"✓ 基础模型加载完成: {base_model_path}") |
| print(f" Transformer 参数量: {sum(p.numel() for p in pipe.transformer.parameters()) / 1e9:.2f}B") |
| |
| |
| if device != "cpu": |
| print(f" 移动 transformer 到 {device}...") |
| pipe.transformer = pipe.transformer.to(device) |
| |
| print("\n" + "=" * 80) |
| print("Step 2: 加载 PEFT LoRA 权重...") |
| print("=" * 80) |
| |
| |
| print(f" 从 {lora_weights_path} 加载 LoRA 权重...") |
| peft_model = PeftModel.from_pretrained( |
| pipe.transformer, |
| lora_weights_path, |
| is_trainable=False |
| ) |
| peft_model.set_adapter("default") |
| |
| print(f"✓ PEFT 模型加载完成") |
| |
| |
| lora_config = peft_model.peft_config.get("default", None) |
| if lora_config: |
| print(f" LoRA 配置:") |
| print(f" - Rank (r): {lora_config.r}") |
| print(f" - Alpha: {lora_config.lora_alpha}") |
| print(f" - Dropout: {lora_config.lora_dropout}") |
| print(f" - Target modules: {lora_config.target_modules}") |
| |
| print("\n" + "=" * 80) |
| print("Step 3: 合并 LoRA 权重到基础模型...") |
| print("=" * 80) |
| |
| |
| merged_model = peft_model.merge_and_unload() |
| |
| print(f"✓ 权重合并完成") |
| print(f" 合并后模型参数量: {sum(p.numel() for p in merged_model.parameters()) / 1e9:.2f}B") |
| |
| print("\n" + "=" * 80) |
| print("Step 4: 保存合并后的模型...") |
| print("=" * 80) |
| |
| |
| os.makedirs(output_path, exist_ok=True) |
| |
| |
| print(f" 保存到 {output_path}...") |
| merged_model.save_pretrained( |
| output_path, |
| safe_serialization=True, |
| max_shard_size="5GB" |
| ) |
| |
| print(f"✓ 模型保存完成: {output_path}") |
| |
| |
| info_path = os.path.join(output_path, "merge_info.txt") |
| with open(info_path, "w") as f: |
| f.write(f"Base model: {base_model_path}\n") |
| f.write(f"LoRA weights: {lora_weights_path}\n") |
| f.write(f"Merged model: {output_path}\n") |
| f.write(f"Data type: {dtype}\n") |
| if lora_config: |
| f.write(f"\nLoRA Configuration:\n") |
| f.write(f" Rank (r): {lora_config.r}\n") |
| f.write(f" Alpha: {lora_config.lora_alpha}\n") |
| f.write(f" Dropout: {lora_config.lora_dropout}\n") |
| f.write(f" Target modules: {lora_config.target_modules}\n") |
| |
| print(f"✓ 合并信息保存到: {info_path}") |
| |
| return merged_model |
|
|
|
|
| def merge_and_save_full_pipeline( |
| base_model_path: str, |
| lora_weights_path: str, |
| output_path: str, |
| dtype: torch.dtype = torch.bfloat16, |
| device: str = "cpu" |
| ): |
| """ |
| 合并 LoRA 权重并保存完整的 FluxFillPipeline |
| |
| Args: |
| base_model_path: 基础 Flux Fill 模型路径 |
| lora_weights_path: PEFT LoRA 权重路径 |
| output_path: 输出路径(保存完整 pipeline) |
| dtype: 数据类型 |
| device: 加载设备 |
| """ |
| print("=" * 80) |
| print("Step 1: 加载基础 Flux Fill Pipeline...") |
| print("=" * 80) |
| |
| |
| pipe = FluxFillPipeline.from_pretrained( |
| base_model_path, |
| torch_dtype=dtype, |
| low_cpu_mem_usage=True |
| ) |
| |
| print(f"✓ Pipeline 加载完成: {base_model_path}") |
| |
| |
| if device != "cpu": |
| print(f" 移动 transformer 到 {device}...") |
| pipe.transformer = pipe.transformer.to(device) |
| |
| print("\n" + "=" * 80) |
| print("Step 2: 加载并合并 PEFT LoRA 权重...") |
| print("=" * 80) |
| |
| |
| peft_model = PeftModel.from_pretrained( |
| pipe.transformer, |
| lora_weights_path, |
| is_trainable=False |
| ) |
| peft_model.set_adapter("default") |
| |
| |
| merged_transformer = peft_model.merge_and_unload() |
| |
| |
| pipe.transformer = merged_transformer |
| |
| print(f"✓ 权重合并完成") |
| |
| print("\n" + "=" * 80) |
| print("Step 3: 保存完整 Pipeline...") |
| print("=" * 80) |
| |
| |
| os.makedirs(output_path, exist_ok=True) |
| |
| |
| print(f" 保存到 {output_path}...") |
| pipe.save_pretrained( |
| output_path, |
| safe_serialization=True, |
| max_shard_size="5GB" |
| ) |
| |
| print(f"✓ 完整 Pipeline 保存完成: {output_path}") |
| |
| |
| info_path = os.path.join(output_path, "merge_info.txt") |
| with open(info_path, "w") as f: |
| f.write(f"Base model: {base_model_path}\n") |
| f.write(f"LoRA weights: {lora_weights_path}\n") |
| f.write(f"Merged pipeline: {output_path}\n") |
| f.write(f"Data type: {dtype}\n") |
| f.write(f"Components saved:\n") |
| f.write(f" - Transformer (merged with LoRA)\n") |
| f.write(f" - VAE\n") |
| f.write(f" - Text Encoder\n") |
| f.write(f" - Scheduler\n") |
| f.write(f" - Other components\n") |
| |
| print(f"✓ 合并信息保存到: {info_path}") |
| |
| return pipe |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中" |
| ) |
| |
| parser.add_argument( |
| "--base_model_path", |
| type=str, |
| default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4", |
| help="基础 Flux Fill 模型路径" |
| ) |
| |
| parser.add_argument( |
| "--lora_weights_path", |
| type=str, |
| default="/home/tione/notebook2/research/retofan/code/RL/flow_grpo/logs/defectgen_det/flux_fill_redux/checkpoints/checkpoint-60/lora", |
| help="PEFT LoRA 权重路径" |
| ) |
| |
| parser.add_argument( |
| "--output_path", |
| type=str, |
| default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4_defect_gen_det_e60", |
| help="输出路径(保存合并后的模型)" |
| ) |
| |
| parser.add_argument( |
| "--save_full_pipeline", |
| action="store_true", |
| help="保存完整 FluxFillPipeline(包含 VAE、Text Encoder 等),而不只是 Transformer" |
| ) |
| |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="bfloat16", |
| choices=["float32", "float16", "bfloat16"], |
| help="数据类型" |
| ) |
| |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="cpu", |
| help="加载设备(cpu 或 cuda:0 等)。建议使用 cpu 以节省显存" |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| dtype_map = { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16 |
| } |
| dtype = dtype_map[args.dtype] |
| |
| |
| if not os.path.exists(args.base_model_path): |
| print(f"错误: 基础模型路径不存在: {args.base_model_path}") |
| return |
| |
| if not os.path.exists(args.lora_weights_path): |
| print(f"错误: LoRA 权重路径不存在: {args.lora_weights_path}") |
| return |
| |
| print("\n" + "=" * 80) |
| print("PEFT LoRA 权重合并工具") |
| print("=" * 80) |
| print(f"基础模型: {args.base_model_path}") |
| print(f"LoRA 权重: {args.lora_weights_path}") |
| print(f"输出路径: {args.output_path}") |
| print(f"保存类型: {'完整 Pipeline' if args.save_full_pipeline else '仅 Transformer'}") |
| print(f"数据类型: {args.dtype}") |
| print(f"加载设备: {args.device}") |
| print("=" * 80 + "\n") |
| |
| try: |
| if args.save_full_pipeline: |
| |
| merge_and_save_full_pipeline( |
| base_model_path=args.base_model_path, |
| lora_weights_path=args.lora_weights_path, |
| output_path=args.output_path, |
| dtype=dtype, |
| device=args.device |
| ) |
| else: |
| |
| merge_and_save_transformer( |
| base_model_path=args.base_model_path, |
| lora_weights_path=args.lora_weights_path, |
| output_path=args.output_path, |
| dtype=dtype, |
| device=args.device |
| ) |
| |
| print("\n" + "=" * 80) |
| print("✅ 合并完成!") |
| print("=" * 80) |
| print(f"\n合并后的模型已保存到: {args.output_path}") |
| print("\n使用方法:") |
| |
| if args.save_full_pipeline: |
| print(" # 直接加载合并后的完整 pipeline") |
| print(" from diffusers import FluxFillPipeline") |
| print(f" pipe = FluxFillPipeline.from_pretrained('{args.output_path}')") |
| else: |
| print(" # 加载基础 pipeline,然后替换 transformer") |
| print(" from diffusers import FluxFillPipeline") |
| print(" from diffusers.models import FluxTransformer2DModel") |
| print(f" pipe = FluxFillPipeline.from_pretrained('{args.base_model_path}')") |
| print(f" pipe.transformer = FluxTransformer2DModel.from_pretrained('{args.output_path}')") |
| |
| print("\n" + "=" * 80) |
| |
| except Exception as e: |
| print(f"\n❌ 错误: {e}") |
| import traceback |
| traceback.print_exc() |
| return 1 |
| |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|