retofan23333 commited on
Commit
e6a67f8
·
verified ·
1 Parent(s): ca9c515

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - defect-generation
5
+ - anomaly-detection
6
+ - industrial-inspection
7
+ - lora
8
+ - flux
9
+ - diffusion
10
+ - rlhf
11
+ language: en
12
+ pipeline_tag: image-to-image
13
+ ---
14
+
15
+ # UniDG-RFT-LoRA
16
+
17
+ LoRA weights for **UniDG** (Universal Defect Generation), trained via **Consistency-RFT** with Flow-GRPO and dual reward models on the UDG dataset (300K quadruplets).
18
+
19
+ [[Paper]](https://arxiv.org/abs/2604.08915) [[Code]](https://github.com/RetoFan233/UniDG) [[UniDG-SFT-LoRA]](https://huggingface.co/retofan23333/UniDG-SFT-LoRA-Release)
20
+
21
+ ## Overview
22
+
23
+ UniDG is a universal defect generation foundation model that transfers defects from a reference image to a target region via **Defect-Context Editing** and **MM-DiT multimodal attention**, without per-category fine-tuning. This checkpoint is the **Consistency-RFT** variant, further refined from UniDG-SFT using Flow-GRPO with dual reward models (Defect-Und-Reward & Defect-Recog-Reward) for improved defect fidelity and consistency.
24
+
25
+ | Variant | Training | Focus |
26
+ |---------|----------|-------|
27
+ | UniDG-SFT | Diversity-SFT with complementary sampling | Diverse defect patterns |
28
+ | **UniDG-RFT** (this) | Consistency-RFT with Flow-GRPO + dual rewards | Consistent & faithful defects |
29
+
30
+ ## Important: Usage Difference from UniDG-SFT-LoRA
31
+
32
+ **The UniDG-RFT-LoRA weights are stored in PEFT format** (`adapter_model.safetensors` + `adapter_config.json`), which is different from UniDG-SFT-LoRA (which uses `pytorch_lora_weights.safetensors`). This means:
33
+
34
+ - **UniDG-SFT-LoRA** can be directly loaded via the `lora_weights_path` parameter in `ImageUniDG`.
35
+ - **UniDG-RFT-LoRA** must first be **merged into the base SFT model** using the provided `combine_peft_weights.py` script. After merging, the resulting model can be loaded directly without any additional LoRA loading step.
36
+
37
+ ## Repository Contents
38
+
39
+ | File | Description |
40
+ |------|-------------|
41
+ | `adapter_model.safetensors` | PEFT LoRA weights (Consistency-RFT) |
42
+ | `adapter_config.json` | LoRA configuration (rank=64, alpha=128) |
43
+ | `combine_peft_weights.py` | Script to merge RFT LoRA into the base SFT model |
44
+
45
+ ## Step-by-Step Usage
46
+
47
+ ### Prerequisites
48
+
49
+ - [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) (inpainting backbone)
50
+ - [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) (reference conditioning)
51
+ - [UniDG-SFT-LoRA](https://huggingface.co/retofan23333/UniDG-SFT-LoRA-Release) (base SFT model — the RFT LoRA is fine-tuned on top of this)
52
+ - [UniDG code](https://github.com/RetoFan233/UniDG) (inference framework)
53
+ - Python dependencies: `diffusers`, `peft`, `torch`
54
+
55
+ ### Step 1: Prepare the Base SFT Model
56
+
57
+ First, you need a base FLUX.1-Fill-dev model with UniDG-SFT-LoRA weights already merged in. If you haven't done this, you can prepare it by loading the SFT model and saving the merged weights:
58
+
59
+ ```python
60
+ from diffusers import FluxFillPipeline
61
+ import torch
62
+
63
+ # Load base FLUX.1-Fill-dev
64
+ pipe = FluxFillPipeline.from_pretrained(
65
+ "path/to/FLUX.1-Fill-dev",
66
+ torch_dtype=torch.bfloat16,
67
+ )
68
+
69
+ # Load SFT LoRA weights
70
+ pipe.load_lora_weights("path/to/UniDG-SFT-LoRA-Release/pytorch_lora_weights.safetensors")
71
+
72
+ # Save the merged SFT model as the base for RFT merging
73
+ pipe.save_pretrained("path/to/FLUX.1-Fill-dev-UDG-SFT", safe_serialization=True, max_shard_size="5GB")
74
+ ```
75
+
76
+ ### Step 2: Merge RFT LoRA into the Base SFT Model
77
+
78
+ Use the provided `combine_peft_weights.py` to merge the RFT LoRA weights into the base SFT model:
79
+
80
+ ```bash
81
+ python combine_peft_weights.py \
82
+ --base_model_path path/to/FLUX.1-Fill-dev-UDG-SFT \
83
+ --lora_weights_path path/to/UniDG-RFT-LoRA-Release \
84
+ --output_path path/to/FLUX.1-Fill-dev-UDG-RFT \
85
+ --save_full_pipeline
86
+ ```
87
+
88
+ Parameters:
89
+ - `--base_model_path`: Path to the base SFT model (from Step 1)
90
+ - `--lora_weights_path`: Path to this RFT LoRA repository (containing `adapter_model.safetensors` and `adapter_config.json`)
91
+ - `--output_path`: Output path for the merged model
92
+ - `--save_full_pipeline`: Save the full pipeline (including VAE, text encoder, etc.) so you can load it directly later
93
+ - `--dtype`: Data type, default `bfloat16`
94
+ - `--device`: Device for loading, default `cpu` (recommended to avoid OOM)
95
+
96
+ > **Tip**: Use `--device cpu` (default) to save GPU memory during the merge process. The merge only needs to run once.
97
+
98
+ ### Step 3: Use the Merged Model with UniDG
99
+
100
+ After merging, the model can be used directly with the UniDG inference code — **no additional LoRA loading is needed**:
101
+
102
+ ```python
103
+ from unidg import ImageUniDG
104
+ from PIL import Image
105
+ import torch
106
+
107
+ # Load the merged RFT model — set lora_weights_path="" since LoRA is already merged
108
+ model = ImageUniDG(
109
+ flux_model_path="path/to/FLUX.1-Fill-dev-UDG-RFT",
110
+ redux_model_path="path/to/FLUX.1-Redux-dev",
111
+ lora_weights_path="", # No additional LoRA needed!
112
+ device="cuda:0",
113
+ dtype=torch.bfloat16,
114
+ )
115
+
116
+ result, mask = model.process_images(
117
+ target_image=Image.open("target.jpg"),
118
+ reference_image=Image.open("reference.jpg"),
119
+ reference_mask=Image.open("reference_mask.png"),
120
+ target_mask=Image.open("target_mask.png"),
121
+ num_inference_steps=28,
122
+ guidance_scale=3.5,
123
+ seed=42,
124
+ )
125
+ result.save("result.png")
126
+ ```
127
+
128
+ ### Quick Reference: SFT vs RFT Usage
129
+
130
+ | | UniDG-SFT | UniDG-RFT |
131
+ |---|-----------|-----------|
132
+ | Weight format | `pytorch_lora_weights.safetensors` | `adapter_model.safetensors` + `adapter_config.json` |
133
+ | Merge required? | No | Yes (with SFT base model) |
134
+ | `lora_weights_path` | Path to SFT weights | `""` (empty, after merge) |
135
+ | `flux_model_path` | `path/to/FLUX.1-Fill-dev` | `path/to/merged-RFT-model` |
136
+ | Load time | LoRA loaded on-the-fly | Pre-merged, no LoRA overhead |
137
+
138
+ ## LoRA Configuration
139
+
140
+ | Parameter | Value |
141
+ |-----------|-------|
142
+ | PEFT type | LORA |
143
+ | Rank (r) | 64 |
144
+ | Alpha | 128 |
145
+ | Dropout | 0.0 |
146
+ | Target modules | `ff.net.0.proj`, `ff.net.2`, `ff_context.net.0.proj`, `proj_mlp`, `attn.to_q`, `attn.to_v`, `attn.to_add_out`, `attn.add_k_proj`, `attn.add_v_proj`, `ff_context.net.2`, `attn.add_q_proj`, `attn.to_out.0`, `attn.to_k` |
147
+ | Base model | FLUX.1-Fill-dev + UniDG-SFT-LoRA |
148
+
149
+ ## Citation
150
+
151
+ ```bibtex
152
+ @article{fan2026unidg,
153
+ title={Large-Scale Universal Defect Generation: Foundation Models and Datasets},
154
+ author={Fan, Yuanting and Liu, Jun and Gao, Bin-Bin and Chen, Xiaochen and Lin, Yuhuan and Dai, Zhewei and Zhan, Jiawei and Wang, Chengjie},
155
+ journal={arXiv preprint arXiv:2604.08915},
156
+ year={2026}
157
+ }
158
+ ```
adapter_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "FluxTransformer2DModel",
5
+ "parent_library": "diffusers.models.transformers.transformer_flux"
6
+ },
7
+ "base_model_name_or_path": null,
8
+ "bias": "none",
9
+ "corda_config": null,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": "gaussian",
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 128,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.0,
22
+ "megatron_config": null,
23
+ "megatron_core": "megatron.core",
24
+ "modules_to_save": null,
25
+ "peft_type": "LORA",
26
+ "qalora_group_size": 16,
27
+ "r": 64,
28
+ "rank_pattern": {},
29
+ "revision": null,
30
+ "target_modules": [
31
+ "ff.net.0.proj",
32
+ "ff.net.2",
33
+ "ff_context.net.0.proj",
34
+ "proj_mlp",
35
+ "attn.to_q",
36
+ "attn.to_v",
37
+ "attn.to_add_out",
38
+ "attn.add_k_proj",
39
+ "attn.add_v_proj",
40
+ "ff_context.net.2",
41
+ "attn.add_q_proj",
42
+ "attn.to_out.0",
43
+ "attn.to_k"
44
+ ],
45
+ "target_parameters": null,
46
+ "task_type": null,
47
+ "trainable_token_indices": null,
48
+ "use_dora": false,
49
+ "use_qalora": false,
50
+ "use_rslora": false
51
+ }
adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffdb61c5049d61ff469601d1f4a84a983d8b0aba0c2aa3c05b1da75bcfc887e7
3
+ size 433431520
combine_peft_weights.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中
3
+
4
+ 功能:
5
+ 1. 加载基础 Flux Fill 模型的 Transformer
6
+ 2. 加载 RL 训练的 PEFT LoRA 权重
7
+ 3. 将 LoRA 权重合并到基础模型
8
+ 4. 保存合并后的完整模型
9
+
10
+ 使用方法:
11
+ python combine_peft_weights.py \
12
+ --base_model_path /path/to/base/model \
13
+ --lora_weights_path /path/to/lora/weights \
14
+ --output_path /path/to/output \
15
+ --save_full_pipeline # 可选:保存完整 pipeline 而不只是 transformer
16
+ """
17
+
18
+ import torch
19
+ import argparse
20
+ import os
21
+ from pathlib import Path
22
+ from diffusers import FluxFillPipeline
23
+ from peft import PeftModel
24
+
25
+
26
+ def merge_and_save_transformer(
27
+ base_model_path: str,
28
+ lora_weights_path: str,
29
+ output_path: str,
30
+ dtype: torch.dtype = torch.bfloat16,
31
+ device: str = "cpu"
32
+ ):
33
+ """
34
+ 合并 LoRA 权重到 Transformer 并保存
35
+
36
+ Args:
37
+ base_model_path: 基础 Flux Fill 模型路径
38
+ lora_weights_path: PEFT LoRA 权重路径
39
+ output_path: 输出路径(保存合并后的 transformer)
40
+ dtype: 数据类型
41
+ device: 加载设备(建议用 CPU 以节省显存)
42
+ """
43
+ print("=" * 80)
44
+ print("Step 1: 加载基础 Flux Fill 模型...")
45
+ print("=" * 80)
46
+
47
+ # 加载基础模型(只加载 transformer 部分以节省内存)
48
+ pipe = FluxFillPipeline.from_pretrained(
49
+ base_model_path,
50
+ torch_dtype=dtype,
51
+ low_cpu_mem_usage=True
52
+ )
53
+
54
+ print(f"✓ 基础模型加载完成: {base_model_path}")
55
+ print(f" Transformer 参数量: {sum(p.numel() for p in pipe.transformer.parameters()) / 1e9:.2f}B")
56
+
57
+ # 移动到指定设备
58
+ if device != "cpu":
59
+ print(f" 移动 transformer 到 {device}...")
60
+ pipe.transformer = pipe.transformer.to(device)
61
+
62
+ print("\n" + "=" * 80)
63
+ print("Step 2: 加载 PEFT LoRA 权重...")
64
+ print("=" * 80)
65
+
66
+ # 加载 PEFT 模型
67
+ print(f" 从 {lora_weights_path} 加载 LoRA 权重...")
68
+ peft_model = PeftModel.from_pretrained(
69
+ pipe.transformer,
70
+ lora_weights_path,
71
+ is_trainable=False
72
+ )
73
+ peft_model.set_adapter("default")
74
+
75
+ print(f"✓ PEFT 模型加载完成")
76
+
77
+ # 检查 LoRA 配置
78
+ lora_config = peft_model.peft_config.get("default", None)
79
+ if lora_config:
80
+ print(f" LoRA 配置:")
81
+ print(f" - Rank (r): {lora_config.r}")
82
+ print(f" - Alpha: {lora_config.lora_alpha}")
83
+ print(f" - Dropout: {lora_config.lora_dropout}")
84
+ print(f" - Target modules: {lora_config.target_modules}")
85
+
86
+ print("\n" + "=" * 80)
87
+ print("Step 3: 合并 LoRA 权重到基础模型...")
88
+ print("=" * 80)
89
+
90
+ # 合并权重
91
+ merged_model = peft_model.merge_and_unload()
92
+
93
+ print(f"✓ 权重合并完成")
94
+ print(f" 合并后模型参数量: {sum(p.numel() for p in merged_model.parameters()) / 1e9:.2f}B")
95
+
96
+ print("\n" + "=" * 80)
97
+ print("Step 4: 保存合并后的模型...")
98
+ print("=" * 80)
99
+
100
+ # 创建输出目录
101
+ os.makedirs(output_path, exist_ok=True)
102
+
103
+ # 保存合并后的 transformer
104
+ print(f" 保存到 {output_path}...")
105
+ merged_model.save_pretrained(
106
+ output_path,
107
+ safe_serialization=True, # 使用 safetensors 格式
108
+ max_shard_size="5GB" # 分片大小
109
+ )
110
+
111
+ print(f"✓ 模型保存完成: {output_path}")
112
+
113
+ # 保存模型配置信息
114
+ info_path = os.path.join(output_path, "merge_info.txt")
115
+ with open(info_path, "w") as f:
116
+ f.write(f"Base model: {base_model_path}\n")
117
+ f.write(f"LoRA weights: {lora_weights_path}\n")
118
+ f.write(f"Merged model: {output_path}\n")
119
+ f.write(f"Data type: {dtype}\n")
120
+ if lora_config:
121
+ f.write(f"\nLoRA Configuration:\n")
122
+ f.write(f" Rank (r): {lora_config.r}\n")
123
+ f.write(f" Alpha: {lora_config.lora_alpha}\n")
124
+ f.write(f" Dropout: {lora_config.lora_dropout}\n")
125
+ f.write(f" Target modules: {lora_config.target_modules}\n")
126
+
127
+ print(f"✓ 合并信息保存到: {info_path}")
128
+
129
+ return merged_model
130
+
131
+
132
+ def merge_and_save_full_pipeline(
133
+ base_model_path: str,
134
+ lora_weights_path: str,
135
+ output_path: str,
136
+ dtype: torch.dtype = torch.bfloat16,
137
+ device: str = "cpu"
138
+ ):
139
+ """
140
+ 合并 LoRA 权重并保存完整的 FluxFillPipeline
141
+
142
+ Args:
143
+ base_model_path: 基础 Flux Fill 模型路径
144
+ lora_weights_path: PEFT LoRA 权重路径
145
+ output_path: 输出路径(保存完整 pipeline)
146
+ dtype: 数据类型
147
+ device: 加载设备
148
+ """
149
+ print("=" * 80)
150
+ print("Step 1: 加载基础 Flux Fill Pipeline...")
151
+ print("=" * 80)
152
+
153
+ # 加载完整 pipeline
154
+ pipe = FluxFillPipeline.from_pretrained(
155
+ base_model_path,
156
+ torch_dtype=dtype,
157
+ low_cpu_mem_usage=True
158
+ )
159
+
160
+ print(f"✓ Pipeline 加载完成: {base_model_path}")
161
+
162
+ # 移动到指定设备
163
+ if device != "cpu":
164
+ print(f" 移动 transformer 到 {device}...")
165
+ pipe.transformer = pipe.transformer.to(device)
166
+
167
+ print("\n" + "=" * 80)
168
+ print("Step 2: 加载并合并 PEFT LoRA 权重...")
169
+ print("=" * 80)
170
+
171
+ # 加载 PEFT 模型
172
+ peft_model = PeftModel.from_pretrained(
173
+ pipe.transformer,
174
+ lora_weights_path,
175
+ is_trainable=False
176
+ )
177
+ peft_model.set_adapter("default")
178
+
179
+ # 合并权重
180
+ merged_transformer = peft_model.merge_and_unload()
181
+
182
+ # 替换 pipeline 中的 transformer
183
+ pipe.transformer = merged_transformer
184
+
185
+ print(f"✓ 权重合并完成")
186
+
187
+ print("\n" + "=" * 80)
188
+ print("Step 3: 保存完整 Pipeline...")
189
+ print("=" * 80)
190
+
191
+ # 创建输出目录
192
+ os.makedirs(output_path, exist_ok=True)
193
+
194
+ # 保存完整 pipeline
195
+ print(f" 保存到 {output_path}...")
196
+ pipe.save_pretrained(
197
+ output_path,
198
+ safe_serialization=True,
199
+ max_shard_size="5GB"
200
+ )
201
+
202
+ print(f"✓ 完整 Pipeline 保存完成: {output_path}")
203
+
204
+ # 保存合并信息
205
+ info_path = os.path.join(output_path, "merge_info.txt")
206
+ with open(info_path, "w") as f:
207
+ f.write(f"Base model: {base_model_path}\n")
208
+ f.write(f"LoRA weights: {lora_weights_path}\n")
209
+ f.write(f"Merged pipeline: {output_path}\n")
210
+ f.write(f"Data type: {dtype}\n")
211
+ f.write(f"Components saved:\n")
212
+ f.write(f" - Transformer (merged with LoRA)\n")
213
+ f.write(f" - VAE\n")
214
+ f.write(f" - Text Encoder\n")
215
+ f.write(f" - Scheduler\n")
216
+ f.write(f" - Other components\n")
217
+
218
+ print(f"✓ 合并信息保存到: {info_path}")
219
+
220
+ return pipe
221
+
222
+
223
+ def main():
224
+ parser = argparse.ArgumentParser(
225
+ description="将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中"
226
+ )
227
+
228
+ parser.add_argument(
229
+ "--base_model_path",
230
+ type=str,
231
+ default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4",
232
+ help="基础 Flux Fill 模型路径"
233
+ )
234
+
235
+ parser.add_argument(
236
+ "--lora_weights_path",
237
+ type=str,
238
+ default="/home/tione/notebook2/research/retofan/code/RL/flow_grpo/logs/defectgen_det/flux_fill_redux/checkpoints/checkpoint-60/lora",
239
+ help="PEFT LoRA 权重路径"
240
+ )
241
+
242
+ parser.add_argument(
243
+ "--output_path",
244
+ type=str,
245
+ default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4_defect_gen_det_e60",
246
+ help="输出路径(保存合并后的模型)"
247
+ )
248
+
249
+ parser.add_argument(
250
+ "--save_full_pipeline",
251
+ action="store_true",
252
+ help="保存完整 FluxFillPipeline(包含 VAE、Text Encoder 等),而不只是 Transformer"
253
+ )
254
+
255
+ parser.add_argument(
256
+ "--dtype",
257
+ type=str,
258
+ default="bfloat16",
259
+ choices=["float32", "float16", "bfloat16"],
260
+ help="数据类型"
261
+ )
262
+
263
+ parser.add_argument(
264
+ "--device",
265
+ type=str,
266
+ default="cpu",
267
+ help="加载设备(cpu 或 cuda:0 等)。建议使用 cpu 以节省显存"
268
+ )
269
+
270
+ args = parser.parse_args()
271
+
272
+ # 转换 dtype
273
+ dtype_map = {
274
+ "float32": torch.float32,
275
+ "float16": torch.float16,
276
+ "bfloat16": torch.bfloat16
277
+ }
278
+ dtype = dtype_map[args.dtype]
279
+
280
+ # 检查路径
281
+ if not os.path.exists(args.base_model_path):
282
+ print(f"错误: 基础模型路径不存在: {args.base_model_path}")
283
+ return
284
+
285
+ if not os.path.exists(args.lora_weights_path):
286
+ print(f"错误: LoRA 权重路径不存在: {args.lora_weights_path}")
287
+ return
288
+
289
+ print("\n" + "=" * 80)
290
+ print("PEFT LoRA 权重合并工具")
291
+ print("=" * 80)
292
+ print(f"基础模型: {args.base_model_path}")
293
+ print(f"LoRA 权重: {args.lora_weights_path}")
294
+ print(f"输出路径: {args.output_path}")
295
+ print(f"保存类型: {'完整 Pipeline' if args.save_full_pipeline else '仅 Transformer'}")
296
+ print(f"数据类型: {args.dtype}")
297
+ print(f"加载设备: {args.device}")
298
+ print("=" * 80 + "\n")
299
+
300
+ try:
301
+ if args.save_full_pipeline:
302
+ # 保存完整 pipeline
303
+ merge_and_save_full_pipeline(
304
+ base_model_path=args.base_model_path,
305
+ lora_weights_path=args.lora_weights_path,
306
+ output_path=args.output_path,
307
+ dtype=dtype,
308
+ device=args.device
309
+ )
310
+ else:
311
+ # 只保存 transformer
312
+ merge_and_save_transformer(
313
+ base_model_path=args.base_model_path,
314
+ lora_weights_path=args.lora_weights_path,
315
+ output_path=args.output_path,
316
+ dtype=dtype,
317
+ device=args.device
318
+ )
319
+
320
+ print("\n" + "=" * 80)
321
+ print("✅ 合并完成!")
322
+ print("=" * 80)
323
+ print(f"\n合并后的模型已保存到: {args.output_path}")
324
+ print("\n使用方法:")
325
+
326
+ if args.save_full_pipeline:
327
+ print(" # 直接加载合并后的完整 pipeline")
328
+ print(" from diffusers import FluxFillPipeline")
329
+ print(f" pipe = FluxFillPipeline.from_pretrained('{args.output_path}')")
330
+ else:
331
+ print(" # 加载基础 pipeline,然后替换 transformer")
332
+ print(" from diffusers import FluxFillPipeline")
333
+ print(" from diffusers.models import FluxTransformer2DModel")
334
+ print(f" pipe = FluxFillPipeline.from_pretrained('{args.base_model_path}')")
335
+ print(f" pipe.transformer = FluxTransformer2DModel.from_pretrained('{args.output_path}')")
336
+
337
+ print("\n" + "=" * 80)
338
+
339
+ except Exception as e:
340
+ print(f"\n❌ 错误: {e}")
341
+ import traceback
342
+ traceback.print_exc()
343
+ return 1
344
+
345
+ return 0
346
+
347
+
348
+ if __name__ == "__main__":
349
+ exit(main())