Upload folder using huggingface_hub
Browse files- README.md +158 -0
- adapter_config.json +51 -0
- adapter_model.safetensors +3 -0
- combine_peft_weights.py +349 -0
README.md
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- defect-generation
|
| 5 |
+
- anomaly-detection
|
| 6 |
+
- industrial-inspection
|
| 7 |
+
- lora
|
| 8 |
+
- flux
|
| 9 |
+
- diffusion
|
| 10 |
+
- rlhf
|
| 11 |
+
language: en
|
| 12 |
+
pipeline_tag: image-to-image
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# UniDG-RFT-LoRA
|
| 16 |
+
|
| 17 |
+
LoRA weights for **UniDG** (Universal Defect Generation), trained via **Consistency-RFT** with Flow-GRPO and dual reward models on the UDG dataset (300K quadruplets).
|
| 18 |
+
|
| 19 |
+
[[Paper]](https://arxiv.org/abs/2604.08915) [[Code]](https://github.com/RetoFan233/UniDG) [[UniDG-SFT-LoRA]](https://huggingface.co/retofan23333/UniDG-SFT-LoRA-Release)
|
| 20 |
+
|
| 21 |
+
## Overview
|
| 22 |
+
|
| 23 |
+
UniDG is a universal defect generation foundation model that transfers defects from a reference image to a target region via **Defect-Context Editing** and **MM-DiT multimodal attention**, without per-category fine-tuning. This checkpoint is the **Consistency-RFT** variant, further refined from UniDG-SFT using Flow-GRPO with dual reward models (Defect-Und-Reward & Defect-Recog-Reward) for improved defect fidelity and consistency.
|
| 24 |
+
|
| 25 |
+
| Variant | Training | Focus |
|
| 26 |
+
|---------|----------|-------|
|
| 27 |
+
| UniDG-SFT | Diversity-SFT with complementary sampling | Diverse defect patterns |
|
| 28 |
+
| **UniDG-RFT** (this) | Consistency-RFT with Flow-GRPO + dual rewards | Consistent & faithful defects |
|
| 29 |
+
|
| 30 |
+
## Important: Usage Difference from UniDG-SFT-LoRA
|
| 31 |
+
|
| 32 |
+
**The UniDG-RFT-LoRA weights are stored in PEFT format** (`adapter_model.safetensors` + `adapter_config.json`), which is different from UniDG-SFT-LoRA (which uses `pytorch_lora_weights.safetensors`). This means:
|
| 33 |
+
|
| 34 |
+
- **UniDG-SFT-LoRA** can be directly loaded via the `lora_weights_path` parameter in `ImageUniDG`.
|
| 35 |
+
- **UniDG-RFT-LoRA** must first be **merged into the base SFT model** using the provided `combine_peft_weights.py` script. After merging, the resulting model can be loaded directly without any additional LoRA loading step.
|
| 36 |
+
|
| 37 |
+
## Repository Contents
|
| 38 |
+
|
| 39 |
+
| File | Description |
|
| 40 |
+
|------|-------------|
|
| 41 |
+
| `adapter_model.safetensors` | PEFT LoRA weights (Consistency-RFT) |
|
| 42 |
+
| `adapter_config.json` | LoRA configuration (rank=64, alpha=128) |
|
| 43 |
+
| `combine_peft_weights.py` | Script to merge RFT LoRA into the base SFT model |
|
| 44 |
+
|
| 45 |
+
## Step-by-Step Usage
|
| 46 |
+
|
| 47 |
+
### Prerequisites
|
| 48 |
+
|
| 49 |
+
- [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) (inpainting backbone)
|
| 50 |
+
- [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) (reference conditioning)
|
| 51 |
+
- [UniDG-SFT-LoRA](https://huggingface.co/retofan23333/UniDG-SFT-LoRA-Release) (base SFT model — the RFT LoRA is fine-tuned on top of this)
|
| 52 |
+
- [UniDG code](https://github.com/RetoFan233/UniDG) (inference framework)
|
| 53 |
+
- Python dependencies: `diffusers`, `peft`, `torch`
|
| 54 |
+
|
| 55 |
+
### Step 1: Prepare the Base SFT Model
|
| 56 |
+
|
| 57 |
+
First, you need a base FLUX.1-Fill-dev model with UniDG-SFT-LoRA weights already merged in. If you haven't done this, you can prepare it by loading the SFT model and saving the merged weights:
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
from diffusers import FluxFillPipeline
|
| 61 |
+
import torch
|
| 62 |
+
|
| 63 |
+
# Load base FLUX.1-Fill-dev
|
| 64 |
+
pipe = FluxFillPipeline.from_pretrained(
|
| 65 |
+
"path/to/FLUX.1-Fill-dev",
|
| 66 |
+
torch_dtype=torch.bfloat16,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Load SFT LoRA weights
|
| 70 |
+
pipe.load_lora_weights("path/to/UniDG-SFT-LoRA-Release/pytorch_lora_weights.safetensors")
|
| 71 |
+
|
| 72 |
+
# Save the merged SFT model as the base for RFT merging
|
| 73 |
+
pipe.save_pretrained("path/to/FLUX.1-Fill-dev-UDG-SFT", safe_serialization=True, max_shard_size="5GB")
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### Step 2: Merge RFT LoRA into the Base SFT Model
|
| 77 |
+
|
| 78 |
+
Use the provided `combine_peft_weights.py` to merge the RFT LoRA weights into the base SFT model:
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
python combine_peft_weights.py \
|
| 82 |
+
--base_model_path path/to/FLUX.1-Fill-dev-UDG-SFT \
|
| 83 |
+
--lora_weights_path path/to/UniDG-RFT-LoRA-Release \
|
| 84 |
+
--output_path path/to/FLUX.1-Fill-dev-UDG-RFT \
|
| 85 |
+
--save_full_pipeline
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
Parameters:
|
| 89 |
+
- `--base_model_path`: Path to the base SFT model (from Step 1)
|
| 90 |
+
- `--lora_weights_path`: Path to this RFT LoRA repository (containing `adapter_model.safetensors` and `adapter_config.json`)
|
| 91 |
+
- `--output_path`: Output path for the merged model
|
| 92 |
+
- `--save_full_pipeline`: Save the full pipeline (including VAE, text encoder, etc.) so you can load it directly later
|
| 93 |
+
- `--dtype`: Data type, default `bfloat16`
|
| 94 |
+
- `--device`: Device for loading, default `cpu` (recommended to avoid OOM)
|
| 95 |
+
|
| 96 |
+
> **Tip**: Use `--device cpu` (default) to save GPU memory during the merge process. The merge only needs to run once.
|
| 97 |
+
|
| 98 |
+
### Step 3: Use the Merged Model with UniDG
|
| 99 |
+
|
| 100 |
+
After merging, the model can be used directly with the UniDG inference code — **no additional LoRA loading is needed**:
|
| 101 |
+
|
| 102 |
+
```python
|
| 103 |
+
from unidg import ImageUniDG
|
| 104 |
+
from PIL import Image
|
| 105 |
+
import torch
|
| 106 |
+
|
| 107 |
+
# Load the merged RFT model — set lora_weights_path="" since LoRA is already merged
|
| 108 |
+
model = ImageUniDG(
|
| 109 |
+
flux_model_path="path/to/FLUX.1-Fill-dev-UDG-RFT",
|
| 110 |
+
redux_model_path="path/to/FLUX.1-Redux-dev",
|
| 111 |
+
lora_weights_path="", # No additional LoRA needed!
|
| 112 |
+
device="cuda:0",
|
| 113 |
+
dtype=torch.bfloat16,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
result, mask = model.process_images(
|
| 117 |
+
target_image=Image.open("target.jpg"),
|
| 118 |
+
reference_image=Image.open("reference.jpg"),
|
| 119 |
+
reference_mask=Image.open("reference_mask.png"),
|
| 120 |
+
target_mask=Image.open("target_mask.png"),
|
| 121 |
+
num_inference_steps=28,
|
| 122 |
+
guidance_scale=3.5,
|
| 123 |
+
seed=42,
|
| 124 |
+
)
|
| 125 |
+
result.save("result.png")
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### Quick Reference: SFT vs RFT Usage
|
| 129 |
+
|
| 130 |
+
| | UniDG-SFT | UniDG-RFT |
|
| 131 |
+
|---|-----------|-----------|
|
| 132 |
+
| Weight format | `pytorch_lora_weights.safetensors` | `adapter_model.safetensors` + `adapter_config.json` |
|
| 133 |
+
| Merge required? | No | Yes (with SFT base model) |
|
| 134 |
+
| `lora_weights_path` | Path to SFT weights | `""` (empty, after merge) |
|
| 135 |
+
| `flux_model_path` | `path/to/FLUX.1-Fill-dev` | `path/to/merged-RFT-model` |
|
| 136 |
+
| Load time | LoRA loaded on-the-fly | Pre-merged, no LoRA overhead |
|
| 137 |
+
|
| 138 |
+
## LoRA Configuration
|
| 139 |
+
|
| 140 |
+
| Parameter | Value |
|
| 141 |
+
|-----------|-------|
|
| 142 |
+
| PEFT type | LORA |
|
| 143 |
+
| Rank (r) | 64 |
|
| 144 |
+
| Alpha | 128 |
|
| 145 |
+
| Dropout | 0.0 |
|
| 146 |
+
| Target modules | `ff.net.0.proj`, `ff.net.2`, `ff_context.net.0.proj`, `proj_mlp`, `attn.to_q`, `attn.to_v`, `attn.to_add_out`, `attn.add_k_proj`, `attn.add_v_proj`, `ff_context.net.2`, `attn.add_q_proj`, `attn.to_out.0`, `attn.to_k` |
|
| 147 |
+
| Base model | FLUX.1-Fill-dev + UniDG-SFT-LoRA |
|
| 148 |
+
|
| 149 |
+
## Citation
|
| 150 |
+
|
| 151 |
+
```bibtex
|
| 152 |
+
@article{fan2026unidg,
|
| 153 |
+
title={Large-Scale Universal Defect Generation: Foundation Models and Datasets},
|
| 154 |
+
author={Fan, Yuanting and Liu, Jun and Gao, Bin-Bin and Chen, Xiaochen and Lin, Yuhuan and Dai, Zhewei and Zhan, Jiawei and Wang, Chengjie},
|
| 155 |
+
journal={arXiv preprint arXiv:2604.08915},
|
| 156 |
+
year={2026}
|
| 157 |
+
}
|
| 158 |
+
```
|
adapter_config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": {
|
| 4 |
+
"base_model_class": "FluxTransformer2DModel",
|
| 5 |
+
"parent_library": "diffusers.models.transformers.transformer_flux"
|
| 6 |
+
},
|
| 7 |
+
"base_model_name_or_path": null,
|
| 8 |
+
"bias": "none",
|
| 9 |
+
"corda_config": null,
|
| 10 |
+
"eva_config": null,
|
| 11 |
+
"exclude_modules": null,
|
| 12 |
+
"fan_in_fan_out": false,
|
| 13 |
+
"inference_mode": true,
|
| 14 |
+
"init_lora_weights": "gaussian",
|
| 15 |
+
"layer_replication": null,
|
| 16 |
+
"layers_pattern": null,
|
| 17 |
+
"layers_to_transform": null,
|
| 18 |
+
"loftq_config": {},
|
| 19 |
+
"lora_alpha": 128,
|
| 20 |
+
"lora_bias": false,
|
| 21 |
+
"lora_dropout": 0.0,
|
| 22 |
+
"megatron_config": null,
|
| 23 |
+
"megatron_core": "megatron.core",
|
| 24 |
+
"modules_to_save": null,
|
| 25 |
+
"peft_type": "LORA",
|
| 26 |
+
"qalora_group_size": 16,
|
| 27 |
+
"r": 64,
|
| 28 |
+
"rank_pattern": {},
|
| 29 |
+
"revision": null,
|
| 30 |
+
"target_modules": [
|
| 31 |
+
"ff.net.0.proj",
|
| 32 |
+
"ff.net.2",
|
| 33 |
+
"ff_context.net.0.proj",
|
| 34 |
+
"proj_mlp",
|
| 35 |
+
"attn.to_q",
|
| 36 |
+
"attn.to_v",
|
| 37 |
+
"attn.to_add_out",
|
| 38 |
+
"attn.add_k_proj",
|
| 39 |
+
"attn.add_v_proj",
|
| 40 |
+
"ff_context.net.2",
|
| 41 |
+
"attn.add_q_proj",
|
| 42 |
+
"attn.to_out.0",
|
| 43 |
+
"attn.to_k"
|
| 44 |
+
],
|
| 45 |
+
"target_parameters": null,
|
| 46 |
+
"task_type": null,
|
| 47 |
+
"trainable_token_indices": null,
|
| 48 |
+
"use_dora": false,
|
| 49 |
+
"use_qalora": false,
|
| 50 |
+
"use_rslora": false
|
| 51 |
+
}
|
adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffdb61c5049d61ff469601d1f4a84a983d8b0aba0c2aa3c05b1da75bcfc887e7
|
| 3 |
+
size 433431520
|
combine_peft_weights.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中
|
| 3 |
+
|
| 4 |
+
功能:
|
| 5 |
+
1. 加载基础 Flux Fill 模型的 Transformer
|
| 6 |
+
2. 加载 RL 训练的 PEFT LoRA 权重
|
| 7 |
+
3. 将 LoRA 权重合并到基础模型
|
| 8 |
+
4. 保存合并后的完整模型
|
| 9 |
+
|
| 10 |
+
使用方法:
|
| 11 |
+
python combine_peft_weights.py \
|
| 12 |
+
--base_model_path /path/to/base/model \
|
| 13 |
+
--lora_weights_path /path/to/lora/weights \
|
| 14 |
+
--output_path /path/to/output \
|
| 15 |
+
--save_full_pipeline # 可选:保存完整 pipeline 而不只是 transformer
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import argparse
|
| 20 |
+
import os
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from diffusers import FluxFillPipeline
|
| 23 |
+
from peft import PeftModel
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def merge_and_save_transformer(
|
| 27 |
+
base_model_path: str,
|
| 28 |
+
lora_weights_path: str,
|
| 29 |
+
output_path: str,
|
| 30 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 31 |
+
device: str = "cpu"
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
合并 LoRA 权重到 Transformer 并保存
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
base_model_path: 基础 Flux Fill 模型路径
|
| 38 |
+
lora_weights_path: PEFT LoRA 权重路径
|
| 39 |
+
output_path: 输出路径(保存合并后的 transformer)
|
| 40 |
+
dtype: 数据类型
|
| 41 |
+
device: 加载设备(建议用 CPU 以节省显存)
|
| 42 |
+
"""
|
| 43 |
+
print("=" * 80)
|
| 44 |
+
print("Step 1: 加载基础 Flux Fill 模型...")
|
| 45 |
+
print("=" * 80)
|
| 46 |
+
|
| 47 |
+
# 加载基础模型(只加载 transformer 部分以节省内存)
|
| 48 |
+
pipe = FluxFillPipeline.from_pretrained(
|
| 49 |
+
base_model_path,
|
| 50 |
+
torch_dtype=dtype,
|
| 51 |
+
low_cpu_mem_usage=True
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
print(f"✓ 基础模型加载完成: {base_model_path}")
|
| 55 |
+
print(f" Transformer 参数量: {sum(p.numel() for p in pipe.transformer.parameters()) / 1e9:.2f}B")
|
| 56 |
+
|
| 57 |
+
# 移动到指定设备
|
| 58 |
+
if device != "cpu":
|
| 59 |
+
print(f" 移动 transformer 到 {device}...")
|
| 60 |
+
pipe.transformer = pipe.transformer.to(device)
|
| 61 |
+
|
| 62 |
+
print("\n" + "=" * 80)
|
| 63 |
+
print("Step 2: 加载 PEFT LoRA 权重...")
|
| 64 |
+
print("=" * 80)
|
| 65 |
+
|
| 66 |
+
# 加载 PEFT 模型
|
| 67 |
+
print(f" 从 {lora_weights_path} 加载 LoRA 权重...")
|
| 68 |
+
peft_model = PeftModel.from_pretrained(
|
| 69 |
+
pipe.transformer,
|
| 70 |
+
lora_weights_path,
|
| 71 |
+
is_trainable=False
|
| 72 |
+
)
|
| 73 |
+
peft_model.set_adapter("default")
|
| 74 |
+
|
| 75 |
+
print(f"✓ PEFT 模型加载完成")
|
| 76 |
+
|
| 77 |
+
# 检查 LoRA 配置
|
| 78 |
+
lora_config = peft_model.peft_config.get("default", None)
|
| 79 |
+
if lora_config:
|
| 80 |
+
print(f" LoRA 配置:")
|
| 81 |
+
print(f" - Rank (r): {lora_config.r}")
|
| 82 |
+
print(f" - Alpha: {lora_config.lora_alpha}")
|
| 83 |
+
print(f" - Dropout: {lora_config.lora_dropout}")
|
| 84 |
+
print(f" - Target modules: {lora_config.target_modules}")
|
| 85 |
+
|
| 86 |
+
print("\n" + "=" * 80)
|
| 87 |
+
print("Step 3: 合并 LoRA 权重到基础模型...")
|
| 88 |
+
print("=" * 80)
|
| 89 |
+
|
| 90 |
+
# 合并权重
|
| 91 |
+
merged_model = peft_model.merge_and_unload()
|
| 92 |
+
|
| 93 |
+
print(f"✓ 权重合并完成")
|
| 94 |
+
print(f" 合并后模型参数量: {sum(p.numel() for p in merged_model.parameters()) / 1e9:.2f}B")
|
| 95 |
+
|
| 96 |
+
print("\n" + "=" * 80)
|
| 97 |
+
print("Step 4: 保存合并后的模型...")
|
| 98 |
+
print("=" * 80)
|
| 99 |
+
|
| 100 |
+
# 创建输出目录
|
| 101 |
+
os.makedirs(output_path, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
# 保存合并后的 transformer
|
| 104 |
+
print(f" 保存到 {output_path}...")
|
| 105 |
+
merged_model.save_pretrained(
|
| 106 |
+
output_path,
|
| 107 |
+
safe_serialization=True, # 使用 safetensors 格式
|
| 108 |
+
max_shard_size="5GB" # 分片大小
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
print(f"✓ 模型保存完成: {output_path}")
|
| 112 |
+
|
| 113 |
+
# 保存模型配置信息
|
| 114 |
+
info_path = os.path.join(output_path, "merge_info.txt")
|
| 115 |
+
with open(info_path, "w") as f:
|
| 116 |
+
f.write(f"Base model: {base_model_path}\n")
|
| 117 |
+
f.write(f"LoRA weights: {lora_weights_path}\n")
|
| 118 |
+
f.write(f"Merged model: {output_path}\n")
|
| 119 |
+
f.write(f"Data type: {dtype}\n")
|
| 120 |
+
if lora_config:
|
| 121 |
+
f.write(f"\nLoRA Configuration:\n")
|
| 122 |
+
f.write(f" Rank (r): {lora_config.r}\n")
|
| 123 |
+
f.write(f" Alpha: {lora_config.lora_alpha}\n")
|
| 124 |
+
f.write(f" Dropout: {lora_config.lora_dropout}\n")
|
| 125 |
+
f.write(f" Target modules: {lora_config.target_modules}\n")
|
| 126 |
+
|
| 127 |
+
print(f"✓ 合并信息保存到: {info_path}")
|
| 128 |
+
|
| 129 |
+
return merged_model
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def merge_and_save_full_pipeline(
|
| 133 |
+
base_model_path: str,
|
| 134 |
+
lora_weights_path: str,
|
| 135 |
+
output_path: str,
|
| 136 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 137 |
+
device: str = "cpu"
|
| 138 |
+
):
|
| 139 |
+
"""
|
| 140 |
+
合并 LoRA 权重并保存完整的 FluxFillPipeline
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
base_model_path: 基础 Flux Fill 模型路径
|
| 144 |
+
lora_weights_path: PEFT LoRA 权重路径
|
| 145 |
+
output_path: 输出路径(保存完整 pipeline)
|
| 146 |
+
dtype: 数据类型
|
| 147 |
+
device: 加载设备
|
| 148 |
+
"""
|
| 149 |
+
print("=" * 80)
|
| 150 |
+
print("Step 1: 加载基础 Flux Fill Pipeline...")
|
| 151 |
+
print("=" * 80)
|
| 152 |
+
|
| 153 |
+
# 加载完整 pipeline
|
| 154 |
+
pipe = FluxFillPipeline.from_pretrained(
|
| 155 |
+
base_model_path,
|
| 156 |
+
torch_dtype=dtype,
|
| 157 |
+
low_cpu_mem_usage=True
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
print(f"✓ Pipeline 加载完成: {base_model_path}")
|
| 161 |
+
|
| 162 |
+
# 移动到指定设备
|
| 163 |
+
if device != "cpu":
|
| 164 |
+
print(f" 移动 transformer 到 {device}...")
|
| 165 |
+
pipe.transformer = pipe.transformer.to(device)
|
| 166 |
+
|
| 167 |
+
print("\n" + "=" * 80)
|
| 168 |
+
print("Step 2: 加载并合并 PEFT LoRA 权重...")
|
| 169 |
+
print("=" * 80)
|
| 170 |
+
|
| 171 |
+
# 加载 PEFT 模型
|
| 172 |
+
peft_model = PeftModel.from_pretrained(
|
| 173 |
+
pipe.transformer,
|
| 174 |
+
lora_weights_path,
|
| 175 |
+
is_trainable=False
|
| 176 |
+
)
|
| 177 |
+
peft_model.set_adapter("default")
|
| 178 |
+
|
| 179 |
+
# 合并权重
|
| 180 |
+
merged_transformer = peft_model.merge_and_unload()
|
| 181 |
+
|
| 182 |
+
# 替换 pipeline 中的 transformer
|
| 183 |
+
pipe.transformer = merged_transformer
|
| 184 |
+
|
| 185 |
+
print(f"✓ 权重合并完成")
|
| 186 |
+
|
| 187 |
+
print("\n" + "=" * 80)
|
| 188 |
+
print("Step 3: 保存完整 Pipeline...")
|
| 189 |
+
print("=" * 80)
|
| 190 |
+
|
| 191 |
+
# 创建输出目录
|
| 192 |
+
os.makedirs(output_path, exist_ok=True)
|
| 193 |
+
|
| 194 |
+
# 保存完整 pipeline
|
| 195 |
+
print(f" 保存到 {output_path}...")
|
| 196 |
+
pipe.save_pretrained(
|
| 197 |
+
output_path,
|
| 198 |
+
safe_serialization=True,
|
| 199 |
+
max_shard_size="5GB"
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
print(f"✓ 完整 Pipeline 保存完成: {output_path}")
|
| 203 |
+
|
| 204 |
+
# 保存合并信息
|
| 205 |
+
info_path = os.path.join(output_path, "merge_info.txt")
|
| 206 |
+
with open(info_path, "w") as f:
|
| 207 |
+
f.write(f"Base model: {base_model_path}\n")
|
| 208 |
+
f.write(f"LoRA weights: {lora_weights_path}\n")
|
| 209 |
+
f.write(f"Merged pipeline: {output_path}\n")
|
| 210 |
+
f.write(f"Data type: {dtype}\n")
|
| 211 |
+
f.write(f"Components saved:\n")
|
| 212 |
+
f.write(f" - Transformer (merged with LoRA)\n")
|
| 213 |
+
f.write(f" - VAE\n")
|
| 214 |
+
f.write(f" - Text Encoder\n")
|
| 215 |
+
f.write(f" - Scheduler\n")
|
| 216 |
+
f.write(f" - Other components\n")
|
| 217 |
+
|
| 218 |
+
print(f"✓ 合并信息保存到: {info_path}")
|
| 219 |
+
|
| 220 |
+
return pipe
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def main():
|
| 224 |
+
parser = argparse.ArgumentParser(
|
| 225 |
+
description="将 PEFT LoRA 权重合并到基础 Flux Transformer 模型中"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--base_model_path",
|
| 230 |
+
type=str,
|
| 231 |
+
default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4",
|
| 232 |
+
help="基础 Flux Fill 模型路径"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--lora_weights_path",
|
| 237 |
+
type=str,
|
| 238 |
+
default="/home/tione/notebook2/research/retofan/code/RL/flow_grpo/logs/defectgen_det/flux_fill_redux/checkpoints/checkpoint-60/lora",
|
| 239 |
+
help="PEFT LoRA 权重路径"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--output_path",
|
| 244 |
+
type=str,
|
| 245 |
+
default="/home/tione/notebook/research/retofan/ckpt/FLUX.1-Fill-dev-UDG-1121_e4_defect_gen_det_e60",
|
| 246 |
+
help="输出路径(保存合并后的模型)"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--save_full_pipeline",
|
| 251 |
+
action="store_true",
|
| 252 |
+
help="保存完整 FluxFillPipeline(包含 VAE、Text Encoder 等),而不只是 Transformer"
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--dtype",
|
| 257 |
+
type=str,
|
| 258 |
+
default="bfloat16",
|
| 259 |
+
choices=["float32", "float16", "bfloat16"],
|
| 260 |
+
help="数据类型"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--device",
|
| 265 |
+
type=str,
|
| 266 |
+
default="cpu",
|
| 267 |
+
help="加载设备(cpu 或 cuda:0 等)。建议使用 cpu 以节省显存"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
args = parser.parse_args()
|
| 271 |
+
|
| 272 |
+
# 转换 dtype
|
| 273 |
+
dtype_map = {
|
| 274 |
+
"float32": torch.float32,
|
| 275 |
+
"float16": torch.float16,
|
| 276 |
+
"bfloat16": torch.bfloat16
|
| 277 |
+
}
|
| 278 |
+
dtype = dtype_map[args.dtype]
|
| 279 |
+
|
| 280 |
+
# 检查路径
|
| 281 |
+
if not os.path.exists(args.base_model_path):
|
| 282 |
+
print(f"错误: 基础模型路径不存在: {args.base_model_path}")
|
| 283 |
+
return
|
| 284 |
+
|
| 285 |
+
if not os.path.exists(args.lora_weights_path):
|
| 286 |
+
print(f"错误: LoRA 权重路径不存在: {args.lora_weights_path}")
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
print("\n" + "=" * 80)
|
| 290 |
+
print("PEFT LoRA 权重合并工具")
|
| 291 |
+
print("=" * 80)
|
| 292 |
+
print(f"基础模型: {args.base_model_path}")
|
| 293 |
+
print(f"LoRA 权重: {args.lora_weights_path}")
|
| 294 |
+
print(f"输出路径: {args.output_path}")
|
| 295 |
+
print(f"保存类型: {'完整 Pipeline' if args.save_full_pipeline else '仅 Transformer'}")
|
| 296 |
+
print(f"数据类型: {args.dtype}")
|
| 297 |
+
print(f"加载设备: {args.device}")
|
| 298 |
+
print("=" * 80 + "\n")
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
if args.save_full_pipeline:
|
| 302 |
+
# 保存完整 pipeline
|
| 303 |
+
merge_and_save_full_pipeline(
|
| 304 |
+
base_model_path=args.base_model_path,
|
| 305 |
+
lora_weights_path=args.lora_weights_path,
|
| 306 |
+
output_path=args.output_path,
|
| 307 |
+
dtype=dtype,
|
| 308 |
+
device=args.device
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
# 只保存 transformer
|
| 312 |
+
merge_and_save_transformer(
|
| 313 |
+
base_model_path=args.base_model_path,
|
| 314 |
+
lora_weights_path=args.lora_weights_path,
|
| 315 |
+
output_path=args.output_path,
|
| 316 |
+
dtype=dtype,
|
| 317 |
+
device=args.device
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
print("\n" + "=" * 80)
|
| 321 |
+
print("✅ 合并完成!")
|
| 322 |
+
print("=" * 80)
|
| 323 |
+
print(f"\n合并后的模型已保存到: {args.output_path}")
|
| 324 |
+
print("\n使用方法:")
|
| 325 |
+
|
| 326 |
+
if args.save_full_pipeline:
|
| 327 |
+
print(" # 直接加载合并后的完整 pipeline")
|
| 328 |
+
print(" from diffusers import FluxFillPipeline")
|
| 329 |
+
print(f" pipe = FluxFillPipeline.from_pretrained('{args.output_path}')")
|
| 330 |
+
else:
|
| 331 |
+
print(" # 加载基础 pipeline,然后替换 transformer")
|
| 332 |
+
print(" from diffusers import FluxFillPipeline")
|
| 333 |
+
print(" from diffusers.models import FluxTransformer2DModel")
|
| 334 |
+
print(f" pipe = FluxFillPipeline.from_pretrained('{args.base_model_path}')")
|
| 335 |
+
print(f" pipe.transformer = FluxTransformer2DModel.from_pretrained('{args.output_path}')")
|
| 336 |
+
|
| 337 |
+
print("\n" + "=" * 80)
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
print(f"\n❌ 错误: {e}")
|
| 341 |
+
import traceback
|
| 342 |
+
traceback.print_exc()
|
| 343 |
+
return 1
|
| 344 |
+
|
| 345 |
+
return 0
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
exit(main())
|