| --- |
| license: apache-2.0 |
| tags: |
| - defect-generation |
| - anomaly-detection |
| - industrial-inspection |
| - lora |
| - flux |
| - diffusion |
| - rlhf |
| language: en |
| pipeline_tag: image-to-image |
| --- |
| |
| # UniDG-RFT-LoRA |
|
|
| LoRA weights for **UniDG** (Universal Defect Generation), trained via **Consistency-RFT** with Flow-GRPO and dual reward models on the UDG dataset (300K quadruplets). |
|
|
| [[Paper]](https://arxiv.org/abs/2604.08915) [[Code]](https://github.com/RetoFan233/UniDG) [[UniDG-SFT-LoRA]](https://huggingface.co/retofan23333/UniDG-SFT-LoRA-Release) |
|
|
| ## Overview |
|
|
| UniDG is a universal defect generation foundation model that transfers defects from a reference image to a target region via **Defect-Context Editing** and **MM-DiT multimodal attention**, without per-category fine-tuning. This checkpoint is the **Consistency-RFT** variant, further refined from UniDG-SFT using Flow-GRPO with dual reward models (Defect-Und-Reward & Defect-Recog-Reward) for improved defect fidelity and consistency. |
|
|
| | Variant | Training | Focus | |
| |---------|----------|-------| |
| | UniDG-SFT | Diversity-SFT with complementary sampling | Diverse defect patterns | |
| | **UniDG-RFT** (this) | Consistency-RFT with Flow-GRPO + dual rewards | Consistent & faithful defects | |
|
|
| ## Important: Usage Difference from UniDG-SFT-LoRA |
|
|
| **The UniDG-RFT-LoRA weights are stored in PEFT format** (`adapter_model.safetensors` + `adapter_config.json`), which is different from UniDG-SFT-LoRA (which uses `pytorch_lora_weights.safetensors`). This means: |
|
|
| - **UniDG-SFT-LoRA** can be directly loaded via the `lora_weights_path` parameter in `ImageUniDG`. |
| - **UniDG-RFT-LoRA** must first be **merged into the base SFT model** using the provided `combine_peft_weights.py` script. After merging, the resulting model can be loaded directly without any additional LoRA loading step. |
|
|
| ## Repository Contents |
|
|
| | File | Description | |
| |------|-------------| |
| | `adapter_model.safetensors` | PEFT LoRA weights (Consistency-RFT) | |
| | `adapter_config.json` | LoRA configuration (rank=64, alpha=128) | |
| | `combine_peft_weights.py` | Script to merge RFT LoRA into the base SFT model | |
|
|
| ## Step-by-Step Usage |
|
|
| ### Prerequisites |
|
|
| - [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) (inpainting backbone) |
| - [FLUX.1-Redux-dev](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) (reference conditioning) |
| - [UniDG-SFT-LoRA](https://huggingface.co/retofan23333/UniDG-SFT-LoRA-Release) (base SFT model — the RFT LoRA is fine-tuned on top of this) |
| - [UniDG code](https://github.com/RetoFan233/UniDG) (inference framework) |
| - Python dependencies: `diffusers`, `peft`, `torch` |
|
|
| ### Step 1: Prepare the Base SFT Model |
|
|
| First, you need a base FLUX.1-Fill-dev model with UniDG-SFT-LoRA weights already merged in. If you haven't done this, you can prepare it by loading the SFT model and saving the merged weights: |
|
|
| ```python |
| from diffusers import FluxFillPipeline |
| import torch |
| |
| # Load base FLUX.1-Fill-dev |
| pipe = FluxFillPipeline.from_pretrained( |
| "path/to/FLUX.1-Fill-dev", |
| torch_dtype=torch.bfloat16, |
| ) |
| |
| # Load SFT LoRA weights |
| pipe.load_lora_weights("path/to/UniDG-SFT-LoRA-Release/pytorch_lora_weights.safetensors") |
| |
| # Save the merged SFT model as the base for RFT merging |
| pipe.save_pretrained("path/to/FLUX.1-Fill-dev-UDG-SFT", safe_serialization=True, max_shard_size="5GB") |
| ``` |
|
|
| ### Step 2: Merge RFT LoRA into the Base SFT Model |
|
|
| Use the provided `combine_peft_weights.py` to merge the RFT LoRA weights into the base SFT model: |
|
|
| ```bash |
| python combine_peft_weights.py \ |
| --base_model_path path/to/FLUX.1-Fill-dev-UDG-SFT \ |
| --lora_weights_path path/to/UniDG-RFT-LoRA-Release \ |
| --output_path path/to/FLUX.1-Fill-dev-UDG-RFT \ |
| --save_full_pipeline |
| ``` |
|
|
| Parameters: |
| - `--base_model_path`: Path to the base SFT model (from Step 1) |
| - `--lora_weights_path`: Path to this RFT LoRA repository (containing `adapter_model.safetensors` and `adapter_config.json`) |
| - `--output_path`: Output path for the merged model |
| - `--save_full_pipeline`: Save the full pipeline (including VAE, text encoder, etc.) so you can load it directly later |
| - `--dtype`: Data type, default `bfloat16` |
| - `--device`: Device for loading, default `cpu` (recommended to avoid OOM) |
|
|
| > **Tip**: Use `--device cpu` (default) to save GPU memory during the merge process. The merge only needs to run once. |
|
|
| ### Step 3: Use the Merged Model with UniDG |
|
|
| After merging, the model can be used directly with the UniDG inference code — **no additional LoRA loading is needed**: |
|
|
| ```python |
| from unidg import ImageUniDG |
| from PIL import Image |
| import torch |
| |
| # Load the merged RFT model — set lora_weights_path="" since LoRA is already merged |
| model = ImageUniDG( |
| flux_model_path="path/to/FLUX.1-Fill-dev-UDG-RFT", |
| redux_model_path="path/to/FLUX.1-Redux-dev", |
| lora_weights_path="", # No additional LoRA needed! |
| device="cuda:0", |
| dtype=torch.bfloat16, |
| ) |
| |
| result, mask = model.process_images( |
| target_image=Image.open("target.jpg"), |
| reference_image=Image.open("reference.jpg"), |
| reference_mask=Image.open("reference_mask.png"), |
| target_mask=Image.open("target_mask.png"), |
| num_inference_steps=28, |
| guidance_scale=3.5, |
| seed=42, |
| ) |
| result.save("result.png") |
| ``` |
|
|
| ### Quick Reference: SFT vs RFT Usage |
|
|
| | | UniDG-SFT | UniDG-RFT | |
| |---|-----------|-----------| |
| | Weight format | `pytorch_lora_weights.safetensors` | `adapter_model.safetensors` + `adapter_config.json` | |
| | Merge required? | No | Yes (with SFT base model) | |
| | `lora_weights_path` | Path to SFT weights | `""` (empty, after merge) | |
| | `flux_model_path` | `path/to/FLUX.1-Fill-dev` | `path/to/merged-RFT-model` | |
| | Load time | LoRA loaded on-the-fly | Pre-merged, no LoRA overhead | |
|
|
| ## LoRA Configuration |
|
|
| | Parameter | Value | |
| |-----------|-------| |
| | PEFT type | LORA | |
| | Rank (r) | 64 | |
| | Alpha | 128 | |
| | Dropout | 0.0 | |
| | Target modules | `ff.net.0.proj`, `ff.net.2`, `ff_context.net.0.proj`, `proj_mlp`, `attn.to_q`, `attn.to_v`, `attn.to_add_out`, `attn.add_k_proj`, `attn.add_v_proj`, `ff_context.net.2`, `attn.add_q_proj`, `attn.to_out.0`, `attn.to_k` | |
| | Base model | FLUX.1-Fill-dev + UniDG-SFT-LoRA | |
|
|
| ## Citation |
|
|
| ```bibtex |
| @article{fan2026unidg, |
| title={Large-Scale Universal Defect Generation: Foundation Models and Datasets}, |
| author={Fan, Yuanting and Liu, Jun and Gao, Bin-Bin and Chen, Xiaochen and Lin, Yuhuan and Dai, Zhewei and Zhan, Jiawei and Wang, Chengjie}, |
| journal={arXiv preprint arXiv:2604.08915}, |
| year={2026} |
| } |
| ``` |
|
|