UniDG-RFT-LoRA-Release / combine_peft_weights.py
retofan23333's picture
Upload folder using huggingface_hub
e6a67f8 verified
Raw
History Blame Contribute Delete
11 kB
"""
将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中
功能:
1. 加载基础 Flux Fill 模型的 Transformer
2. 加载 RL 训练的 PEFT LoRA 权重
3. 将 LoRA 权重合并到基础模型
4. 保存合并后的完整模型
使用方法:
python combine_peft_weights.py \
--base_model_path /path/to/base/model \
--lora_weights_path /path/to/lora/weights \
--output_path /path/to/output \
--save_full_pipeline # 可选:保存完整 pipeline 而不只是 transformer
"""
import torch
import argparse
import os
from pathlib import Path
from diffusers import FluxFillPipeline
from peft import PeftModel
def merge_and_save_transformer(
base_model_path: str,
lora_weights_path: str,
output_path: str,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu"
):
"""
合并 LoRA 权重到 Transformer 并保存
Args:
base_model_path: 基础 Flux Fill 模型路径
lora_weights_path: PEFT LoRA 权重路径
output_path: 输出路径(保存合并后的 transformer)
dtype: 数据类型
device: 加载设备(建议用 CPU 以节省显存)
"""
print("=" * 80)
print("Step 1: 加载基础 Flux Fill 模型...")
print("=" * 80)
# 加载基础模型(只加载 transformer 部分以节省内存)
pipe = FluxFillPipeline.from_pretrained(
base_model_path,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
print(f"✓ 基础模型加载完成: {base_model_path}")
print(f" Transformer 参数量: {sum(p.numel() for p in pipe.transformer.parameters()) / 1e9:.2f}B")
# 移动到指定设备
if device != "cpu":
print(f" 移动 transformer 到 {device}...")
pipe.transformer = pipe.transformer.to(device)
print("\n" + "=" * 80)
print("Step 2: 加载 PEFT LoRA 权重...")
print("=" * 80)
# 加载 PEFT 模型
print(f" 从 {lora_weights_path} 加载 LoRA 权重...")
peft_model = PeftModel.from_pretrained(
pipe.transformer,
lora_weights_path,
is_trainable=False
)
peft_model.set_adapter("default")
print(f"✓ PEFT 模型加载完成")
# 检查 LoRA 配置
lora_config = peft_model.peft_config.get("default", None)
if lora_config:
print(f" LoRA 配置:")
print(f" - Rank (r): {lora_config.r}")
print(f" - Alpha: {lora_config.lora_alpha}")
print(f" - Dropout: {lora_config.lora_dropout}")
print(f" - Target modules: {lora_config.target_modules}")
print("\n" + "=" * 80)
print("Step 3: 合并 LoRA 权重到基础模型...")
print("=" * 80)
# 合并权重
merged_model = peft_model.merge_and_unload()
print(f"✓ 权重合并完成")
print(f" 合并后模型参数量: {sum(p.numel() for p in merged_model.parameters()) / 1e9:.2f}B")
print("\n" + "=" * 80)
print("Step 4: 保存合并后的模型...")
print("=" * 80)
# 创建输出目录
os.makedirs(output_path, exist_ok=True)
# 保存合并后的 transformer
print(f" 保存到 {output_path}...")
merged_model.save_pretrained(
output_path,
safe_serialization=True, # 使用 safetensors 格式
max_shard_size="5GB" # 分片大小
)
print(f"✓ 模型保存完成: {output_path}")
# 保存模型配置信息
info_path = os.path.join(output_path, "merge_info.txt")
with open(info_path, "w") as f:
f.write(f"Base model: {base_model_path}\n")
f.write(f"LoRA weights: {lora_weights_path}\n")
f.write(f"Merged model: {output_path}\n")
f.write(f"Data type: {dtype}\n")
if lora_config:
f.write(f"\nLoRA Configuration:\n")
f.write(f" Rank (r): {lora_config.r}\n")
f.write(f" Alpha: {lora_config.lora_alpha}\n")
f.write(f" Dropout: {lora_config.lora_dropout}\n")
f.write(f" Target modules: {lora_config.target_modules}\n")
print(f"✓ 合并信息保存到: {info_path}")
return merged_model
def merge_and_save_full_pipeline(
base_model_path: str,
lora_weights_path: str,
output_path: str,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu"
):
"""
合并 LoRA 权重并保存完整的 FluxFillPipeline
Args:
base_model_path: 基础 Flux Fill 模型路径
lora_weights_path: PEFT LoRA 权重路径
output_path: 输出路径(保存完整 pipeline)
dtype: 数据类型
device: 加载设备
"""
print("=" * 80)
print("Step 1: 加载基础 Flux Fill Pipeline...")
print("=" * 80)
# 加载完整 pipeline
pipe = FluxFillPipeline.from_pretrained(
base_model_path,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
print(f"✓ Pipeline 加载完成: {base_model_path}")
# 移动到指定设备
if device != "cpu":
print(f" 移动 transformer 到 {device}...")
pipe.transformer = pipe.transformer.to(device)
print("\n" + "=" * 80)
print("Step 2: 加载并合并 PEFT LoRA 权重...")
print("=" * 80)
# 加载 PEFT 模型
peft_model = PeftModel.from_pretrained(
pipe.transformer,
lora_weights_path,
is_trainable=False
)
peft_model.set_adapter("default")
# 合并权重
merged_transformer = peft_model.merge_and_unload()
# 替换 pipeline 中的 transformer
pipe.transformer = merged_transformer
print(f"✓ 权重合并完成")
print("\n" + "=" * 80)
print("Step 3: 保存完整 Pipeline...")
print("=" * 80)
# 创建输出目录
os.makedirs(output_path, exist_ok=True)
# 保存完整 pipeline
print(f" 保存到 {output_path}...")
pipe.save_pretrained(
output_path,
safe_serialization=True,
max_shard_size="5GB"
)
print(f"✓ 完整 Pipeline 保存完成: {output_path}")
# 保存合并信息
info_path = os.path.join(output_path, "merge_info.txt")
with open(info_path, "w") as f:
f.write(f"Base model: {base_model_path}\n")
f.write(f"LoRA weights: {lora_weights_path}\n")
f.write(f"Merged pipeline: {output_path}\n")
f.write(f"Data type: {dtype}\n")
f.write(f"Components saved:\n")
f.write(f" - Transformer (merged with LoRA)\n")
f.write(f" - VAE\n")
f.write(f" - Text Encoder\n")
f.write(f" - Scheduler\n")
f.write(f" - Other components\n")
print(f"✓ 合并信息保存到: {info_path}")
return pipe
def main():
parser = argparse.ArgumentParser(
description="将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中"
)
parser.add_argument(
"--base_model_path",
type=str,
default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4",
help="基础 Flux Fill 模型路径"
)
parser.add_argument(
"--lora_weights_path",
type=str,
default="/home/tione/notebook2/research/retofan/code/RL/flow_grpo/logs/defectgen_det/flux_fill_redux/checkpoints/checkpoint-60/lora",
help="PEFT LoRA 权重路径"
)
parser.add_argument(
"--output_path",
type=str,
default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4_defect_gen_det_e60",
help="输出路径(保存合并后的模型)"
)
parser.add_argument(
"--save_full_pipeline",
action="store_true",
help="保存完整 FluxFillPipeline(包含 VAE、Text Encoder 等),而不只是 Transformer"
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="数据类型"
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="加载设备(cpu 或 cuda:0 等)。建议使用 cpu 以节省显存"
)
args = parser.parse_args()
# 转换 dtype
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16
}
dtype = dtype_map[args.dtype]
# 检查路径
if not os.path.exists(args.base_model_path):
print(f"错误: 基础模型路径不存在: {args.base_model_path}")
return
if not os.path.exists(args.lora_weights_path):
print(f"错误: LoRA 权重路径不存在: {args.lora_weights_path}")
return
print("\n" + "=" * 80)
print("PEFT LoRA 权重合并工具")
print("=" * 80)
print(f"基础模型: {args.base_model_path}")
print(f"LoRA 权重: {args.lora_weights_path}")
print(f"输出路径: {args.output_path}")
print(f"保存类型: {'完整 Pipeline' if args.save_full_pipeline else '仅 Transformer'}")
print(f"数据类型: {args.dtype}")
print(f"加载设备: {args.device}")
print("=" * 80 + "\n")
try:
if args.save_full_pipeline:
# 保存完整 pipeline
merge_and_save_full_pipeline(
base_model_path=args.base_model_path,
lora_weights_path=args.lora_weights_path,
output_path=args.output_path,
dtype=dtype,
device=args.device
)
else:
# 只保存 transformer
merge_and_save_transformer(
base_model_path=args.base_model_path,
lora_weights_path=args.lora_weights_path,
output_path=args.output_path,
dtype=dtype,
device=args.device
)
print("\n" + "=" * 80)
print("✅ 合并完成!")
print("=" * 80)
print(f"\n合并后的模型已保存到: {args.output_path}")
print("\n使用方法:")
if args.save_full_pipeline:
print(" # 直接加载合并后的完整 pipeline")
print(" from diffusers import FluxFillPipeline")
print(f" pipe = FluxFillPipeline.from_pretrained('{args.output_path}')")
else:
print(" # 加载基础 pipeline,然后替换 transformer")
print(" from diffusers import FluxFillPipeline")
print(" from diffusers.models import FluxTransformer2DModel")
print(f" pipe = FluxFillPipeline.from_pretrained('{args.base_model_path}')")
print(f" pipe.transformer = FluxTransformer2DModel.from_pretrained('{args.output_path}')")
print("\n" + "=" * 80)
except Exception as e:
print(f"\n❌ 错误: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())