TAI Research commited on
Commit ·
29691f6
0
Parent(s):
Initial commit: Lumina_Dev_Legacy (archived)
Browse files- LICENSE +21 -0
- README.md +158 -0
- SOLUTION_SUMMARY.md +49 -0
- configs/data/laion_filtered.yaml +52 -0
- configs/model/diffusion.yaml +19 -0
- configs/model/unet_light.yaml +22 -0
- configs/training/p4_optimized.yaml +69 -0
- configs/training/schedule_256.yaml +33 -0
- data/laion/dataset_info.json +14 -0
- data/laion/metadata.parquet +0 -0
- docs/LAION_DATASET_GUIDE.md +279 -0
- requirements.txt +28 -0
- scripts/benchmark.py +490 -0
- scripts/download_laion.py +272 -0
- scripts/export.py +270 -0
- scripts/train.py +293 -0
- src/data/__pycache__/dataset.cpython-313.pyc +0 -0
- src/data/dataset.py +298 -0
- src/data/preprocessing.py +295 -0
- src/data/text_encoder.py +227 -0
- src/inference/api.py +631 -0
- src/inference/optimization.py +427 -0
- src/inference/sampler.py +428 -0
- src/models/attention.py +143 -0
- src/models/diffusion.py +263 -0
- src/models/unet_light.py +379 -0
- src/training/callbacks.py +324 -0
- src/training/memory_manager.py +245 -0
- src/training/trainer_p4.py +378 -0
- tests/test_basic.py +250 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 TAI Research
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- image-generation
|
| 5 |
+
- legacy
|
| 6 |
+
- archived
|
| 7 |
+
- research
|
| 8 |
+
- pytorch
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# 🎨 Lumina_Dev_Legacy — [ARCHIVED]
|
| 12 |
+
|
| 13 |
+
> **⚠️ 项目状态:已归档 / 停止开发**
|
| 14 |
+
>
|
| 15 |
+
> 本项目为 TAI Research 早期探索性项目,开发阶段已终止,代码仅作存档用途。项目未完成完整训练,不提供预训练模型权重。
|
| 16 |
+
|
| 17 |
+
| 项目信息 | |
|
| 18 |
+
|---------|---|
|
| 19 |
+
| **状态** | 🔴 已废弃 (Archived) |
|
| 20 |
+
| **原因** | 项目方向调整,资源重新分配至其他研究领域 |
|
| 21 |
+
| **最后更新** | 2026 |
|
| 22 |
+
| **训练状态** | ❌ 未完成训练 |
|
| 23 |
+
| **可用资源** | 仅代码框架,无预训练权重 |
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## 项目背景
|
| 28 |
+
|
| 29 |
+
Lumina 原本的目标是开发一个**针对有限硬件(特别是 NVIDIA P4,8GB 显存)优化的轻量级图像生成模型**,基于扩散模型架构,专注于文本到图像生成。
|
| 30 |
+
|
| 31 |
+
### 原始设计目标
|
| 32 |
+
|
| 33 |
+
- **极致的显存优化**:专门为 8GB 显存的 GPU 优化
|
| 34 |
+
- **轻量级架构**:参数量 < 2000 万
|
| 35 |
+
- **完整训练管道**:从数据预处理到模型训练
|
| 36 |
+
- **高效推理**:支持多种采样器
|
| 37 |
+
- **模块化设计**:易于扩展和定制
|
| 38 |
+
|
| 39 |
+
### 为什么废弃?
|
| 40 |
+
|
| 41 |
+
1. **方向调整**:TAI Research 将重心转向语言模型和 AI 安全领域
|
| 42 |
+
2. **资源限制**:图像生成模型训练需要大量计算资源
|
| 43 |
+
3. **竞争环境**:开源社区已有成熟方案(Stable Diffusion、Flux 等)
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## 代码内容
|
| 48 |
+
|
| 49 |
+
本仓库仅包含开发阶段的代码框架:
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
lumina_legacy/
|
| 53 |
+
├── configs/ # 配置文件
|
| 54 |
+
│ ├── model/ # 模型配置(UNet架构)
|
| 55 |
+
│ ├── training/ # 训练配置(P4优化)
|
| 56 |
+
│ └── data/ # 数据配置
|
| 57 |
+
├── src/ # 源代码
|
| 58 |
+
│ ├── models/ # UNet + 注意力机制
|
| 59 |
+
│ ├── training/ # 训练器 + 内存优化
|
| 60 |
+
│ ├── data/ # LAION数据处理
|
| 61 |
+
│ └── inference/ # 采样器(DDIM/DPM/LCM)
|
| 62 |
+
├── scripts/ # 工具脚本
|
| 63 |
+
│ ├── train.py # 训练入口
|
| 64 |
+
│ ├── download_laion.py # 数据下载
|
| 65 |
+
│ └── webui.py # Gradio界面
|
| 66 |
+
└── tests/ # 单元测试
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
**⚠️ 注意**:
|
| 70 |
+
- 代码未经完整测试,可能存在 bug
|
| 71 |
+
- 训练管道未验证
|
| 72 |
+
- 不保证可复现
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## 技术栈
|
| 77 |
+
|
| 78 |
+
| 组件 | 技术 |
|
| 79 |
+
|------|------|
|
| 80 |
+
| 框架 | PyTorch 2.0+ |
|
| 81 |
+
| 架构 | 轻量级 UNet + Cross-Attention |
|
| 82 |
+
| 扩散 | DDPM / DDIM |
|
| 83 |
+
| 文本编码 | CLIP (预训练冻结) |
|
| 84 |
+
| 精度 | FP16 混合精度 |
|
| 85 |
+
| 优化 | 梯度检查点 + 梯度累积 |
|
| 86 |
+
|
| 87 |
+
### 原始设计架构
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
输入 (4×64×64)
|
| 91 |
+
↓
|
| 92 |
+
Conv2d (4→64)
|
| 93 |
+
↓
|
| 94 |
+
[下采样块 × 4]
|
| 95 |
+
↓
|
| 96 |
+
[注意力层 (8×8)]
|
| 97 |
+
↓
|
| 98 |
+
[上采样块 × 4]
|
| 99 |
+
↓
|
| 100 |
+
Conv2d (64→4)
|
| 101 |
+
↓
|
| 102 |
+
输出 (4×64×64)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
---
|
| 106 |
+
|
| 107 |
+
## 如何使用这些代码
|
| 108 |
+
|
| 109 |
+
> ⚠️ 代码仅为存档,不提供任何保证
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
# 克隆仓库
|
| 113 |
+
git clone https://huggingface.co/TAI-Research/Lumina_Dev_Legacy
|
| 114 |
+
cd Lumina_Dev_Legacy
|
| 115 |
+
|
| 116 |
+
# 安装依赖(如需要)
|
| 117 |
+
pip install -r requirements.txt
|
| 118 |
+
|
| 119 |
+
# 尝试运行训练(仅测试代码是否可运行)
|
| 120 |
+
python scripts/train.py --config configs/training/p4_optimized.yaml --dummy
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
---
|
| 124 |
+
|
| 125 |
+
## 已知问题
|
| 126 |
+
|
| 127 |
+
- [ ] 训练管道未完整验证
|
| 128 |
+
- [ ] 数据处理模块有性能问题
|
| 129 |
+
- [ ] 推理采样器可能有 bug
|
| 130 |
+
- [ ] 缺少单元测试覆盖
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## 后续
|
| 135 |
+
|
| 136 |
+
如果你对这个项目感兴趣,建议使用成熟的开源方案:
|
| 137 |
+
|
| 138 |
+
- [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
|
| 139 |
+
- [Hugging Face Diffusers](https://github.com/huggingface/diffusers)
|
| 140 |
+
- [Flux](https://github.com/black-forest-labs/flux)
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## 许可证
|
| 145 |
+
|
| 146 |
+
MIT License — 代码可自由使用,但无任何保证。
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## 相关链接
|
| 151 |
+
|
| 152 |
+
- [TAI Research Hugging Face](https://huggingface.co/TAI-Research)
|
| 153 |
+
- [GTC-Guard-0](https://huggingface.co/TAI-Research/GTC-Guard-0)
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
|
| 157 |
+
**最后更新**: 2026-05-23
|
| 158 |
+
**归档者**: TAI Research
|
SOLUTION_SUMMARY.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- image-generation
|
| 5 |
+
- legacy
|
| 6 |
+
- archived
|
| 7 |
+
- research
|
| 8 |
+
- pytorch
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# 🎨 Lumina_Dev_Legacy — [ARCHIVED]
|
| 12 |
+
|
| 13 |
+
> **⚠️ 项目状态:已归档 / 停止开发**
|
| 14 |
+
>
|
| 15 |
+
> 本项目为 TAI Research 早期探索性项目,开发阶段已终止,代码仅作存档用途。项目未完成完整训练,不提供预训练模型权重。
|
| 16 |
+
|
| 17 |
+
| 项目信息 | |
|
| 18 |
+
|---------|---|
|
| 19 |
+
| **状态** | 🔴 已废弃 (Archived) |
|
| 20 |
+
| **原因** | 项目方向调整,资源重新分配至其他研究领域 |
|
| 21 |
+
| **最后更新** | 2026 |
|
| 22 |
+
| **训练状态** | ❌ 未完成训练 |
|
| 23 |
+
| **可用资源** | 仅代码框架,无预训练权重 |
|
| 24 |
+
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
## 项目背景
|
| 28 |
+
|
| 29 |
+
Lumina 原本的目标是开发一个**针对有限硬件(特别是 NVIDIA P4,8GB 显存)优化的轻量级图像生成模型**,基于扩散模型架构,专注于文本到图像生成。
|
| 30 |
+
|
| 31 |
+
### 原始设计目标
|
| 32 |
+
|
| 33 |
+
- **极致的显存优化**:专门为 8GB 显存的 GPU 优化
|
| 34 |
+
- **轻量级架构**:参数量 < 2000 万
|
| 35 |
+
- **完整训练管道**:从数据预处理到模型训练
|
| 36 |
+
- **高效推理**:支持多种采样器
|
| 37 |
+
- **模块化设计**:易于扩展和定制
|
| 38 |
+
|
| 39 |
+
### 为什么废弃?
|
| 40 |
+
|
| 41 |
+
1. **方向调整**:TAI Research 将重心转向语言模型和 AI 安全领域
|
| 42 |
+
2. **资源限制**:图像生成模型训练需要大量计算资源
|
| 43 |
+
3. **竞争环境**:开源社区已有成熟方案(Stable Diffusion、Flux 等)
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## 代码内容
|
| 48 |
+
|
| 49 |
+
本仓库仅包含开发阶段的代码框架:
|
configs/data/laion_filtered.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 数据处理配置
|
| 2 |
+
dataset:
|
| 3 |
+
name: "laion-aesthetic"
|
| 4 |
+
path: "./data/laion" # 数据存放路径
|
| 5 |
+
metadata_file: "./data/laion/metadata.parquet"
|
| 6 |
+
|
| 7 |
+
# 过滤条件
|
| 8 |
+
filters:
|
| 9 |
+
aesthetic_score: 6.0
|
| 10 |
+
watermark_prob: 0.5
|
| 11 |
+
nsfw: false
|
| 12 |
+
|
| 13 |
+
# 数据拆分
|
| 14 |
+
split:
|
| 15 |
+
train: 0.95
|
| 16 |
+
val: 0.05
|
| 17 |
+
seed: 42
|
| 18 |
+
|
| 19 |
+
# 处理配置
|
| 20 |
+
max_samples: 2000000 # 最大样本数
|
| 21 |
+
shuffle: true
|
| 22 |
+
shuffle_seed: 42
|
| 23 |
+
|
| 24 |
+
preprocessing:
|
| 25 |
+
# 图像处理
|
| 26 |
+
target_size: 512
|
| 27 |
+
resize_mode: "center_crop" # "center_crop", "random_crop", "resize"
|
| 28 |
+
random_crop: true
|
| 29 |
+
random_flip: true
|
| 30 |
+
|
| 31 |
+
# 归一化
|
| 32 |
+
normalize:
|
| 33 |
+
mean: [0.5, 0.5, 0.5]
|
| 34 |
+
std: [0.5, 0.5, 0.5]
|
| 35 |
+
|
| 36 |
+
# 文本处理
|
| 37 |
+
tokenizer: "openai/clip-vit-base-patch32"
|
| 38 |
+
max_length: 77
|
| 39 |
+
truncation: true
|
| 40 |
+
padding: "max_length"
|
| 41 |
+
|
| 42 |
+
# 缓存
|
| 43 |
+
use_cache: true
|
| 44 |
+
cache_dir: "./data/cache"
|
| 45 |
+
cache_compression: true
|
| 46 |
+
|
| 47 |
+
loader:
|
| 48 |
+
batch_size: 1
|
| 49 |
+
shuffle: true
|
| 50 |
+
num_workers: 2
|
| 51 |
+
prefetch_factor: 2
|
| 52 |
+
persistent_workers: true
|
configs/model/diffusion.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 扩散过程配置
|
| 2 |
+
diffusion:
|
| 3 |
+
beta_schedule: "scaled_linear"
|
| 4 |
+
beta_start: 0.00085
|
| 5 |
+
beta_end: 0.012
|
| 6 |
+
num_train_timesteps: 1000
|
| 7 |
+
num_inference_timesteps: 50
|
| 8 |
+
|
| 9 |
+
# 损失函数
|
| 10 |
+
loss_type: "l2" # 或 "l1"
|
| 11 |
+
snr_gamma: null
|
| 12 |
+
|
| 13 |
+
# 采样器
|
| 14 |
+
sampler_type: "ddim" # 或 "dpm++_2m"
|
| 15 |
+
prediction_type: "epsilon"
|
| 16 |
+
|
| 17 |
+
# 训练参数
|
| 18 |
+
vae_scale_factor: 0.18215
|
| 19 |
+
offset_noise_strength: 0.1
|
configs/model/unet_light.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 轻量UNet配置
|
| 2 |
+
model:
|
| 3 |
+
in_channels: 4 # 潜在空间通道数
|
| 4 |
+
out_channels: 4
|
| 5 |
+
base_channels: 64
|
| 6 |
+
channel_mults: [1, 2, 4, 8] # 4次下采样
|
| 7 |
+
num_res_blocks: 2
|
| 8 |
+
attention_resolutions: [8] # 仅在最低分辨率应用注意力
|
| 9 |
+
dropout: 0.0
|
| 10 |
+
use_checkpoint: true
|
| 11 |
+
num_heads: 4
|
| 12 |
+
|
| 13 |
+
# 文本条件
|
| 14 |
+
context_dim: 768 # CLIP文本编码维度
|
| 15 |
+
use_linear_projection: true
|
| 16 |
+
|
| 17 |
+
# 时间步嵌入
|
| 18 |
+
time_embed_dim: 256
|
| 19 |
+
|
| 20 |
+
# 优化配置
|
| 21 |
+
use_flash_attention: false # P4不支持,但保留选项
|
| 22 |
+
gradient_checkpointing: true
|
configs/training/p4_optimized.yaml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# P4优化训练配置
|
| 2 |
+
hardware:
|
| 3 |
+
gpu_memory: 8 # GB
|
| 4 |
+
batch_size: 1 # 实际批大小
|
| 5 |
+
gradient_accumulation_steps: 8
|
| 6 |
+
num_workers: 2
|
| 7 |
+
pin_memory: true
|
| 8 |
+
|
| 9 |
+
# 显存优化策略
|
| 10 |
+
mixed_precision: "fp16"
|
| 11 |
+
gradient_checkpointing: true
|
| 12 |
+
attention_slicing: "auto"
|
| 13 |
+
cpu_offload: true
|
| 14 |
+
tiled_vae: false # 如果启用分块VAE解码
|
| 15 |
+
|
| 16 |
+
# 动态显存管理
|
| 17 |
+
memory_threshold_gb: 6.5
|
| 18 |
+
warning_threshold_gb: 6.0
|
| 19 |
+
cleanup_frequency: 100
|
| 20 |
+
|
| 21 |
+
training:
|
| 22 |
+
max_epochs: 50
|
| 23 |
+
learning_rate: 1e-4
|
| 24 |
+
learning_rate_scheduler: "cosine"
|
| 25 |
+
warmup_steps: 1000
|
| 26 |
+
weight_decay: 0.01
|
| 27 |
+
adam_beta1: 0.9
|
| 28 |
+
adam_beta2: 0.999
|
| 29 |
+
adam_epsilon: 1e-8
|
| 30 |
+
|
| 31 |
+
# 训练策略
|
| 32 |
+
gradient_clip: 1.0
|
| 33 |
+
use_ema: true
|
| 34 |
+
ema_decay: 0.9999
|
| 35 |
+
save_checkpoint_every: 1000
|
| 36 |
+
save_best_model: true
|
| 37 |
+
|
| 38 |
+
# 验证和监控
|
| 39 |
+
validation_steps: 500
|
| 40 |
+
sample_steps: 500
|
| 41 |
+
log_steps: 50
|
| 42 |
+
|
| 43 |
+
# 优化器状态
|
| 44 |
+
optimizer_on_cpu: true
|
| 45 |
+
|
| 46 |
+
data:
|
| 47 |
+
resolution: 512
|
| 48 |
+
center_crop: true
|
| 49 |
+
random_flip: true
|
| 50 |
+
cache_dataset: true
|
| 51 |
+
|
| 52 |
+
# 数据增强
|
| 53 |
+
augmentation:
|
| 54 |
+
random_crop: true
|
| 55 |
+
color_jitter: 0.05
|
| 56 |
+
random_rotation: 5.0 # 角度
|
| 57 |
+
|
| 58 |
+
logging:
|
| 59 |
+
use_wandb: false
|
| 60 |
+
use_tensorboard: true
|
| 61 |
+
log_dir: "./logs"
|
| 62 |
+
project_name: "lumina"
|
| 63 |
+
run_name: "lumina-v0.1"
|
| 64 |
+
|
| 65 |
+
checkpoint:
|
| 66 |
+
save_dir: "./checkpoints"
|
| 67 |
+
keep_last: 5
|
| 68 |
+
save_compressed: true
|
| 69 |
+
save_onnx: false
|
configs/training/schedule_256.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 256x256训练计划
|
| 2 |
+
phases:
|
| 3 |
+
- name: "phase1_warmup"
|
| 4 |
+
epochs: 5
|
| 5 |
+
resolution: 256
|
| 6 |
+
learning_rate: 1e-5
|
| 7 |
+
batch_size: 1
|
| 8 |
+
gradient_accumulation: 8
|
| 9 |
+
description: "预热阶段,低分辨率"
|
| 10 |
+
|
| 11 |
+
- name: "phase2_main"
|
| 12 |
+
epochs: 20
|
| 13 |
+
resolution: 256
|
| 14 |
+
learning_rate: 1e-4
|
| 15 |
+
batch_size: 1
|
| 16 |
+
gradient_accumulation: 8
|
| 17 |
+
description: "主训练阶段"
|
| 18 |
+
|
| 19 |
+
- name: "phase3_refine"
|
| 20 |
+
epochs: 10
|
| 21 |
+
resolution: 256
|
| 22 |
+
learning_rate: 5e-5
|
| 23 |
+
batch_size: 1
|
| 24 |
+
gradient_accumulation: 8
|
| 25 |
+
description: "精细调优"
|
| 26 |
+
|
| 27 |
+
- name: "phase4_upscale"
|
| 28 |
+
epochs: 15
|
| 29 |
+
resolution: 512
|
| 30 |
+
learning_rate: 2e-5
|
| 31 |
+
batch_size: 1
|
| 32 |
+
gradient_accumulation: 4
|
| 33 |
+
description: "升级到512分辨率"
|
data/laion/dataset_info.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"total_samples": 10,
|
| 3 |
+
"columns": [
|
| 4 |
+
"url",
|
| 5 |
+
"caption",
|
| 6 |
+
"aesthetic_score",
|
| 7 |
+
"watermark_prob",
|
| 8 |
+
"NSFW",
|
| 9 |
+
"image_file"
|
| 10 |
+
],
|
| 11 |
+
"dataset": "dummy",
|
| 12 |
+
"description": "\u865a\u62df\u6d4b\u8bd5\u6570\u636e\u96c6 (10\u6761\u8bb0\u5f55)",
|
| 13 |
+
"note": "\u4e0d\u5305\u542b\u5b9e\u9645\u56fe\u50cf\u6587\u4ef6\uff0c\u4ec5\u7528\u4e8e\u6d4b\u8bd5\u4ee3\u7801\u6d41\u7a0b"
|
| 14 |
+
}
|
data/laion/metadata.parquet
ADDED
|
Binary file (4.91 kB). View file
|
|
|
docs/LAION_DATASET_GUIDE.md
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LAION数据集下载指南
|
| 2 |
+
|
| 3 |
+
## 问题描述
|
| 4 |
+
|
| 5 |
+
运行代码时出现以下错误:
|
| 6 |
+
```
|
| 7 |
+
FileNotFoundError: [Errno 2] No such file or directory: './data/laion/metadata.parquet'
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
这是因为项目需要LAION数据集,但数据集文件不存在。
|
| 11 |
+
|
| 12 |
+
## LAION数据集简介
|
| 13 |
+
|
| 14 |
+
LAION(Large-scale Artificial Intelligence Open Network)是一个大规模的多模态数据集,包含:
|
| 15 |
+
- **LAION-5B**: 58.5亿个图像-文本对
|
| 16 |
+
- **LAION-Aesthetic**: 经过美学评分筛选的高质量子集
|
| 17 |
+
- **LAION-400M**: 4亿个图像-文本对
|
| 18 |
+
|
| 19 |
+
## 解决方案
|
| 20 |
+
|
| 21 |
+
### 方案1:使用虚拟数据集测试(推荐)
|
| 22 |
+
|
| 23 |
+
对于测试和开发,可以使用虚拟数据集:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
# 创建虚拟数据集
|
| 27 |
+
python scripts/download_laion.py --dummy --dummy-size 100
|
| 28 |
+
|
| 29 |
+
# 或者直接运行
|
| 30 |
+
python scripts/download_laion.py --dummy
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
这将创建一个包含100条虚拟记录的元数据文件,用于测试代码流程。
|
| 34 |
+
|
| 35 |
+
### 方案2:下载LAION-Aesthetic子集
|
| 36 |
+
|
| 37 |
+
LAION-Aesthetic是经过美学评分筛选的高质量数据集:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
# 下载LAION-Aesthetic 6.5+子集(默认)
|
| 41 |
+
python scripts/download_laion.py
|
| 42 |
+
|
| 43 |
+
# 下载LAION-Aesthetic 7.0+子集(更高质量)
|
| 44 |
+
python scripts/download_laion.py --subset "7.0+"
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### 方案3:下载LAION-5B样本
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
# 下载10,000条记录的样本
|
| 51 |
+
python scripts/download_laion.py --sample-size 10000
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## 手动下载方法
|
| 55 |
+
|
| 56 |
+
### 方法1:从Hugging Face下载
|
| 57 |
+
|
| 58 |
+
1. **访问Hugging Face数据集页面**:
|
| 59 |
+
- LAION-Aesthetic 6.5+: https://huggingface.co/datasets/laion/laion-aesthetic-6.5plus
|
| 60 |
+
- LAION-Aesthetic 7.0+: https://huggingface.co/datasets/laion/laion-aesthetic-7.0plus
|
| 61 |
+
- LAION-5B: https://huggingface.co/datasets/laion/laion2b-en
|
| 62 |
+
|
| 63 |
+
2. **下载元数据文件**:
|
| 64 |
+
```bash
|
| 65 |
+
# 创建目录
|
| 66 |
+
mkdir -p data/laion
|
| 67 |
+
|
| 68 |
+
# 下载LAION-Aesthetic 6.5+元数据
|
| 69 |
+
wget https://huggingface.co/datasets/laion/laion-aesthetic-6.5plus/resolve/main/data/00000.parquet -O data/laion/metadata.parquet
|
| 70 |
+
|
| 71 |
+
# 或者使用curl
|
| 72 |
+
curl -L https://huggingface.co/datasets/laion/laion-aesthetic-6.5plus/resolve/main/data/00000.parquet -o data/laion/metadata.parquet
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### 方法2:使用img2dataset工具
|
| 76 |
+
|
| 77 |
+
`img2dataset`是一个专门用于下载LAION数据集的工具:
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
# 安装img2dataset
|
| 81 |
+
pip install img2dataset
|
| 82 |
+
|
| 83 |
+
# 下载LAION-400M子集
|
| 84 |
+
img2dataset \
|
| 85 |
+
--url_list "path/to/laion-400m.parquet" \
|
| 86 |
+
--input_format "parquet" \
|
| 87 |
+
--url_col "URL" \
|
| 88 |
+
--caption_col "TEXT" \
|
| 89 |
+
--output_folder "data/laion/images" \
|
| 90 |
+
--processes_count 16 \
|
| 91 |
+
--thread_count 64 \
|
| 92 |
+
--image_size 512 \
|
| 93 |
+
--resize_mode "keep_ratio" \
|
| 94 |
+
--output_format "webdataset"
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### 方法3:使用官方脚本
|
| 98 |
+
|
| 99 |
+
LAION官方提供了一些下载脚本:
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
# 克隆LAION工具仓库
|
| 103 |
+
git clone https://github.com/rom1504/img2dataset.git
|
| 104 |
+
cd img2dataset
|
| 105 |
+
|
| 106 |
+
# 查看使用说明
|
| 107 |
+
python -m img2dataset --help
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
## 数据集结构
|
| 111 |
+
|
| 112 |
+
下载后,数据集目录结构应为:
|
| 113 |
+
|
| 114 |
+
```
|
| 115 |
+
data/
|
| 116 |
+
└── laion/
|
| 117 |
+
├── metadata.parquet # 元数据文件(必需)
|
| 118 |
+
├── dataset_info.json # 数据集信息文件
|
| 119 |
+
└── images/ # 图像文件目录(可选)
|
| 120 |
+
├── 00000.tar
|
| 121 |
+
├── 00001.tar
|
| 122 |
+
└── ...
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### 元数据文件格式
|
| 126 |
+
|
| 127 |
+
`metadata.parquet`文件通常包含以下列:
|
| 128 |
+
- `url`: 图像URL
|
| 129 |
+
- `caption` 或 `text`: 图像描述文本
|
| 130 |
+
- `aesthetic_score`: 美学评分(LAION-Aesthetic特有)
|
| 131 |
+
- `watermark_prob`: 水印概率
|
| 132 |
+
- `NSFW`: 成人内容标记
|
| 133 |
+
- `width`/`height`: 图像尺寸
|
| 134 |
+
|
| 135 |
+
## 验证数据集
|
| 136 |
+
|
| 137 |
+
下载完成后,验证数据集是否正确:
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
import pandas as pd
|
| 141 |
+
import os
|
| 142 |
+
|
| 143 |
+
# 检查文件是否存在
|
| 144 |
+
metadata_path = "./data/laion/metadata.parquet"
|
| 145 |
+
if os.path.exists(metadata_path):
|
| 146 |
+
print(f"元数据文件存在: {metadata_path}")
|
| 147 |
+
|
| 148 |
+
# 读取前几行
|
| 149 |
+
df = pd.read_parquet(metadata_path)
|
| 150 |
+
print(f"记录数: {len(df)}")
|
| 151 |
+
print(f"列名: {list(df.columns)}")
|
| 152 |
+
print("\n前5条记录:")
|
| 153 |
+
print(df.head())
|
| 154 |
+
else:
|
| 155 |
+
print(f"错误: 文件不存在 {metadata_path}")
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## 常见问题
|
| 159 |
+
|
| 160 |
+
### 问题1:下载速度慢
|
| 161 |
+
- **解决方案**:使用国内镜像或代理
|
| 162 |
+
- 可以尝试使用清华镜像:`pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple`
|
| 163 |
+
|
| 164 |
+
### 问题2:存储空间不足
|
| 165 |
+
- **解决方案**:
|
| 166 |
+
1. 下载较小的子集(如LAION-Aesthetic)
|
| 167 |
+
2. 使用虚拟数据集进行测试
|
| 168 |
+
3. 只下载元数据,不下载图像
|
| 169 |
+
|
| 170 |
+
### 问题3:网络连接问题
|
| 171 |
+
- **解决方案**:
|
| 172 |
+
1. 使用`--dummy`参数创建虚拟数据集
|
| 173 |
+
2. 手动下载小样本文件
|
| 174 |
+
3. 使用现有的本地数据集
|
| 175 |
+
|
| 176 |
+
### 问题4:Parquet文件读取错误
|
| 177 |
+
- **解决方案**:
|
| 178 |
+
```bash
|
| 179 |
+
# 安装正确版本的pandas和pyarrow
|
| 180 |
+
pip install pandas pyarrow fastparquet
|
| 181 |
+
|
| 182 |
+
# 或者使用dask读取
|
| 183 |
+
pip install dask[dataframe]
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
## 高级用法
|
| 187 |
+
|
| 188 |
+
### 自定义数据集
|
| 189 |
+
|
| 190 |
+
如��需要使用自定义数据集,可以修改配置文件:
|
| 191 |
+
|
| 192 |
+
```yaml
|
| 193 |
+
# configs/data/laion_filtered.yaml
|
| 194 |
+
dataset:
|
| 195 |
+
name: "custom-dataset"
|
| 196 |
+
path: "./data/custom" # 修改路径
|
| 197 |
+
metadata_file: "./data/custom/metadata.parquet" # 修改元数据文件路径
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
### 数据集预处理
|
| 201 |
+
|
| 202 |
+
项目包含数据预处理模块:
|
| 203 |
+
|
| 204 |
+
```python
|
| 205 |
+
from src.data.preprocessing import get_transform
|
| 206 |
+
|
| 207 |
+
# 获取数据变换
|
| 208 |
+
transform = get_transform(config, mode='train')
|
| 209 |
+
|
| 210 |
+
# 创建数据集
|
| 211 |
+
from src.data.dataset import LAIONDataset
|
| 212 |
+
dataset = LAIONDataset(config, transform=transform, split='train')
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
### 批量下载图像
|
| 216 |
+
|
| 217 |
+
如果需要下载实际图像文件:
|
| 218 |
+
|
| 219 |
+
```python
|
| 220 |
+
import requests
|
| 221 |
+
from PIL import Image
|
| 222 |
+
from io import BytesIO
|
| 223 |
+
import pandas as pd
|
| 224 |
+
|
| 225 |
+
# 读取元数据
|
| 226 |
+
df = pd.read_parquet("./data/laion/metadata.parquet")
|
| 227 |
+
|
| 228 |
+
# 下载前N张图像
|
| 229 |
+
for i, row in df.head(10).iterrows():
|
| 230 |
+
try:
|
| 231 |
+
response = requests.get(row['url'], timeout=10)
|
| 232 |
+
img = Image.open(BytesIO(response.content))
|
| 233 |
+
img.save(f"./data/laion/images/image_{i:06d}.jpg")
|
| 234 |
+
print(f"下载完成: {i}")
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"下载失败 {row['url']}: {e}")
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
## 性能优化建议
|
| 240 |
+
|
| 241 |
+
1. **使用缓存**:启用数据集缓存加速训练
|
| 242 |
+
```yaml
|
| 243 |
+
preprocessing:
|
| 244 |
+
use_cache: true
|
| 245 |
+
cache_dir: "./data/cache"
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
2. **数据并行**:使用多个worker加载数据
|
| 249 |
+
```yaml
|
| 250 |
+
loader:
|
| 251 |
+
num_workers: 4
|
| 252 |
+
prefetch_factor: 2
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
3. **内存映射**:对于大型数据集,使用内存映射文件
|
| 256 |
+
```python
|
| 257 |
+
df = pd.read_parquet("metadata.parquet", memory_map=True)
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
## 参考资料
|
| 261 |
+
|
| 262 |
+
1. [LAION官方网站](https://laion.ai/)
|
| 263 |
+
2. [LAION数据集论文](https://arxiv.org/abs/2210.08402)
|
| 264 |
+
3. [Hugging Face数据集](https://huggingface.co/datasets/laion)
|
| 265 |
+
4. [img2dataset工具](https://github.com/rom1504/img2dataset)
|
| 266 |
+
5. [WebDataset格式](https://github.com/webdataset/webdataset)
|
| 267 |
+
|
| 268 |
+
## 技术支持
|
| 269 |
+
|
| 270 |
+
如果遇到问题:
|
| 271 |
+
1. 检查错误信息
|
| 272 |
+
2. 查看日志文件
|
| 273 |
+
3. 参考项目README
|
| 274 |
+
4. 在GitHub Issues中搜索类似问题
|
| 275 |
+
5. 创建新的Issue寻求帮助
|
| 276 |
+
|
| 277 |
+
---
|
| 278 |
+
|
| 279 |
+
**注意**:LAION数据集受版权法保护,请确保遵守使用条款和许可证要求。
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 核心依赖
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
transformers>=4.30.0
|
| 5 |
+
diffusers>=0.20.0
|
| 6 |
+
accelerate>=0.21.0
|
| 7 |
+
|
| 8 |
+
# 数据处理
|
| 9 |
+
Pillow>=10.0.0
|
| 10 |
+
numpy>=1.24.0
|
| 11 |
+
pandas>=2.0.0
|
| 12 |
+
pyarrow>=12.0.0
|
| 13 |
+
|
| 14 |
+
# 训练工具
|
| 15 |
+
wandb>=0.15.0
|
| 16 |
+
tqdm>=4.65.0
|
| 17 |
+
matplotlib>=3.7.0
|
| 18 |
+
tensorboard>=2.13.0
|
| 19 |
+
|
| 20 |
+
# API和部署
|
| 21 |
+
gradio>=3.41.0
|
| 22 |
+
fastapi>=0.100.0
|
| 23 |
+
uvicorn>=0.23.0
|
| 24 |
+
|
| 25 |
+
# 开发工具
|
| 26 |
+
black>=23.7.0
|
| 27 |
+
flake8>=6.0.0
|
| 28 |
+
isort>=5.12.0
|
scripts/benchmark.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
性能基准测试脚本
|
| 4 |
+
测试模型训练和推理性能
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import argparse
|
| 11 |
+
import yaml
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
import numpy as np
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
# 添加项目根目录到Python路径
|
| 19 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
| 20 |
+
|
| 21 |
+
from src.models.unet_light import UNetLight
|
| 22 |
+
from src.models.diffusion import DiffusionProcess
|
| 23 |
+
from src.data.dataset import create_data_loaders
|
| 24 |
+
from src.inference.optimization import InferenceBenchmark
|
| 25 |
+
from src.inference.sampler import DDIMSampler
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_config(config_path: str) -> dict:
|
| 29 |
+
"""加载配置文件"""
|
| 30 |
+
with open(config_path, 'r') as f:
|
| 31 |
+
config = yaml.safe_load(f)
|
| 32 |
+
return config
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def benchmark_training(config: dict):
|
| 36 |
+
"""训练性能基准测试"""
|
| 37 |
+
print("=" * 60)
|
| 38 |
+
print("训练性能基准测试")
|
| 39 |
+
print("=" * 60)
|
| 40 |
+
|
| 41 |
+
# 设备
|
| 42 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 43 |
+
|
| 44 |
+
# 加载模型配置
|
| 45 |
+
model_config = load_config('configs/model/unet_light.yaml')
|
| 46 |
+
|
| 47 |
+
# 创建模型
|
| 48 |
+
model = UNetLight(model_config).to(device)
|
| 49 |
+
|
| 50 |
+
# 创建扩散过程
|
| 51 |
+
diffusion_config = load_config('configs/model/diffusion.yaml')
|
| 52 |
+
diffusion = DiffusionProcess(diffusion_config)
|
| 53 |
+
|
| 54 |
+
# 创建数据加载器
|
| 55 |
+
data_config = load_config('configs/data/laion_filtered.yaml')
|
| 56 |
+
train_loader, _ = create_data_loaders(data_config)
|
| 57 |
+
|
| 58 |
+
# 优化器
|
| 59 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
| 60 |
+
|
| 61 |
+
# 预热
|
| 62 |
+
print("预热...")
|
| 63 |
+
warmup_batches = 5
|
| 64 |
+
for i, batch in enumerate(train_loader):
|
| 65 |
+
if i >= warmup_batches:
|
| 66 |
+
break
|
| 67 |
+
|
| 68 |
+
images = batch['images'].to(device)
|
| 69 |
+
text_embeddings = batch['text_embeddings'].to(device)
|
| 70 |
+
|
| 71 |
+
# 前向传播
|
| 72 |
+
loss = diffusion.compute_loss(model, images, text_embeddings)
|
| 73 |
+
|
| 74 |
+
# 反向传播
|
| 75 |
+
loss.backward()
|
| 76 |
+
optimizer.zero_grad()
|
| 77 |
+
|
| 78 |
+
# 同步
|
| 79 |
+
if torch.cuda.is_available():
|
| 80 |
+
torch.cuda.synchronize()
|
| 81 |
+
|
| 82 |
+
# 基准测试
|
| 83 |
+
print("运行训练基准测试...")
|
| 84 |
+
num_batches = 20
|
| 85 |
+
batch_times = []
|
| 86 |
+
memory_usage = []
|
| 87 |
+
|
| 88 |
+
model.train()
|
| 89 |
+
|
| 90 |
+
for i, batch in enumerate(train_loader):
|
| 91 |
+
if i >= num_batches:
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
images = batch['images'].to(device)
|
| 95 |
+
text_embeddings = batch['text_embeddings'].to(device)
|
| 96 |
+
|
| 97 |
+
# 开始计时
|
| 98 |
+
start_time = time.time()
|
| 99 |
+
|
| 100 |
+
# 前向传播
|
| 101 |
+
loss = diffusion.compute_loss(model, images, text_embeddings)
|
| 102 |
+
|
| 103 |
+
# 反向传播
|
| 104 |
+
loss.backward()
|
| 105 |
+
optimizer.step()
|
| 106 |
+
optimizer.zero_grad()
|
| 107 |
+
|
| 108 |
+
# 同步
|
| 109 |
+
if torch.cuda.is_available():
|
| 110 |
+
torch.cuda.synchronize()
|
| 111 |
+
|
| 112 |
+
# 结束计时
|
| 113 |
+
end_time = time.time()
|
| 114 |
+
batch_time = end_time - start_time
|
| 115 |
+
batch_times.append(batch_time)
|
| 116 |
+
|
| 117 |
+
# 记录内存使用
|
| 118 |
+
if torch.cuda.is_available():
|
| 119 |
+
memory_allocated = torch.cuda.memory_allocated() / 1024**3
|
| 120 |
+
memory_usage.append(memory_allocated)
|
| 121 |
+
|
| 122 |
+
# 进度
|
| 123 |
+
print(f"批次 {i+1}/{num_batches}: {batch_time:.3f}s")
|
| 124 |
+
|
| 125 |
+
# 统计
|
| 126 |
+
batch_times = np.array(batch_times)
|
| 127 |
+
|
| 128 |
+
print("\n" + "=" * 60)
|
| 129 |
+
print("训练基准测试结果:")
|
| 130 |
+
print(f" 平均批次时间: {batch_times.mean():.3f} ± {batch_times.std():.3f} s")
|
| 131 |
+
print(f" 最小批次时间: {batch_times.min():.3f} s")
|
| 132 |
+
print(f" 最大批次时间: {batch_times.max():.3f} s")
|
| 133 |
+
print(f" 吞吐量: {1 / batch_times.mean():.2f} batches/s")
|
| 134 |
+
|
| 135 |
+
if memory_usage:
|
| 136 |
+
memory_usage = np.array(memory_usage)
|
| 137 |
+
print(f" 平均GPU内存使用: {memory_usage.mean():.2f} ± {memory_usage.std():.2f} GB")
|
| 138 |
+
|
| 139 |
+
print("=" * 60)
|
| 140 |
+
|
| 141 |
+
return batch_times.mean()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def benchmark_inference(config: dict):
|
| 145 |
+
"""推理性能基准测试"""
|
| 146 |
+
print("\n" + "=" * 60)
|
| 147 |
+
print("推理性能基准测试")
|
| 148 |
+
print("=" * 60)
|
| 149 |
+
|
| 150 |
+
# 设备
|
| 151 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 152 |
+
|
| 153 |
+
# 加载模型配置
|
| 154 |
+
model_config = load_config('configs/model/unet_light.yaml')
|
| 155 |
+
|
| 156 |
+
# 创建模型
|
| 157 |
+
model = UNetLight(model_config).to(device)
|
| 158 |
+
model.eval()
|
| 159 |
+
|
| 160 |
+
# 创建扩散过程
|
| 161 |
+
diffusion_config = load_config('configs/model/diffusion.yaml')
|
| 162 |
+
diffusion = DiffusionProcess(diffusion_config)
|
| 163 |
+
|
| 164 |
+
# 创建基准测试器
|
| 165 |
+
benchmark = InferenceBenchmark(model, device)
|
| 166 |
+
|
| 167 |
+
# 测试不同分辨率
|
| 168 |
+
resolutions = [(256, 256), (512, 512), (768, 768)]
|
| 169 |
+
results = {}
|
| 170 |
+
|
| 171 |
+
for height, width in resolutions:
|
| 172 |
+
print(f"\n测试分辨率: {width}x{height}")
|
| 173 |
+
|
| 174 |
+
# 潜在空间大小
|
| 175 |
+
latent_height = height // 8
|
| 176 |
+
latent_width = width // 8
|
| 177 |
+
|
| 178 |
+
# 运行基准测试
|
| 179 |
+
stats = benchmark.benchmark(
|
| 180 |
+
input_shape=(1, model.in_channels, latent_height, latent_width),
|
| 181 |
+
num_iterations=10,
|
| 182 |
+
warmup_iterations=3
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
results[f"{width}x{height}"] = stats
|
| 186 |
+
|
| 187 |
+
# 打印总结
|
| 188 |
+
print("\n" + "=" * 60)
|
| 189 |
+
print("推理基准测试总结:")
|
| 190 |
+
|
| 191 |
+
for resolution, stats in results.items():
|
| 192 |
+
print(f"\n 分辨率 {resolution}:")
|
| 193 |
+
print(f" 平均时间: {stats['mean_ms']:.1f} ms")
|
| 194 |
+
print(f" FPS: {stats['fps']:.1f}")
|
| 195 |
+
|
| 196 |
+
print("=" * 60)
|
| 197 |
+
|
| 198 |
+
return results
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def benchmark_sampling(config: dict):
|
| 202 |
+
"""采样性能基准测试"""
|
| 203 |
+
print("\n" + "=" * 60)
|
| 204 |
+
print("采样性能基准测试")
|
| 205 |
+
print("=" * 60)
|
| 206 |
+
|
| 207 |
+
# 设备
|
| 208 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 209 |
+
|
| 210 |
+
# 加载模型配置
|
| 211 |
+
model_config = load_config('configs/model/unet_light.yaml')
|
| 212 |
+
|
| 213 |
+
# 创建模型
|
| 214 |
+
model = UNetLight(model_config).to(device)
|
| 215 |
+
model.eval()
|
| 216 |
+
|
| 217 |
+
# 创建扩散过程
|
| 218 |
+
diffusion_config = load_config('configs/model/diffusion.yaml')
|
| 219 |
+
diffusion = DiffusionProcess(diffusion_config)
|
| 220 |
+
|
| 221 |
+
# 创建采样器
|
| 222 |
+
sampler = DDIMSampler(model, diffusion, num_inference_steps=50)
|
| 223 |
+
|
| 224 |
+
# 测试不同采样步数
|
| 225 |
+
step_configs = [20, 30, 50]
|
| 226 |
+
results = {}
|
| 227 |
+
|
| 228 |
+
for num_steps in step_configs:
|
| 229 |
+
print(f"\n测试采样步数: {num_steps}")
|
| 230 |
+
|
| 231 |
+
# 设置采样步数
|
| 232 |
+
sampler.set_timesteps(num_steps)
|
| 233 |
+
|
| 234 |
+
# 准备输入
|
| 235 |
+
prompt_embeds = torch.randn(1, 77, 768, device=device)
|
| 236 |
+
|
| 237 |
+
# 预热
|
| 238 |
+
print(" 预热...")
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
for _ in range(3):
|
| 241 |
+
_ = sampler.sample(
|
| 242 |
+
prompt_embeds=prompt_embeds,
|
| 243 |
+
height=512,
|
| 244 |
+
width=512,
|
| 245 |
+
progress_bar=False
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# 基准测试
|
| 249 |
+
print(" 运行基准测试...")
|
| 250 |
+
times = []
|
| 251 |
+
|
| 252 |
+
for i in range(5):
|
| 253 |
+
start_time = time.time()
|
| 254 |
+
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
_ = sampler.sample(
|
| 257 |
+
prompt_embeds=prompt_embeds,
|
| 258 |
+
height=512,
|
| 259 |
+
width=512,
|
| 260 |
+
progress_bar=False
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if torch.cuda.is_available():
|
| 264 |
+
torch.cuda.synchronize()
|
| 265 |
+
|
| 266 |
+
end_time = time.time()
|
| 267 |
+
times.append(end_time - start_time)
|
| 268 |
+
|
| 269 |
+
print(f" 迭代 {i+1}: {times[-1]:.2f}s")
|
| 270 |
+
|
| 271 |
+
# 统计
|
| 272 |
+
times = np.array(times)
|
| 273 |
+
|
| 274 |
+
results[num_steps] = {
|
| 275 |
+
'mean_time': times.mean(),
|
| 276 |
+
'std_time': times.std(),
|
| 277 |
+
'fps': 1 / times.mean()
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
# 打印总结
|
| 281 |
+
print("\n" + "=" * 60)
|
| 282 |
+
print("采样基准测试总结:")
|
| 283 |
+
|
| 284 |
+
for num_steps, stats in results.items():
|
| 285 |
+
print(f"\n 采样步数 {num_steps}:")
|
| 286 |
+
print(f" 平均时间: {stats['mean_time']:.2f} ± {stats['std_time']:.2f} s")
|
| 287 |
+
print(f" FPS: {stats['fps']:.2f}")
|
| 288 |
+
|
| 289 |
+
print("=" * 60)
|
| 290 |
+
|
| 291 |
+
return results
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def benchmark_memory(config: dict):
|
| 295 |
+
"""内存使用基准测试"""
|
| 296 |
+
print("\n" + "=" * 60)
|
| 297 |
+
print("内存使用基准测试")
|
| 298 |
+
print("=" * 60)
|
| 299 |
+
|
| 300 |
+
if not torch.cuda.is_available():
|
| 301 |
+
print("GPU不可用,跳过内存基准测试")
|
| 302 |
+
return {}
|
| 303 |
+
|
| 304 |
+
# 设备
|
| 305 |
+
device = torch.device('cuda')
|
| 306 |
+
|
| 307 |
+
# 加载模型配置
|
| 308 |
+
model_config = load_config('configs/model/unet_light.yaml')
|
| 309 |
+
|
| 310 |
+
# 测试不同批次大小
|
| 311 |
+
batch_sizes = [1, 2, 4, 8]
|
| 312 |
+
results = {}
|
| 313 |
+
|
| 314 |
+
for batch_size in batch_sizes:
|
| 315 |
+
print(f"\n测试批次大小: {batch_size}")
|
| 316 |
+
|
| 317 |
+
# 创建模型
|
| 318 |
+
model = UNetLight(model_config).to(device)
|
| 319 |
+
model.eval()
|
| 320 |
+
|
| 321 |
+
# 准备输入
|
| 322 |
+
input_shape = (batch_size, model.in_channels, 64, 64)
|
| 323 |
+
x = torch.randn(*input_shape, device=device)
|
| 324 |
+
t = torch.tensor([500] * batch_size, device=device)
|
| 325 |
+
context = torch.randn(batch_size, 77, 768, device=device)
|
| 326 |
+
|
| 327 |
+
# 清空缓存
|
| 328 |
+
torch.cuda.empty_cache()
|
| 329 |
+
|
| 330 |
+
# 记录初始内存
|
| 331 |
+
initial_memory = torch.cuda.memory_allocated()
|
| 332 |
+
|
| 333 |
+
# 前向传播
|
| 334 |
+
with torch.no_grad():
|
| 335 |
+
_ = model(x, t, context)
|
| 336 |
+
|
| 337 |
+
# 记录峰值内存
|
| 338 |
+
peak_memory = torch.cuda.max_memory_allocated()
|
| 339 |
+
current_memory = torch.cuda.memory_allocated()
|
| 340 |
+
|
| 341 |
+
# 计算内存使用
|
| 342 |
+
memory_used = peak_memory - initial_memory
|
| 343 |
+
|
| 344 |
+
results[batch_size] = {
|
| 345 |
+
'initial_memory_gb': initial_memory / 1024**3,
|
| 346 |
+
'peak_memory_gb': peak_memory / 1024**3,
|
| 347 |
+
'current_memory_gb': current_memory / 1024**3,
|
| 348 |
+
'memory_used_gb': memory_used / 1024**3,
|
| 349 |
+
'memory_per_sample_gb': memory_used / (batch_size * 1024**3)
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
print(f" 初始内存: {initial_memory / 1024**3:.2f} GB")
|
| 353 |
+
print(f" 峰值内存: {peak_memory / 1024**3:.2f} GB")
|
| 354 |
+
print(f" 当前内存: {current_memory / 1024**3:.2f} GB")
|
| 355 |
+
print(f" 内存使用: {memory_used / 1024**3:.2f} GB")
|
| 356 |
+
print(f" 每样本内存: {memory_used / (batch_size * 1024**3):.2f} GB")
|
| 357 |
+
|
| 358 |
+
# 清理
|
| 359 |
+
del model
|
| 360 |
+
torch.cuda.empty_cache()
|
| 361 |
+
|
| 362 |
+
# 打印总结
|
| 363 |
+
print("\n" + "=" * 60)
|
| 364 |
+
print("内存基准测试总结:")
|
| 365 |
+
|
| 366 |
+
for batch_size, stats in results.items():
|
| 367 |
+
print(f"\n 批次大小 {batch_size}:")
|
| 368 |
+
print(f" 总内存使用: {stats['memory_used_gb']:.2f} GB")
|
| 369 |
+
print(f" 每样本内存: {stats['memory_per_sample_gb']:.2f} GB")
|
| 370 |
+
|
| 371 |
+
print("=" * 60)
|
| 372 |
+
|
| 373 |
+
return results
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def generate_report(results: dict, output_file: str = "benchmark_report.md"):
|
| 377 |
+
"""生成基准测试报告"""
|
| 378 |
+
print(f"\n生成报告: {output_file}")
|
| 379 |
+
|
| 380 |
+
with open(output_file, 'w') as f:
|
| 381 |
+
f.write("# Lumina 性能基准测试报告\n\n")
|
| 382 |
+
f.write(f"生成时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
| 383 |
+
|
| 384 |
+
f.write("## 系统信息\n")
|
| 385 |
+
f.write(f"- PyTorch版本: {torch.__version__}\n")
|
| 386 |
+
f.write(f"- CUDA可用: {torch.cuda.is_available()}\n")
|
| 387 |
+
if torch.cuda.is_available():
|
| 388 |
+
f.write(f"- GPU: {torch.cuda.get_device_name(0)}\n")
|
| 389 |
+
f.write(f"- CUDA版本: {torch.version.cuda}\n")
|
| 390 |
+
f.write(f"- 系统内存: {os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') / 1024**3:.1f} GB\n\n")
|
| 391 |
+
|
| 392 |
+
if 'training' in results:
|
| 393 |
+
f.write("## 训练性能\n")
|
| 394 |
+
f.write(f"- 平均批次时间: {results['training']:.3f} s\n")
|
| 395 |
+
f.write(f"- 吞吐量: {1/results['training']:.2f} batches/s\n\n")
|
| 396 |
+
|
| 397 |
+
if 'inference' in results:
|
| 398 |
+
f.write("## 推理性能\n")
|
| 399 |
+
for resolution, stats in results['inference'].items():
|
| 400 |
+
f.write(f"### 分辨率 {resolution}\n")
|
| 401 |
+
f.write(f"- 平均推理时间: {stats['mean_ms']:.1f} ms\n")
|
| 402 |
+
f.write(f"- FPS: {stats['fps']:.1f}\n\n")
|
| 403 |
+
|
| 404 |
+
if 'sampling' in results:
|
| 405 |
+
f.write("## 采样性能\n")
|
| 406 |
+
for num_steps, stats in results['sampling'].items():
|
| 407 |
+
f.write(f"### 采样步数 {num_steps}\n")
|
| 408 |
+
f.write(f"- 平均采样时间: {stats['mean_time']:.2f} s\n")
|
| 409 |
+
f.write(f"- FPS: {stats['fps']:.2f}\n\n")
|
| 410 |
+
|
| 411 |
+
if 'memory' in results:
|
| 412 |
+
f.write("## 内存使用\n")
|
| 413 |
+
for batch_size, stats in results['memory'].items():
|
| 414 |
+
f.write(f"### 批次大小 {batch_size}\n")
|
| 415 |
+
f.write(f"- 总内存使用: {stats['memory_used_gb']:.2f} GB\n")
|
| 416 |
+
f.write(f"- 每样本内存: {stats['memory_per_sample_gb']:.2f} GB\n\n")
|
| 417 |
+
|
| 418 |
+
f.write("## 建议\n")
|
| 419 |
+
f.write("1. 根据GPU内存选择适当的批次大小\n")
|
| 420 |
+
f.write("2. 推理时使用适当的采样步数平衡质量和速度\n")
|
| 421 |
+
f.write("3. 训练时使用梯度累积来模拟大批次训练\n")
|
| 422 |
+
|
| 423 |
+
print(f"报告已保存: {output_file}")
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def main():
|
| 427 |
+
"""主函数"""
|
| 428 |
+
parser = argparse.ArgumentParser(description="Lumina性能基准测试")
|
| 429 |
+
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--config",
|
| 432 |
+
type=str,
|
| 433 |
+
default="configs/training/p4_optimized.yaml",
|
| 434 |
+
help="配置文件路径"
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
parser.add_argument(
|
| 438 |
+
"--test",
|
| 439 |
+
type=str,
|
| 440 |
+
nargs="+",
|
| 441 |
+
default=['all'],
|
| 442 |
+
choices=['training', 'inference', 'sampling', 'memory', 'all'],
|
| 443 |
+
help="测试项目"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
parser.add_argument(
|
| 447 |
+
"--output",
|
| 448 |
+
type=str,
|
| 449 |
+
default="benchmark_report.md",
|
| 450 |
+
help="输出报告文件"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
args = parser.parse_args()
|
| 454 |
+
|
| 455 |
+
# 加载配置
|
| 456 |
+
config = load_config(args.config)
|
| 457 |
+
|
| 458 |
+
# 运行基准测试
|
| 459 |
+
results = {}
|
| 460 |
+
|
| 461 |
+
if 'all' in args.test or 'training' in args.test:
|
| 462 |
+
try:
|
| 463 |
+
results['training'] = benchmark_training(config)
|
| 464 |
+
except Exception as e:
|
| 465 |
+
print(f"训练基准测试失败: {e}")
|
| 466 |
+
|
| 467 |
+
if 'all' in args.test or 'inference' in args.test:
|
| 468 |
+
try:
|
| 469 |
+
results['inference'] = benchmark_inference(config)
|
| 470 |
+
except Exception as e:
|
| 471 |
+
print(f"推理基准测试失败: {e}")
|
| 472 |
+
|
| 473 |
+
if 'all' in args.test or 'sampling' in args.test:
|
| 474 |
+
try:
|
| 475 |
+
results['sampling'] = benchmark_sampling(config)
|
| 476 |
+
except Exception as e:
|
| 477 |
+
print(f"采样基准测试失败: {e}")
|
| 478 |
+
|
| 479 |
+
if 'all' in args.test or 'memory' in args.test:
|
| 480 |
+
try:
|
| 481 |
+
results['memory'] = benchmark_memory(config)
|
| 482 |
+
except Exception as e:
|
| 483 |
+
print(f"内存基准测试失败: {e}")
|
| 484 |
+
|
| 485 |
+
# 生成报告
|
| 486 |
+
generate_report(results, args.output)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
if __name__ == "__main__":
|
| 490 |
+
main()
|
scripts/download_laion.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LAION数据集下载脚本
|
| 4 |
+
|
| 5 |
+
这个脚本帮助下载LAION数据集的不同版本。
|
| 6 |
+
LAION数据集很大,通常需要下载元数据文件和图像文件。
|
| 7 |
+
|
| 8 |
+
注意:完整LAION数据集非常大(数TB),建议下载子集或使用现有缓存。
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
import subprocess
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import requests
|
| 17 |
+
import json
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
def download_file(url, output_path, chunk_size=8192):
|
| 22 |
+
"""下载文件并显示进度条"""
|
| 23 |
+
response = requests.get(url, stream=True)
|
| 24 |
+
response.raise_for_status()
|
| 25 |
+
|
| 26 |
+
total_size = int(response.headers.get('content-length', 0))
|
| 27 |
+
|
| 28 |
+
with open(output_path, 'wb') as f, tqdm(
|
| 29 |
+
desc=os.path.basename(output_path),
|
| 30 |
+
total=total_size,
|
| 31 |
+
unit='B',
|
| 32 |
+
unit_scale=True,
|
| 33 |
+
unit_divisor=1024,
|
| 34 |
+
) as pbar:
|
| 35 |
+
for chunk in response.iter_content(chunk_size=chunk_size):
|
| 36 |
+
f.write(chunk)
|
| 37 |
+
pbar.update(len(chunk))
|
| 38 |
+
|
| 39 |
+
return output_path
|
| 40 |
+
|
| 41 |
+
def download_laion_aesthetic(output_dir="./data/laion", subset="6.5+"):
|
| 42 |
+
"""
|
| 43 |
+
下载LAION-Aesthetic数据集
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
output_dir: 输出目录
|
| 47 |
+
subset: 子集版本,可选 "6.5+" (6.5分以上), "7.0+" (7.0分以上)
|
| 48 |
+
"""
|
| 49 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# LAION-Aesthetic数据集信息
|
| 52 |
+
datasets = {
|
| 53 |
+
"6.5+": {
|
| 54 |
+
"metadata": "https://huggingface.co/datasets/laion/laion-aesthetic-6.5plus/resolve/main/data/00000.parquet",
|
| 55 |
+
"description": "LAION-Aesthetic 6.5+ (美学评分6.5分以上)"
|
| 56 |
+
},
|
| 57 |
+
"7.0+": {
|
| 58 |
+
"metadata": "https://huggingface.co/datasets/laion/laion-aesthetic-7.0plus/resolve/main/data/00000.parquet",
|
| 59 |
+
"description": "LAION-Aesthetic 7.0+ (美学评分7.0分以上)"
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
if subset not in datasets:
|
| 64 |
+
print(f"错误: 不支持的子集 {subset}")
|
| 65 |
+
print(f"可用子集: {list(datasets.keys())}")
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
dataset_info = datasets[subset]
|
| 69 |
+
print(f"下载 {dataset_info['description']}")
|
| 70 |
+
|
| 71 |
+
# 下载元数据文件
|
| 72 |
+
metadata_url = dataset_info["metadata"]
|
| 73 |
+
metadata_path = os.path.join(output_dir, "metadata.parquet")
|
| 74 |
+
|
| 75 |
+
print(f"下载元数据文件到: {metadata_path}")
|
| 76 |
+
try:
|
| 77 |
+
download_file(metadata_url, metadata_path)
|
| 78 |
+
print(f"元数据文件下载完成: {metadata_path}")
|
| 79 |
+
|
| 80 |
+
# 验证文件
|
| 81 |
+
df = pd.read_parquet(metadata_path)
|
| 82 |
+
print(f"元数据包含 {len(df)} 条记录")
|
| 83 |
+
print(f"列名: {list(df.columns)}")
|
| 84 |
+
|
| 85 |
+
# 保存样本信息
|
| 86 |
+
sample_info = {
|
| 87 |
+
"total_samples": len(df),
|
| 88 |
+
"columns": list(df.columns),
|
| 89 |
+
"subset": subset,
|
| 90 |
+
"description": dataset_info["description"]
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
with open(os.path.join(output_dir, "dataset_info.json"), "w") as f:
|
| 94 |
+
json.dump(sample_info, f, indent=2)
|
| 95 |
+
|
| 96 |
+
print(f"数据集信息已保存到: {os.path.join(output_dir, 'dataset_info.json')}")
|
| 97 |
+
|
| 98 |
+
return True
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(f"下载失败: {e}")
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
def download_laion_5b_sample(output_dir="./data/laion", num_samples=10000):
|
| 105 |
+
"""
|
| 106 |
+
下载LAION-5B数据集的样本
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
output_dir: 输出目录
|
| 110 |
+
num_samples: 样本数量
|
| 111 |
+
"""
|
| 112 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 113 |
+
|
| 114 |
+
print(f"下载LAION-5B数据集样本 ({num_samples}条记录)")
|
| 115 |
+
|
| 116 |
+
# LAION-5B数据集分片URL示例
|
| 117 |
+
# 注意:完整数据集有数万个分片,这里只下载一个样本分片
|
| 118 |
+
sample_shard = "https://huggingface.co/datasets/laion/laion2b-en/resolve/main/part-00000-5b54c5d5-bbcf-484d-a2ce-0d6f73df1a36-c000.snappy.parquet"
|
| 119 |
+
|
| 120 |
+
metadata_path = os.path.join(output_dir, "metadata_sample.parquet")
|
| 121 |
+
|
| 122 |
+
print(f"下载样本分片到: {metadata_path}")
|
| 123 |
+
try:
|
| 124 |
+
download_file(sample_shard, metadata_path)
|
| 125 |
+
print(f"样本分片下载完成: {metadata_path}")
|
| 126 |
+
|
| 127 |
+
# 读取并采样
|
| 128 |
+
df = pd.read_parquet(metadata_path)
|
| 129 |
+
|
| 130 |
+
if len(df) > num_samples:
|
| 131 |
+
df = df.sample(num_samples, random_state=42)
|
| 132 |
+
|
| 133 |
+
# 保存采样后的数据
|
| 134 |
+
sampled_path = os.path.join(output_dir, "metadata.parquet")
|
| 135 |
+
df.to_parquet(sampled_path)
|
| 136 |
+
|
| 137 |
+
print(f"采样数据保存到: {sampled_path}")
|
| 138 |
+
print(f"采样后包含 {len(df)} 条记录")
|
| 139 |
+
print(f"列名: {list(df.columns)}")
|
| 140 |
+
|
| 141 |
+
# 保存样本信息
|
| 142 |
+
sample_info = {
|
| 143 |
+
"total_samples": len(df),
|
| 144 |
+
"original_samples": len(pd.read_parquet(metadata_path)),
|
| 145 |
+
"columns": list(df.columns),
|
| 146 |
+
"dataset": "LAION-5B-sample",
|
| 147 |
+
"description": f"LAION-5B数据集样本 ({num_samples}条记录)"
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
with open(os.path.join(output_dir, "dataset_info.json"), "w") as f:
|
| 151 |
+
json.dump(sample_info, f, indent=2)
|
| 152 |
+
|
| 153 |
+
print(f"数据集信息已保存到: {os.path.join(output_dir, 'dataset_info.json')}")
|
| 154 |
+
|
| 155 |
+
return True
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"下载失败: {e}")
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
def create_dummy_dataset(output_dir="./data/laion", num_samples=100):
|
| 162 |
+
"""
|
| 163 |
+
创建虚拟数据集用于测试
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
output_dir: 输出目录
|
| 167 |
+
num_samples: 样本数量
|
| 168 |
+
"""
|
| 169 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 170 |
+
|
| 171 |
+
print(f"创建虚拟数据集 ({num_samples}条记录)")
|
| 172 |
+
|
| 173 |
+
import numpy as np
|
| 174 |
+
|
| 175 |
+
# 创建虚拟数据
|
| 176 |
+
data = {
|
| 177 |
+
'url': [f'https://example.com/image_{i}.jpg' for i in range(num_samples)],
|
| 178 |
+
'caption': [f'A beautiful image number {i}' for i in range(num_samples)],
|
| 179 |
+
'aesthetic_score': np.random.uniform(5.0, 9.0, num_samples),
|
| 180 |
+
'watermark_prob': np.random.uniform(0.0, 1.0, num_samples),
|
| 181 |
+
'NSFW': ['UNLIKELY'] * num_samples,
|
| 182 |
+
'image_file': [f'image_{i:06d}.jpg' for i in range(num_samples)]
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
df = pd.DataFrame(data)
|
| 186 |
+
|
| 187 |
+
# 保存元数据
|
| 188 |
+
metadata_path = os.path.join(output_dir, "metadata.parquet")
|
| 189 |
+
df.to_parquet(metadata_path)
|
| 190 |
+
|
| 191 |
+
print(f"虚拟元数据创建完成: {metadata_path}")
|
| 192 |
+
print(f"包含 {len(df)} 条记录")
|
| 193 |
+
print(f"列名: {list(df.columns)}")
|
| 194 |
+
|
| 195 |
+
# 创建虚拟图像目录
|
| 196 |
+
images_dir = os.path.join(output_dir, "images")
|
| 197 |
+
os.makedirs(images_dir, exist_ok=True)
|
| 198 |
+
|
| 199 |
+
print(f"虚拟图像目录: {images_dir}")
|
| 200 |
+
print("注意:虚拟数据集不包含实际图像文件,仅用于测试代码流程")
|
| 201 |
+
|
| 202 |
+
# 保存数据集信息
|
| 203 |
+
sample_info = {
|
| 204 |
+
"total_samples": len(df),
|
| 205 |
+
"columns": list(df.columns),
|
| 206 |
+
"dataset": "dummy",
|
| 207 |
+
"description": f"虚拟测试数据集 ({num_samples}条记录)",
|
| 208 |
+
"note": "不包含实际图像文件,仅用于测试代码流程"
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
with open(os.path.join(output_dir, "dataset_info.json"), "w") as f:
|
| 212 |
+
json.dump(sample_info, f, indent=2)
|
| 213 |
+
|
| 214 |
+
print(f"数据集信息已保存到: {os.path.join(output_dir, 'dataset_info.json')}")
|
| 215 |
+
|
| 216 |
+
return True
|
| 217 |
+
|
| 218 |
+
def main():
|
| 219 |
+
parser = argparse.ArgumentParser(description="下载LAION数据集")
|
| 220 |
+
parser.add_argument("--output-dir", default="./data/laion", help="输出目录")
|
| 221 |
+
parser.add_argument("--subset", default="6.5+", choices=["6.5+", "7.0+"],
|
| 222 |
+
help="LAION-Aesthetic子集版本")
|
| 223 |
+
parser.add_argument("--sample-size", type=int, default=10000,
|
| 224 |
+
help="LAION-5B样本大小")
|
| 225 |
+
parser.add_argument("--dummy", action="store_true",
|
| 226 |
+
help="创建虚拟数据集用于测试")
|
| 227 |
+
parser.add_argument("--dummy-size", type=int, default=100,
|
| 228 |
+
help="虚拟数据集大小")
|
| 229 |
+
|
| 230 |
+
args = parser.parse_args()
|
| 231 |
+
|
| 232 |
+
print("=" * 60)
|
| 233 |
+
print("LAION数据集下载工具")
|
| 234 |
+
print("=" * 60)
|
| 235 |
+
|
| 236 |
+
if args.dummy:
|
| 237 |
+
print("\n创建虚拟数据集模式...")
|
| 238 |
+
success = create_dummy_dataset(args.output_dir, args.dummy_size)
|
| 239 |
+
else:
|
| 240 |
+
print("\n下载LAION-Aesthetic数据集...")
|
| 241 |
+
print(f"输出目录: {args.output_dir}")
|
| 242 |
+
print(f"子集版本: {args.subset}")
|
| 243 |
+
print("\n注意:")
|
| 244 |
+
print("1. LAION数据集很大,下载需要时间和存储空间")
|
| 245 |
+
print("2. 元数据文件通常几百MB到几GB")
|
| 246 |
+
print("3. 图像文件需要额外下载")
|
| 247 |
+
print("4. 建议先使用虚拟数据集测试代码流程")
|
| 248 |
+
|
| 249 |
+
response = input("\n是否继续? (y/n): ")
|
| 250 |
+
if response.lower() != 'y':
|
| 251 |
+
print("取消下载")
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
success = download_laion_aesthetic(args.output_dir, args.subset)
|
| 255 |
+
|
| 256 |
+
if success:
|
| 257 |
+
print("\n" + "=" * 60)
|
| 258 |
+
print("下载完成!")
|
| 259 |
+
print("=" * 60)
|
| 260 |
+
print("\n下一步:")
|
| 261 |
+
print("1. 检查下载的文件:")
|
| 262 |
+
print(f" ls -lh {args.output_dir}/")
|
| 263 |
+
print("2. 测试数据集加载:")
|
| 264 |
+
print(" python -c \"import pandas as pd; df=pd.read_parquet('{}'); print('记录数:', len(df))\"".format(
|
| 265 |
+
os.path.join(args.output_dir, "metadata.parquet")))
|
| 266 |
+
print("3. 运行测试脚本:")
|
| 267 |
+
print(" python src/data/dataset.py")
|
| 268 |
+
else:
|
| 269 |
+
print("\n下载失败,请检查错误信息")
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
main()
|
scripts/export.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
模型导出脚本
|
| 4 |
+
用于导出训练好的模型为不同格式
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import argparse
|
| 10 |
+
import yaml
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
# 添加项目根目录到Python路径
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
| 16 |
+
|
| 17 |
+
from src.models.unet_light import UNetLight
|
| 18 |
+
from src.models.diffusion import DiffusionProcess
|
| 19 |
+
from src.inference.optimization import ModelOptimizer, ONNXExporter, optimize_model_for_p4
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_config(config_path: str) -> dict:
|
| 23 |
+
"""加载配置文件"""
|
| 24 |
+
with open(config_path, 'r') as f:
|
| 25 |
+
config = yaml.safe_load(f)
|
| 26 |
+
return config
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_model(checkpoint_path: str, config: dict, device: torch.device) -> nn.Module:
|
| 30 |
+
"""加载模型"""
|
| 31 |
+
# 加载模型配置
|
| 32 |
+
model_config_path = config.get('model_config', 'configs/model/unet_light.yaml')
|
| 33 |
+
model_config = load_config(model_config_path)
|
| 34 |
+
|
| 35 |
+
# 创建模型
|
| 36 |
+
model = UNetLight(model_config)
|
| 37 |
+
|
| 38 |
+
# 加载检查点
|
| 39 |
+
print(f"加载检查点: {checkpoint_path}")
|
| 40 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 41 |
+
|
| 42 |
+
# 加载模型权重
|
| 43 |
+
if 'model_state_dict' in checkpoint:
|
| 44 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 45 |
+
elif 'state_dict' in checkpoint:
|
| 46 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 47 |
+
else:
|
| 48 |
+
model.load_state_dict(checkpoint)
|
| 49 |
+
|
| 50 |
+
# 移动到设备
|
| 51 |
+
model = model.to(device)
|
| 52 |
+
model.eval()
|
| 53 |
+
|
| 54 |
+
print(f"模型加载完成")
|
| 55 |
+
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def export_torchscript(model: nn.Module, output_path: str):
|
| 60 |
+
"""导出为TorchScript格式"""
|
| 61 |
+
print(f"导出为TorchScript: {output_path}")
|
| 62 |
+
|
| 63 |
+
# 创建示例输入
|
| 64 |
+
example_input = torch.randn(1, model.in_channels, 64, 64)
|
| 65 |
+
example_timestep = torch.tensor([500])
|
| 66 |
+
example_context = torch.randn(1, 77, 768)
|
| 67 |
+
|
| 68 |
+
# 跟踪模型
|
| 69 |
+
traced_model = torch.jit.trace(
|
| 70 |
+
model,
|
| 71 |
+
(example_input, example_timestep, example_context),
|
| 72 |
+
check_trace=False
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# 保存
|
| 76 |
+
traced_model.save(output_path)
|
| 77 |
+
print(f"TorchScript模型已保存: {output_path}")
|
| 78 |
+
|
| 79 |
+
return traced_model
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def export_onnx(model: nn.Module, output_path: str, opset_version: int = 14):
|
| 83 |
+
"""导出为ONNX格式"""
|
| 84 |
+
print(f"导出为ONNX: {output_path}")
|
| 85 |
+
|
| 86 |
+
# 创建示例输入
|
| 87 |
+
example_input = torch.randn(1, model.in_channels, 64, 64)
|
| 88 |
+
example_timestep = torch.tensor([500])
|
| 89 |
+
example_context = torch.randn(1, 77, 768)
|
| 90 |
+
|
| 91 |
+
# 设置动态轴
|
| 92 |
+
dynamic_axes = {
|
| 93 |
+
'input': {0: 'batch_size'},
|
| 94 |
+
'timestep': {0: 'batch_size'},
|
| 95 |
+
'context': {0: 'batch_size'},
|
| 96 |
+
'output': {0: 'batch_size'}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# 导出
|
| 100 |
+
torch.onnx.export(
|
| 101 |
+
model,
|
| 102 |
+
(example_input, example_timestep, example_context),
|
| 103 |
+
output_path,
|
| 104 |
+
input_names=['input', 'timestep', 'context'],
|
| 105 |
+
output_names=['output'],
|
| 106 |
+
dynamic_axes=dynamic_axes,
|
| 107 |
+
opset_version=opset_version,
|
| 108 |
+
do_constant_folding=True,
|
| 109 |
+
verbose=False
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
print(f"ONNX模型已保存: {output_path}")
|
| 113 |
+
|
| 114 |
+
# 验证ONNX模型
|
| 115 |
+
import onnx
|
| 116 |
+
onnx_model = onnx.load(output_path)
|
| 117 |
+
onnx.checker.check_model(onnx_model)
|
| 118 |
+
print("ONNX模型验证成功")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def export_safetensors(model: nn.Module, output_path: str):
|
| 122 |
+
"""导出为safetensors格式"""
|
| 123 |
+
try:
|
| 124 |
+
from safetensors.torch import save_file
|
| 125 |
+
|
| 126 |
+
# 转换为safetensors格式
|
| 127 |
+
state_dict = model.state_dict()
|
| 128 |
+
save_file(state_dict, output_path)
|
| 129 |
+
|
| 130 |
+
print(f"Safetensors模型已保存: {output_path}")
|
| 131 |
+
except ImportError:
|
| 132 |
+
print("safetensors未安装,跳过safetensors导出")
|
| 133 |
+
print("安装: pip install safetensors")
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def optimize_and_export(
|
| 137 |
+
checkpoint_path: str,
|
| 138 |
+
output_dir: str,
|
| 139 |
+
formats: list = ['torchscript', 'onnx', 'safetensors'],
|
| 140 |
+
optimize_for_p4: bool = True
|
| 141 |
+
):
|
| 142 |
+
"""优化并导出模型"""
|
| 143 |
+
# 创建输出目录
|
| 144 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 145 |
+
|
| 146 |
+
# 加载配置
|
| 147 |
+
config_path = "configs/training/p4_optimized.yaml"
|
| 148 |
+
config = load_config(config_path)
|
| 149 |
+
|
| 150 |
+
# 设备
|
| 151 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 152 |
+
|
| 153 |
+
# 加载模型
|
| 154 |
+
model = load_model(checkpoint_path, config, device)
|
| 155 |
+
|
| 156 |
+
# 优化模型(针对P4)
|
| 157 |
+
if optimize_for_p4:
|
| 158 |
+
print("优化模型(针对P4)...")
|
| 159 |
+
model = optimize_model_for_p4(model)
|
| 160 |
+
|
| 161 |
+
# 获取模型信息
|
| 162 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 163 |
+
model_size_mb = total_params * 4 / 1024**2 # fp32
|
| 164 |
+
|
| 165 |
+
print(f"\n模型信息:")
|
| 166 |
+
print(f" 参数量: {total_params:,}")
|
| 167 |
+
print(f" 模型大小: {model_size_mb:.2f} MB (fp32)")
|
| 168 |
+
|
| 169 |
+
# 导���为不同格式
|
| 170 |
+
base_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
|
| 171 |
+
|
| 172 |
+
for fmt in formats:
|
| 173 |
+
if fmt == 'torchscript':
|
| 174 |
+
output_path = os.path.join(output_dir, f"{base_name}.torchscript.pt")
|
| 175 |
+
export_torchscript(model, output_path)
|
| 176 |
+
|
| 177 |
+
elif fmt == 'onnx':
|
| 178 |
+
output_path = os.path.join(output_dir, f"{base_name}.onnx")
|
| 179 |
+
export_onnx(model, output_path)
|
| 180 |
+
|
| 181 |
+
elif fmt == 'safetensors':
|
| 182 |
+
output_path = os.path.join(output_dir, f"{base_name}.safetensors")
|
| 183 |
+
export_safetensors(model, output_path)
|
| 184 |
+
|
| 185 |
+
elif fmt == 'pth':
|
| 186 |
+
output_path = os.path.join(output_dir, f"{base_name}.pth")
|
| 187 |
+
torch.save(model.state_dict(), output_path)
|
| 188 |
+
print(f"PyTorch模型已保存: {output_path}")
|
| 189 |
+
|
| 190 |
+
else:
|
| 191 |
+
print(f"未知的格式: {fmt}")
|
| 192 |
+
|
| 193 |
+
print(f"\n所有模型已导出到: {output_dir}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def create_lite_version(model: nn.Module, reduction_factor: float = 0.5) -> nn.Module:
|
| 197 |
+
"""创建轻量版本(通过减少通道数)"""
|
| 198 |
+
# 注意:这是一个示例,需要根据实际模型结构调整
|
| 199 |
+
print(f"创建轻量版本,减少因子: {reduction_factor}")
|
| 200 |
+
|
| 201 |
+
# 这里应该实现具体的轻量化逻辑
|
| 202 |
+
# 例如,减少UNet的通道数
|
| 203 |
+
|
| 204 |
+
return model
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def main():
|
| 208 |
+
"""主函数"""
|
| 209 |
+
parser = argparse.ArgumentParser(description="导出Lumina模型")
|
| 210 |
+
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--checkpoint",
|
| 213 |
+
type=str,
|
| 214 |
+
required=True,
|
| 215 |
+
help="模型检查点路径"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--output-dir",
|
| 220 |
+
type=str,
|
| 221 |
+
default="./exported_models",
|
| 222 |
+
help="输出目录"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--formats",
|
| 227 |
+
type=str,
|
| 228 |
+
nargs="+",
|
| 229 |
+
default=['torchscript', 'onnx'],
|
| 230 |
+
choices=['torchscript', 'onnx', 'safetensors', 'pth'],
|
| 231 |
+
help="导出格式"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--optimize",
|
| 236 |
+
action="store_true",
|
| 237 |
+
help="优化模型(针对P4)"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--lite",
|
| 242 |
+
action="store_true",
|
| 243 |
+
help="创建轻量版本"
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--lite-factor",
|
| 248 |
+
type=float,
|
| 249 |
+
default=0.5,
|
| 250 |
+
help="轻量化减少因子"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
args = parser.parse_args()
|
| 254 |
+
|
| 255 |
+
# 检查输入文件
|
| 256 |
+
if not os.path.exists(args.checkpoint):
|
| 257 |
+
print(f"错误: 检查点文件不存在: {args.checkpoint}")
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
# 优化并导出
|
| 261 |
+
optimize_and_export(
|
| 262 |
+
checkpoint_path=args.checkpoint,
|
| 263 |
+
output_dir=args.output_dir,
|
| 264 |
+
formats=args.formats,
|
| 265 |
+
optimize_for_p4=args.optimize
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
main()
|
scripts/train.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Lumina训练脚本
|
| 4 |
+
用于训练轻量级图像生成模型
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import argparse
|
| 10 |
+
import yaml
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
# 添加项目根目录到Python路径
|
| 17 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
| 18 |
+
|
| 19 |
+
from src.models.unet_light import UNetLight
|
| 20 |
+
from src.models.diffusion import DiffusionProcess, DiffusionModel
|
| 21 |
+
from src.data.dataset import create_data_loaders
|
| 22 |
+
from src.data.text_encoder import create_text_encoder
|
| 23 |
+
from src.training.trainer_p4 import P4Trainer
|
| 24 |
+
from src.training.memory_manager import MemoryOptimizer
|
| 25 |
+
from src.training.callbacks import create_default_callbacks
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_config(config_path: str) -> dict:
|
| 29 |
+
"""加载配置文件"""
|
| 30 |
+
with open(config_path, 'r') as f:
|
| 31 |
+
config = yaml.safe_load(f)
|
| 32 |
+
return config
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def setup_environment(config: dict):
|
| 36 |
+
"""设置训练环境"""
|
| 37 |
+
# 设置随机种子
|
| 38 |
+
seed = config.get('seed', 42)
|
| 39 |
+
torch.manual_seed(seed)
|
| 40 |
+
if torch.cuda.is_available():
|
| 41 |
+
torch.cuda.manual_seed(seed)
|
| 42 |
+
|
| 43 |
+
# 设置CUDA设备
|
| 44 |
+
device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
|
| 45 |
+
if device == 'cuda' and not torch.cuda.is_available():
|
| 46 |
+
warnings.warn("CUDA不可用,使用CPU")
|
| 47 |
+
device = 'cpu'
|
| 48 |
+
|
| 49 |
+
# 创建输出目录
|
| 50 |
+
output_dir = config.get('output_dir', './output')
|
| 51 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 52 |
+
|
| 53 |
+
# 设置日志
|
| 54 |
+
log_dir = config.get('log_dir', './logs')
|
| 55 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
print(f"环境设置完成:")
|
| 58 |
+
print(f" 设备: {device}")
|
| 59 |
+
print(f" 随机种子: {seed}")
|
| 60 |
+
print(f" 输出目录: {output_dir}")
|
| 61 |
+
print(f" 日志目录: {log_dir}")
|
| 62 |
+
|
| 63 |
+
return device
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def create_model(config: dict, device: torch.device) -> nn.Module:
|
| 67 |
+
"""创建模型"""
|
| 68 |
+
# 加载模型配置
|
| 69 |
+
model_config_path = config.get('model_config', 'configs/model/unet_light.yaml')
|
| 70 |
+
model_config = load_config(model_config_path)
|
| 71 |
+
|
| 72 |
+
# 创建UNet模型
|
| 73 |
+
model = UNetLight(model_config)
|
| 74 |
+
|
| 75 |
+
# 加载预训练权重(如果有)
|
| 76 |
+
pretrained_path = config.get('pretrained_path')
|
| 77 |
+
if pretrained_path and os.path.exists(pretrained_path):
|
| 78 |
+
print(f"加载预训练权重: {pretrained_path}")
|
| 79 |
+
checkpoint = torch.load(pretrained_path, map_location='cpu')
|
| 80 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 81 |
+
|
| 82 |
+
# 移动到设备
|
| 83 |
+
model = model.to(device)
|
| 84 |
+
|
| 85 |
+
# 打印模型信息
|
| 86 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 87 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 88 |
+
|
| 89 |
+
print(f"模型创建完成:")
|
| 90 |
+
print(f" 总参数量: {total_params:,}")
|
| 91 |
+
print(f" 可训练参数量: {trainable_params:,}")
|
| 92 |
+
print(f" 模型大小: {total_params * 4 / 1024**2:.2f} MB (fp32)")
|
| 93 |
+
|
| 94 |
+
return model
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def create_diffusion(config: dict) -> DiffusionProcess:
|
| 98 |
+
"""创建扩散过程"""
|
| 99 |
+
diffusion_config_path = config.get('diffusion_config', 'configs/model/diffusion.yaml')
|
| 100 |
+
diffusion_config = load_config(diffusion_config_path)
|
| 101 |
+
|
| 102 |
+
diffusion = DiffusionProcess(diffusion_config)
|
| 103 |
+
|
| 104 |
+
print(f"扩散过程创建完成:")
|
| 105 |
+
print(f" 训练时间步: {diffusion.num_train_timesteps}")
|
| 106 |
+
print(f" 推理时间步: {diffusion.num_inference_timesteps}")
|
| 107 |
+
print(f" Beta调度: {diffusion.beta_schedule}")
|
| 108 |
+
|
| 109 |
+
return diffusion
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_data_pipeline(config: dict):
|
| 113 |
+
"""创建数据管道"""
|
| 114 |
+
data_config_path = config.get('data_config', 'configs/data/laion_filtered.yaml')
|
| 115 |
+
data_config = load_config(data_config_path)
|
| 116 |
+
|
| 117 |
+
# 创建文本编码器
|
| 118 |
+
text_encoder = create_text_encoder(data_config)
|
| 119 |
+
|
| 120 |
+
# 创建数据加载器
|
| 121 |
+
train_loader, val_loader = create_data_loaders(data_config)
|
| 122 |
+
|
| 123 |
+
print(f"数据管道创建完成:")
|
| 124 |
+
print(f" 训练集大小: {len(train_loader.dataset)}")
|
| 125 |
+
print(f" 验证集大小: {len(val_loader.dataset) if val_loader else 0}")
|
| 126 |
+
print(f" 批次大小: {train_loader.batch_size}")
|
| 127 |
+
print(f" 梯度累积步数: {config.get('gradient_accumulation_steps', 8)}")
|
| 128 |
+
|
| 129 |
+
return train_loader, val_loader, text_encoder
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def create_optimizer(model: nn.Module, config: dict):
|
| 133 |
+
"""创建优化器"""
|
| 134 |
+
optimizer_config = config.get('optimizer', {})
|
| 135 |
+
optimizer_type = optimizer_config.get('type', 'AdamW')
|
| 136 |
+
learning_rate = optimizer_config.get('learning_rate', 1e-4)
|
| 137 |
+
weight_decay = optimizer_config.get('weight_decay', 0.01)
|
| 138 |
+
|
| 139 |
+
if optimizer_type == 'AdamW':
|
| 140 |
+
optimizer = torch.optim.AdamW(
|
| 141 |
+
model.parameters(),
|
| 142 |
+
lr=learning_rate,
|
| 143 |
+
weight_decay=weight_decay,
|
| 144 |
+
betas=(0.9, 0.999),
|
| 145 |
+
eps=1e-8
|
| 146 |
+
)
|
| 147 |
+
elif optimizer_type == 'Adam':
|
| 148 |
+
optimizer = torch.optim.Adam(
|
| 149 |
+
model.parameters(),
|
| 150 |
+
lr=learning_rate,
|
| 151 |
+
weight_decay=weight_decay
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError(f"未知的优化器类型: {optimizer_type}")
|
| 155 |
+
|
| 156 |
+
print(f"优化器创建完成:")
|
| 157 |
+
print(f" 类型: {optimizer_type}")
|
| 158 |
+
print(f" 学习率: {learning_rate}")
|
| 159 |
+
print(f" 权重衰减: {weight_decay}")
|
| 160 |
+
|
| 161 |
+
return optimizer
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def setup_memory_optimization(model: nn.Module, optimizer, config: dict):
|
| 165 |
+
"""设置内存优化"""
|
| 166 |
+
memory_optimizer = MemoryOptimizer(config)
|
| 167 |
+
memory_optimizer.setup_model_optimizations(model, optimizer)
|
| 168 |
+
|
| 169 |
+
# 打印内存信息
|
| 170 |
+
if torch.cuda.is_available():
|
| 171 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 172 |
+
reserved = torch.cuda.memory_reserved() / 1024**3
|
| 173 |
+
print(f"内存优化设置完成:")
|
| 174 |
+
print(f" GPU已分配: {allocated:.2f} GB")
|
| 175 |
+
print(f" GPU已保留: {reserved:.2f} GB")
|
| 176 |
+
|
| 177 |
+
return memory_optimizer
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def train(config_path: str, resume_from: str = None):
|
| 181 |
+
"""训练主函数"""
|
| 182 |
+
print("=" * 60)
|
| 183 |
+
print("Lumina 训练开始")
|
| 184 |
+
print("=" * 60)
|
| 185 |
+
|
| 186 |
+
# 加载配置
|
| 187 |
+
config = load_config(config_path)
|
| 188 |
+
|
| 189 |
+
# 设置环境
|
| 190 |
+
device = setup_environment(config)
|
| 191 |
+
|
| 192 |
+
# 创建模型
|
| 193 |
+
model = create_model(config, device)
|
| 194 |
+
|
| 195 |
+
# 创建扩散过程
|
| 196 |
+
diffusion = create_diffusion(config)
|
| 197 |
+
|
| 198 |
+
# 创建扩散模型
|
| 199 |
+
diffusion_model = DiffusionModel(model, diffusion)
|
| 200 |
+
|
| 201 |
+
# 创建数据管道
|
| 202 |
+
train_loader, val_loader, text_encoder = create_data_pipeline(config)
|
| 203 |
+
|
| 204 |
+
# 创建优化器
|
| 205 |
+
optimizer = create_optimizer(model, config)
|
| 206 |
+
|
| 207 |
+
# 设置内存优化
|
| 208 |
+
memory_optimizer = setup_memory_optimization(model, optimizer, config)
|
| 209 |
+
|
| 210 |
+
# 创建训练器
|
| 211 |
+
trainer = P4Trainer(
|
| 212 |
+
model=model,
|
| 213 |
+
diffusion=diffusion,
|
| 214 |
+
optimizer=optimizer,
|
| 215 |
+
train_loader=train_loader,
|
| 216 |
+
val_loader=val_loader,
|
| 217 |
+
config=config,
|
| 218 |
+
device=device
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# 创建回调
|
| 222 |
+
callbacks = create_default_callbacks(config)
|
| 223 |
+
|
| 224 |
+
# 加载检查点(如果存在)
|
| 225 |
+
if resume_from and os.path.exists(resume_from):
|
| 226 |
+
print(f"从检查点恢复训练: {resume_from}")
|
| 227 |
+
trainer.load_checkpoint(resume_from)
|
| 228 |
+
|
| 229 |
+
# 开始训练
|
| 230 |
+
try:
|
| 231 |
+
print("\n开始训练...")
|
| 232 |
+
trainer.train()
|
| 233 |
+
|
| 234 |
+
print("\n" + "=" * 60)
|
| 235 |
+
print("训练完成!")
|
| 236 |
+
print(f"最佳验证损失: {trainer.best_loss:.4f}")
|
| 237 |
+
print(f"总训练步数: {trainer.global_step}")
|
| 238 |
+
print("=" * 60)
|
| 239 |
+
|
| 240 |
+
except KeyboardInterrupt:
|
| 241 |
+
print("\n训练被中断")
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"\n训练出错: {e}")
|
| 244 |
+
import traceback
|
| 245 |
+
traceback.print_exc()
|
| 246 |
+
|
| 247 |
+
finally:
|
| 248 |
+
# 保存最终检查点
|
| 249 |
+
final_checkpoint = os.path.join(
|
| 250 |
+
config.get('checkpoint_dir', './checkpoints'),
|
| 251 |
+
'final_model.pt'
|
| 252 |
+
)
|
| 253 |
+
trainer.save_checkpoint(final_checkpoint)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def main():
|
| 257 |
+
"""主函数"""
|
| 258 |
+
parser = argparse.ArgumentParser(description="训练Lumina图像生成模型")
|
| 259 |
+
|
| 260 |
+
parser.add_argument(
|
| 261 |
+
"--config",
|
| 262 |
+
type=str,
|
| 263 |
+
default="configs/training/p4_optimized.yaml",
|
| 264 |
+
help="训练配置文件路径"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--resume",
|
| 269 |
+
type=str,
|
| 270 |
+
help="从检查点恢复训练"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--debug",
|
| 275 |
+
action="store_true",
|
| 276 |
+
help="调试模式"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
args = parser.parse_args()
|
| 280 |
+
|
| 281 |
+
# 调试模式设置
|
| 282 |
+
if args.debug:
|
| 283 |
+
import warnings
|
| 284 |
+
warnings.filterwarnings("always")
|
| 285 |
+
torch.autograd.set_detect_anomaly(True)
|
| 286 |
+
print("调试模式已启用")
|
| 287 |
+
|
| 288 |
+
# 开始训练
|
| 289 |
+
train(args.config, args.resume)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
main()
|
src/data/__pycache__/dataset.cpython-313.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, List, Optional, Tuple
|
| 7 |
+
import numpy as np
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LAIONDataset(Dataset):
|
| 13 |
+
"""LAION数据集"""
|
| 14 |
+
def __init__(self, config: dict, transform=None, split: str = 'train'):
|
| 15 |
+
self.config = config
|
| 16 |
+
self.transform = transform
|
| 17 |
+
self.split = split
|
| 18 |
+
|
| 19 |
+
# 加载元数据
|
| 20 |
+
metadata_path = config['dataset'].get('metadata_file', './data/laion/metadata.parquet')
|
| 21 |
+
self.metadata = pd.read_parquet(metadata_path)
|
| 22 |
+
|
| 23 |
+
# 应用过滤条件
|
| 24 |
+
self._apply_filters(config.get('filters', {}))
|
| 25 |
+
|
| 26 |
+
# 数据拆分
|
| 27 |
+
self._split_data(config.get('split', {}))
|
| 28 |
+
|
| 29 |
+
# 缓存
|
| 30 |
+
self.use_cache = config.get('use_cache', False)
|
| 31 |
+
self.cache_dir = config.get('cache_dir', './data/cache')
|
| 32 |
+
if self.use_cache:
|
| 33 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
# 限制样本数
|
| 36 |
+
max_samples = config.get('max_samples', None)
|
| 37 |
+
if max_samples is not None and len(self.metadata) > max_samples:
|
| 38 |
+
self.metadata = self.metadata.sample(max_samples, random_state=42)
|
| 39 |
+
|
| 40 |
+
# 文本缓存
|
| 41 |
+
self.text_cache = {}
|
| 42 |
+
|
| 43 |
+
print(f"数据集加载完成: {len(self.metadata)} 个样本 ({split}集)")
|
| 44 |
+
|
| 45 |
+
def _apply_filters(self, filters: dict):
|
| 46 |
+
"""应用过滤条件"""
|
| 47 |
+
if 'aesthetic_score' in filters:
|
| 48 |
+
threshold = filters['aesthetic_score']
|
| 49 |
+
if 'aesthetic_score' in self.metadata.columns:
|
| 50 |
+
self.metadata = self.metadata[self.metadata['aesthetic_score'] >= threshold]
|
| 51 |
+
|
| 52 |
+
if 'watermark_prob' in filters:
|
| 53 |
+
threshold = filters['watermark_prob']
|
| 54 |
+
if 'watermark_prob' in self.metadata.columns:
|
| 55 |
+
self.metadata = self.metadata[self.metadata['watermark_prob'] <= threshold]
|
| 56 |
+
|
| 57 |
+
if 'nsfw' in filters and not filters['nsfw']:
|
| 58 |
+
if 'NSFW' in self.metadata.columns:
|
| 59 |
+
self.metadata = self.metadata[self.metadata['NSFW'] != 'NSFW']
|
| 60 |
+
|
| 61 |
+
def _split_data(self, split_config: dict):
|
| 62 |
+
"""拆分数据集"""
|
| 63 |
+
if self.split not in ['train', 'val']:
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
train_ratio = split_config.get('train', 0.95)
|
| 67 |
+
val_ratio = split_config.get('val', 0.05)
|
| 68 |
+
|
| 69 |
+
# 确保拆分比例之和为1
|
| 70 |
+
total = train_ratio + val_ratio
|
| 71 |
+
train_ratio /= total
|
| 72 |
+
val_ratio /= total
|
| 73 |
+
|
| 74 |
+
# 随机拆分
|
| 75 |
+
seed = split_config.get('seed', 42)
|
| 76 |
+
shuffled = self.metadata.sample(frac=1, random_state=seed).reset_index(drop=True)
|
| 77 |
+
|
| 78 |
+
if self.split == 'train':
|
| 79 |
+
split_point = int(len(shuffled) * train_ratio)
|
| 80 |
+
self.metadata = shuffled[:split_point]
|
| 81 |
+
else:
|
| 82 |
+
split_point = int(len(shuffled) * train_ratio)
|
| 83 |
+
self.metadata = shuffled[split_point:]
|
| 84 |
+
|
| 85 |
+
def __len__(self) -> int:
|
| 86 |
+
return len(self.metadata)
|
| 87 |
+
|
| 88 |
+
def _get_image_path(self, row) -> str:
|
| 89 |
+
"""获取图像路径"""
|
| 90 |
+
# 尝试不同的列名
|
| 91 |
+
for col in ['image_file', 'filepath', 'path', 'url_local']:
|
| 92 |
+
if col in row:
|
| 93 |
+
path = row[col]
|
| 94 |
+
# 如果是相对路径,添加基础路径
|
| 95 |
+
if not os.path.isabs(path):
|
| 96 |
+
base_path = self.config['dataset'].get('path', './data/laion')
|
| 97 |
+
path = os.path.join(base_path, path)
|
| 98 |
+
return path
|
| 99 |
+
|
| 100 |
+
# 如果没有找到路径,使用URL哈希
|
| 101 |
+
if 'url' in row:
|
| 102 |
+
import hashlib
|
| 103 |
+
url_hash = hashlib.md5(row['url'].encode()).hexdigest()
|
| 104 |
+
base_path = self.config['dataset'].get('path', './data/laion')
|
| 105 |
+
path = os.path.join(base_path, f"{url_hash}.jpg")
|
| 106 |
+
return path
|
| 107 |
+
|
| 108 |
+
raise ValueError(f"无法找到图像路径: {row}")
|
| 109 |
+
|
| 110 |
+
def __getitem__(self, idx: int) -> Dict:
|
| 111 |
+
row = self.metadata.iloc[idx]
|
| 112 |
+
|
| 113 |
+
# 缓存键
|
| 114 |
+
cache_key = f"{self.split}_{idx}"
|
| 115 |
+
|
| 116 |
+
# 检查缓存
|
| 117 |
+
if self.use_cache and cache_key in self.text_cache:
|
| 118 |
+
text_embedding = self.text_cache[cache_key]
|
| 119 |
+
else:
|
| 120 |
+
# 获取文本描述
|
| 121 |
+
text = row.get('caption', row.get('text', row.get('description', '')))
|
| 122 |
+
|
| 123 |
+
# 这里应该调用文本编码器,但为了简化,我们返回原始文本
|
| 124 |
+
# 在实际使用中,应该使用预训练的CLIP编码器
|
| 125 |
+
text_embedding = text
|
| 126 |
+
|
| 127 |
+
# 缓存
|
| 128 |
+
if self.use_cache:
|
| 129 |
+
self.text_cache[cache_key] = text_embedding
|
| 130 |
+
|
| 131 |
+
# 获取图像
|
| 132 |
+
try:
|
| 133 |
+
image_path = self._get_image_path(row)
|
| 134 |
+
image = Image.open(image_path).convert('RGB')
|
| 135 |
+
|
| 136 |
+
# 应用变换
|
| 137 |
+
if self.transform:
|
| 138 |
+
image = self.transform(image)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
# 如果图像加载失败,返回一个空白图像
|
| 141 |
+
print(f"加载图像失败 {image_path}: {e}")
|
| 142 |
+
image = torch.zeros(3, 512, 512)
|
| 143 |
+
text = "invalid image"
|
| 144 |
+
text_embedding = text
|
| 145 |
+
|
| 146 |
+
return {
|
| 147 |
+
'image': image,
|
| 148 |
+
'text': text_embedding if isinstance(text_embedding, str) else '',
|
| 149 |
+
'text_embedding': text_embedding if not isinstance(text_embedding, str) else None,
|
| 150 |
+
'image_path': image_path if 'image_path' in locals() else '',
|
| 151 |
+
'index': idx
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class TextImageDataset(Dataset):
|
| 156 |
+
"""文本-图像对数据集"""
|
| 157 |
+
def __init__(self, image_dir: str, caption_file: str, transform=None):
|
| 158 |
+
self.image_dir = image_dir
|
| 159 |
+
self.transform = transform
|
| 160 |
+
|
| 161 |
+
# 加载标注文件
|
| 162 |
+
if caption_file.endswith('.json'):
|
| 163 |
+
with open(caption_file, 'r') as f:
|
| 164 |
+
self.captions = json.load(f)
|
| 165 |
+
elif caption_file.endswith('.csv'):
|
| 166 |
+
self.captions = pd.read_csv(caption_file)
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError(f"不支持的标注文件格式: {caption_file}")
|
| 169 |
+
|
| 170 |
+
# 验证图像文件是否存在
|
| 171 |
+
self.valid_samples = []
|
| 172 |
+
for item in self.captions:
|
| 173 |
+
if isinstance(item, dict):
|
| 174 |
+
image_name = item.get('image_name', item.get('file_name', ''))
|
| 175 |
+
caption = item.get('caption', '')
|
| 176 |
+
else:
|
| 177 |
+
image_name = item[0]
|
| 178 |
+
caption = item[1]
|
| 179 |
+
|
| 180 |
+
image_path = os.path.join(self.image_dir, image_name)
|
| 181 |
+
if os.path.exists(image_path):
|
| 182 |
+
self.valid_samples.append((image_path, caption))
|
| 183 |
+
|
| 184 |
+
print(f"找到 {len(self.valid_samples)} 个有效样本")
|
| 185 |
+
|
| 186 |
+
def __len__(self) -> int:
|
| 187 |
+
return len(self.valid_samples)
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, idx: int) -> Dict:
|
| 190 |
+
image_path, caption = self.valid_samples[idx]
|
| 191 |
+
|
| 192 |
+
# 加载图像
|
| 193 |
+
image = Image.open(image_path).convert('RGB')
|
| 194 |
+
|
| 195 |
+
# 应用变换
|
| 196 |
+
if self.transform:
|
| 197 |
+
image = self.transform(image)
|
| 198 |
+
|
| 199 |
+
return {
|
| 200 |
+
'image': image,
|
| 201 |
+
'text': caption,
|
| 202 |
+
'image_path': image_path
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class CachedDataset(Dataset):
|
| 207 |
+
"""缓存数据集,加速训练"""
|
| 208 |
+
def __init__(self, dataset: Dataset, cache_dir: str = './cache'):
|
| 209 |
+
self.dataset = dataset
|
| 210 |
+
self.cache_dir = cache_dir
|
| 211 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 212 |
+
|
| 213 |
+
self.cache_files = []
|
| 214 |
+
for i in range(len(dataset)):
|
| 215 |
+
cache_file = os.path.join(cache_dir, f'sample_{i}.pt')
|
| 216 |
+
self.cache_files.append(cache_file)
|
| 217 |
+
|
| 218 |
+
def __len__(self) -> int:
|
| 219 |
+
return len(self.dataset)
|
| 220 |
+
|
| 221 |
+
def __getitem__(self, idx: int) -> Dict:
|
| 222 |
+
cache_file = self.cache_files[idx]
|
| 223 |
+
|
| 224 |
+
# 如果缓存存在,直接加载
|
| 225 |
+
if os.path.exists(cache_file):
|
| 226 |
+
try:
|
| 227 |
+
return torch.load(cache_file)
|
| 228 |
+
except:
|
| 229 |
+
pass
|
| 230 |
+
|
| 231 |
+
# 否则,从原始数据集加载并缓存
|
| 232 |
+
sample = self.dataset[idx]
|
| 233 |
+
torch.save(sample, cache_file)
|
| 234 |
+
|
| 235 |
+
return sample
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def create_data_loaders(config: dict) -> Tuple[DataLoader, Optional[DataLoader]]:
|
| 239 |
+
"""创建数据加载器"""
|
| 240 |
+
from .preprocessing import get_transform
|
| 241 |
+
|
| 242 |
+
# 获取数据变换
|
| 243 |
+
train_transform = get_transform(config, mode='train')
|
| 244 |
+
val_transform = get_transform(config, mode='val')
|
| 245 |
+
|
| 246 |
+
# 创建数据集
|
| 247 |
+
train_dataset = LAIONDataset(config, transform=train_transform, split='train')
|
| 248 |
+
val_dataset = LAIONDataset(config, transform=val_transform, split='val')
|
| 249 |
+
|
| 250 |
+
# 可选:启用缓存
|
| 251 |
+
if config.get('cache_dataset', True):
|
| 252 |
+
train_dataset = CachedDataset(train_dataset, cache_dir='./data/cache/train')
|
| 253 |
+
val_dataset = CachedDataset(val_dataset, cache_dir='./data/cache/val')
|
| 254 |
+
|
| 255 |
+
# 创建数据加载器
|
| 256 |
+
train_loader = DataLoader(
|
| 257 |
+
train_dataset,
|
| 258 |
+
batch_size=config.get('batch_size', 1),
|
| 259 |
+
shuffle=config.get('shuffle', True),
|
| 260 |
+
num_workers=config.get('num_workers', 2),
|
| 261 |
+
pin_memory=config.get('pin_memory', True),
|
| 262 |
+
prefetch_factor=config.get('prefetch_factor', 2),
|
| 263 |
+
persistent_workers=config.get('persistent_workers', True)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
val_loader = DataLoader(
|
| 267 |
+
val_dataset,
|
| 268 |
+
batch_size=1, # 验证时批次大小为1
|
| 269 |
+
shuffle=False,
|
| 270 |
+
num_workers=config.get('num_workers', 2),
|
| 271 |
+
pin_memory=True
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
return train_loader, val_loader
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def test_dataset():
|
| 278 |
+
"""测试数据集"""
|
| 279 |
+
import yaml
|
| 280 |
+
|
| 281 |
+
# 加载配置
|
| 282 |
+
with open('configs/data/laion_filtered.yaml', 'r') as f:
|
| 283 |
+
config = yaml.safe_load(f)
|
| 284 |
+
|
| 285 |
+
# 创建数据集
|
| 286 |
+
dataset = LAIONDataset(config, split='train')
|
| 287 |
+
|
| 288 |
+
# 测试样本
|
| 289 |
+
sample = dataset[0]
|
| 290 |
+
print(f"样本键: {list(sample.keys())}")
|
| 291 |
+
print(f"图像形状: {sample['image'].shape if hasattr(sample['image'], 'shape') else type(sample['image'])}")
|
| 292 |
+
print(f"文本: {sample['text'][:100]}...")
|
| 293 |
+
|
| 294 |
+
return dataset
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
if __name__ == '__main__':
|
| 298 |
+
test_dataset()
|
src/data/preprocessing.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision.transforms as T
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Dict, List, Optional, Tuple
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_transform(config: dict, mode: str = 'train') -> T.Compose:
|
| 10 |
+
"""获取数据预处理变换"""
|
| 11 |
+
preprocessing = config.get('preprocessing', {})
|
| 12 |
+
target_size = preprocessing.get('target_size', 512)
|
| 13 |
+
resize_mode = preprocessing.get('resize_mode', 'center_crop')
|
| 14 |
+
|
| 15 |
+
transforms_list = []
|
| 16 |
+
|
| 17 |
+
# 训练和验证的不同变换
|
| 18 |
+
if mode == 'train':
|
| 19 |
+
# 随机裁剪
|
| 20 |
+
if preprocessing.get('random_crop', True):
|
| 21 |
+
transforms_list.append(T.RandomResizedCrop(
|
| 22 |
+
target_size,
|
| 23 |
+
scale=(0.8, 1.0),
|
| 24 |
+
ratio=(0.8, 1.2)
|
| 25 |
+
))
|
| 26 |
+
else:
|
| 27 |
+
transforms_list.append(T.Resize(target_size, interpolation=T.InterpolationMode.BILINEAR))
|
| 28 |
+
|
| 29 |
+
# 随机水平翻转
|
| 30 |
+
if preprocessing.get('random_flip', True):
|
| 31 |
+
transforms_list.append(T.RandomHorizontalFlip(p=0.5))
|
| 32 |
+
|
| 33 |
+
# 颜色抖动
|
| 34 |
+
if preprocessing.get('color_jitter', 0.05) > 0:
|
| 35 |
+
color_jitter = preprocessing['color_jitter']
|
| 36 |
+
transforms_list.append(T.ColorJitter(
|
| 37 |
+
brightness=color_jitter,
|
| 38 |
+
contrast=color_jitter,
|
| 39 |
+
saturation=color_jitter,
|
| 40 |
+
hue=min(0.1, color_jitter)
|
| 41 |
+
))
|
| 42 |
+
|
| 43 |
+
# 随机旋转
|
| 44 |
+
if preprocessing.get('random_rotation', 0.0) > 0:
|
| 45 |
+
max_angle = preprocessing['random_rotation']
|
| 46 |
+
transforms_list.append(T.RandomRotation(degrees=(-max_angle, max_angle)))
|
| 47 |
+
|
| 48 |
+
else: # 验证/测试模式
|
| 49 |
+
# 中心裁剪
|
| 50 |
+
if resize_mode == 'center_crop':
|
| 51 |
+
transforms_list.extend([
|
| 52 |
+
T.Resize(target_size, interpolation=T.InterpolationMode.BILINEAR),
|
| 53 |
+
T.CenterCrop(target_size)
|
| 54 |
+
])
|
| 55 |
+
elif resize_mode == 'resize':
|
| 56 |
+
transforms_list.append(T.Resize(target_size, interpolation=T.InterpolationMode.BILINEAR))
|
| 57 |
+
elif resize_mode == 'random_crop':
|
| 58 |
+
transforms_list.append(T.RandomCrop(target_size))
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"未知的resize_mode: {resize_mode}")
|
| 61 |
+
|
| 62 |
+
# 转换为Tensor
|
| 63 |
+
transforms_list.append(T.ToTensor())
|
| 64 |
+
|
| 65 |
+
# 归一化
|
| 66 |
+
normalize_config = preprocessing.get('normalize', {})
|
| 67 |
+
mean = normalize_config.get('mean', [0.5, 0.5, 0.5])
|
| 68 |
+
std = normalize_config.get('std', [0.5, 0.5, 0.5])
|
| 69 |
+
transforms_list.append(T.Normalize(mean=mean, std=std))
|
| 70 |
+
|
| 71 |
+
return T.Compose(transforms_list)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TextPreprocessor:
|
| 75 |
+
"""文本预处理器"""
|
| 76 |
+
def __init__(self, config: dict):
|
| 77 |
+
self.config = config.get('preprocessing', {})
|
| 78 |
+
|
| 79 |
+
# 文本处理参数
|
| 80 |
+
self.max_length = self.config.get('max_length', 77)
|
| 81 |
+
self.truncation = self.config.get('truncation', True)
|
| 82 |
+
self.padding = self.config.get('padding', 'max_length')
|
| 83 |
+
|
| 84 |
+
# 尝试加载tokenizer
|
| 85 |
+
self.tokenizer = None
|
| 86 |
+
self._init_tokenizer()
|
| 87 |
+
|
| 88 |
+
def _init_tokenizer(self):
|
| 89 |
+
"""初始化tokenizer"""
|
| 90 |
+
try:
|
| 91 |
+
from transformers import CLIPTokenizer
|
| 92 |
+
tokenizer_name = self.config.get('tokenizer', 'openai/clip-vit-base-patch32')
|
| 93 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name)
|
| 94 |
+
except ImportError:
|
| 95 |
+
print("警告: 未安装transformers,无法使用CLIP tokenizer")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"加载tokenizer失败: {e}")
|
| 98 |
+
|
| 99 |
+
def preprocess_text(self, text: str) -> Dict:
|
| 100 |
+
"""预处理文本"""
|
| 101 |
+
if self.tokenizer is not None:
|
| 102 |
+
# 使用CLIP tokenizer
|
| 103 |
+
inputs = self.tokenizer(
|
| 104 |
+
text,
|
| 105 |
+
max_length=self.max_length,
|
| 106 |
+
padding=self.padding,
|
| 107 |
+
truncation=self.truncation,
|
| 108 |
+
return_tensors="pt"
|
| 109 |
+
)
|
| 110 |
+
return {
|
| 111 |
+
'input_ids': inputs['input_ids'].squeeze(0),
|
| 112 |
+
'attention_mask': inputs['attention_mask'].squeeze(0)
|
| 113 |
+
}
|
| 114 |
+
else:
|
| 115 |
+
# 简单的文本处理
|
| 116 |
+
return {
|
| 117 |
+
'text': text,
|
| 118 |
+
'length': len(text)
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def batch_preprocess(self, texts: List[str]) -> Dict:
|
| 122 |
+
"""批量预处理文本"""
|
| 123 |
+
if self.tokenizer is not None:
|
| 124 |
+
inputs = self.tokenizer(
|
| 125 |
+
texts,
|
| 126 |
+
max_length=self.max_length,
|
| 127 |
+
padding=self.padding,
|
| 128 |
+
truncation=self.truncation,
|
| 129 |
+
return_tensors="pt"
|
| 130 |
+
)
|
| 131 |
+
return inputs
|
| 132 |
+
else:
|
| 133 |
+
return {'texts': texts}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class ImagePreprocessor:
|
| 137 |
+
"""图像预处理器"""
|
| 138 |
+
def __init__(self, config: dict):
|
| 139 |
+
self.config = config.get('preprocessing', {})
|
| 140 |
+
self.transform = get_transform(config, mode='train')
|
| 141 |
+
|
| 142 |
+
def preprocess_image(self, image: Image.Image) -> torch.Tensor:
|
| 143 |
+
"""预处理单张图像"""
|
| 144 |
+
return self.transform(image)
|
| 145 |
+
|
| 146 |
+
def batch_preprocess(self, images: List[Image.Image]) -> torch.Tensor:
|
| 147 |
+
"""批量预处理图像"""
|
| 148 |
+
return torch.stack([self.transform(img) for img in images])
|
| 149 |
+
|
| 150 |
+
def preprocess_for_vae(self, image: torch.Tensor) -> torch.Tensor:
|
| 151 |
+
"""为VAE编码预处理图像"""
|
| 152 |
+
# VAE期望输入在[-1, 1]范围内
|
| 153 |
+
return image * 2.0 - 1.0
|
| 154 |
+
|
| 155 |
+
def postprocess_from_vae(self, latents: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
"""从VAE解码后处理图像"""
|
| 157 |
+
# 将[-1, 1]范围转换回[0, 1]
|
| 158 |
+
return (latents + 1.0) / 2.0
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class DataPreprocessor:
|
| 162 |
+
"""数据预处理器(整合文本和图像处理)"""
|
| 163 |
+
def __init__(self, config: dict):
|
| 164 |
+
self.config = config
|
| 165 |
+
self.image_preprocessor = ImagePreprocessor(config)
|
| 166 |
+
self.text_preprocessor = TextPreprocessor(config)
|
| 167 |
+
|
| 168 |
+
# 文本编码器
|
| 169 |
+
self.text_encoder = None
|
| 170 |
+
self._init_text_encoder()
|
| 171 |
+
|
| 172 |
+
def _init_text_encoder(self):
|
| 173 |
+
"""初始化文本编码器"""
|
| 174 |
+
try:
|
| 175 |
+
from transformers import CLIPTextModel
|
| 176 |
+
model_name = self.config.get('preprocessing', {}).get('text_encoder', 'openai/clip-vit-base-patch32')
|
| 177 |
+
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
| 178 |
+
|
| 179 |
+
# 冻结参数
|
| 180 |
+
for param in self.text_encoder.parameters():
|
| 181 |
+
param.requires_grad = False
|
| 182 |
+
|
| 183 |
+
# 设置为评估模式
|
| 184 |
+
self.text_encoder.eval()
|
| 185 |
+
|
| 186 |
+
print(f"已加载文本编码器: {model_name}")
|
| 187 |
+
except Exception as e:
|
| 188 |
+
print(f"加载文本编码器失败: {e}")
|
| 189 |
+
|
| 190 |
+
def encode_text(self, text: str) -> torch.Tensor:
|
| 191 |
+
"""编码文本为嵌入向量"""
|
| 192 |
+
if self.text_encoder is None:
|
| 193 |
+
raise ValueError("文本编码器未初始化")
|
| 194 |
+
|
| 195 |
+
# 预处理文本
|
| 196 |
+
inputs = self.text_preprocessor.preprocess_text(text)
|
| 197 |
+
|
| 198 |
+
# 编码
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
if 'input_ids' in inputs:
|
| 201 |
+
outputs = self.text_encoder(
|
| 202 |
+
input_ids=inputs['input_ids'].unsqueeze(0),
|
| 203 |
+
attention_mask=inputs['attention_mask'].unsqueeze(0) if 'attention_mask' in inputs else None
|
| 204 |
+
)
|
| 205 |
+
return outputs.last_hidden_state.squeeze(0)
|
| 206 |
+
else:
|
| 207 |
+
# 回退到简单的嵌入
|
| 208 |
+
return torch.randn(77, 768) # 默认维度
|
| 209 |
+
|
| 210 |
+
def batch_encode_text(self, texts: List[str]) -> torch.Tensor:
|
| 211 |
+
"""批量编码文本"""
|
| 212 |
+
if self.text_encoder is None:
|
| 213 |
+
raise ValueError("文本编码器未初始化")
|
| 214 |
+
|
| 215 |
+
# 预处理文本
|
| 216 |
+
inputs = self.text_preprocessor.batch_preprocess(texts)
|
| 217 |
+
|
| 218 |
+
# 编码
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
if 'input_ids' in inputs:
|
| 221 |
+
outputs = self.text_encoder(
|
| 222 |
+
input_ids=inputs['input_ids'],
|
| 223 |
+
attention_mask=inputs.get('attention_mask', None)
|
| 224 |
+
)
|
| 225 |
+
return outputs.last_hidden_state
|
| 226 |
+
else:
|
| 227 |
+
# 回退到简单的嵌入
|
| 228 |
+
batch_size = len(texts)
|
| 229 |
+
return torch.randn(batch_size, 77, 768)
|
| 230 |
+
|
| 231 |
+
def preprocess_batch(self, batch: List[Dict]) -> Dict:
|
| 232 |
+
"""预处理批次数据"""
|
| 233 |
+
images = [item['image'] for item in batch]
|
| 234 |
+
texts = [item['text'] for item in batch]
|
| 235 |
+
|
| 236 |
+
# 预处理图像
|
| 237 |
+
image_tensors = self.image_preprocessor.batch_preprocess(images)
|
| 238 |
+
|
| 239 |
+
# 编码文本
|
| 240 |
+
if self.text_encoder is not None:
|
| 241 |
+
text_embeddings = self.batch_encode_text(texts)
|
| 242 |
+
else:
|
| 243 |
+
text_embeddings = None
|
| 244 |
+
|
| 245 |
+
return {
|
| 246 |
+
'images': image_tensors,
|
| 247 |
+
'text_embeddings': text_embeddings,
|
| 248 |
+
'texts': texts,
|
| 249 |
+
'image_paths': [item.get('image_path', '') for item in batch]
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def test_preprocessing():
|
| 254 |
+
"""测试预处理"""
|
| 255 |
+
import yaml
|
| 256 |
+
from PIL import Image
|
| 257 |
+
|
| 258 |
+
# 创建测试图像
|
| 259 |
+
test_image = Image.new('RGB', (512, 512), color='red')
|
| 260 |
+
|
| 261 |
+
# 加载配置
|
| 262 |
+
with open('configs/data/laion_filtered.yaml', 'r') as f:
|
| 263 |
+
config = yaml.safe_load(f)
|
| 264 |
+
|
| 265 |
+
# 测试图像预处理
|
| 266 |
+
image_preprocessor = ImagePreprocessor(config)
|
| 267 |
+
processed_image = image_preprocessor.preprocess_image(test_image)
|
| 268 |
+
|
| 269 |
+
print(f"原始图像: {test_image.size}")
|
| 270 |
+
print(f"处理后图像形状: {processed_image.shape}")
|
| 271 |
+
|
| 272 |
+
# 测试文本预处理
|
| 273 |
+
text_preprocessor = TextPreprocessor(config)
|
| 274 |
+
processed_text = text_preprocessor.preprocess_text("A red square image")
|
| 275 |
+
|
| 276 |
+
print(f"文本处理结果: {processed_text}")
|
| 277 |
+
|
| 278 |
+
# 测试数据预处理器
|
| 279 |
+
data_preprocessor = DataPreprocessor(config)
|
| 280 |
+
|
| 281 |
+
test_batch = [
|
| 282 |
+
{'image': test_image, 'text': "A red square"},
|
| 283 |
+
{'image': test_image, 'text': "A blue circle"}
|
| 284 |
+
]
|
| 285 |
+
|
| 286 |
+
processed_batch = data_preprocessor.preprocess_batch(test_batch)
|
| 287 |
+
|
| 288 |
+
print(f"批次图像形状: {processed_batch['images'].shape}")
|
| 289 |
+
print(f"文本嵌入形状: {processed_batch['text_embeddings'].shape if processed_batch['text_embeddings'] is not None else 'None'}")
|
| 290 |
+
|
| 291 |
+
return processed_batch
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
if __name__ == '__main__':
|
| 295 |
+
test_preprocessing()
|
src/data/text_encoder.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPConfig
|
| 4 |
+
from typing import Optional, Tuple, List
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LightTextEncoder(nn.Module):
|
| 9 |
+
"""轻量级文本编码器"""
|
| 10 |
+
def __init__(self, config: dict):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.config = config
|
| 13 |
+
|
| 14 |
+
# 编码器参数
|
| 15 |
+
self.vocab_size = config.get('vocab_size', 49408)
|
| 16 |
+
self.hidden_size = config.get('hidden_size', 512)
|
| 17 |
+
self.num_hidden_layers = config.get('num_hidden_layers', 8)
|
| 18 |
+
self.num_attention_heads = config.get('num_attention_heads', 8)
|
| 19 |
+
self.max_position_embeddings = config.get('max_position_embeddings', 77)
|
| 20 |
+
|
| 21 |
+
# 构建编码器
|
| 22 |
+
self.token_embedding = nn.Embedding(self.vocab_size, self.hidden_size)
|
| 23 |
+
self.position_embedding = nn.Embedding(self.max_position_embeddings, self.hidden_size)
|
| 24 |
+
|
| 25 |
+
# Transformer层
|
| 26 |
+
self.layers = nn.ModuleList([
|
| 27 |
+
TransformerLayer(self.hidden_size, self.num_attention_heads)
|
| 28 |
+
for _ in range(self.num_hidden_layers)
|
| 29 |
+
])
|
| 30 |
+
|
| 31 |
+
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
|
| 32 |
+
|
| 33 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 34 |
+
# 嵌入
|
| 35 |
+
token_embeddings = self.token_embedding(input_ids)
|
| 36 |
+
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0)
|
| 37 |
+
position_embeddings = self.position_embedding(position_ids)
|
| 38 |
+
|
| 39 |
+
hidden_states = token_embeddings + position_embeddings
|
| 40 |
+
|
| 41 |
+
# Transformer层
|
| 42 |
+
for layer in self.layers:
|
| 43 |
+
hidden_states = layer(hidden_states, attention_mask)
|
| 44 |
+
|
| 45 |
+
# 最终层归一化
|
| 46 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 47 |
+
|
| 48 |
+
return hidden_states
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TransformerLayer(nn.Module):
|
| 52 |
+
"""Transformer层"""
|
| 53 |
+
def __init__(self, hidden_size: int, num_heads: int):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.attention = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
|
| 56 |
+
self.attention_norm = nn.LayerNorm(hidden_size)
|
| 57 |
+
|
| 58 |
+
self.mlp = nn.Sequential(
|
| 59 |
+
nn.Linear(hidden_size, hidden_size * 4),
|
| 60 |
+
nn.GELU(),
|
| 61 |
+
nn.Linear(hidden_size * 4, hidden_size)
|
| 62 |
+
)
|
| 63 |
+
self.mlp_norm = nn.LayerNorm(hidden_size)
|
| 64 |
+
|
| 65 |
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 66 |
+
# 自注意力
|
| 67 |
+
attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
|
| 68 |
+
x = self.attention_norm(x + attn_output)
|
| 69 |
+
|
| 70 |
+
# 前馈网络
|
| 71 |
+
mlp_output = self.mlp(x)
|
| 72 |
+
x = self.mlp_norm(x + mlp_output)
|
| 73 |
+
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class CLIPTextEncoderWrapper:
|
| 78 |
+
"""CLIP文本编码器包装器"""
|
| 79 |
+
def __init__(self, model_name: str = 'openai/clip-vit-base-patch32', device: str = 'cuda'):
|
| 80 |
+
self.model_name = model_name
|
| 81 |
+
self.device = device
|
| 82 |
+
|
| 83 |
+
# 加载tokenizer和模型
|
| 84 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
| 85 |
+
|
| 86 |
+
# 只加载文本模型
|
| 87 |
+
self.text_model = CLIPTextModel.from_pretrained(model_name).to(device)
|
| 88 |
+
|
| 89 |
+
# 冻结参数
|
| 90 |
+
for param in self.text_model.parameters():
|
| 91 |
+
param.requires_grad = False
|
| 92 |
+
|
| 93 |
+
# 设置为评估模式
|
| 94 |
+
self.text_model.eval()
|
| 95 |
+
|
| 96 |
+
print(f"已加载CLIP文本编码器: {model_name}")
|
| 97 |
+
|
| 98 |
+
def encode(self, texts: List[str], return_tensors: str = 'pt') -> torch.Tensor:
|
| 99 |
+
"""编码文本"""
|
| 100 |
+
# Tokenize
|
| 101 |
+
inputs = self.tokenizer(
|
| 102 |
+
texts,
|
| 103 |
+
padding=True,
|
| 104 |
+
truncation=True,
|
| 105 |
+
max_length=77,
|
| 106 |
+
return_tensors=return_tensors
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 移动到设备
|
| 110 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 111 |
+
|
| 112 |
+
# 编码
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
outputs = self.text_model(**inputs)
|
| 115 |
+
return outputs.last_hidden_state
|
| 116 |
+
|
| 117 |
+
def encode_batch(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
|
| 118 |
+
"""分批编码文本"""
|
| 119 |
+
all_embeddings = []
|
| 120 |
+
|
| 121 |
+
for i in range(0, len(texts), batch_size):
|
| 122 |
+
batch_texts = texts[i:i + batch_size]
|
| 123 |
+
batch_embeddings = self.encode(batch_texts)
|
| 124 |
+
all_embeddings.append(batch_embeddings.cpu())
|
| 125 |
+
|
| 126 |
+
return torch.cat(all_embeddings, dim=0)
|
| 127 |
+
|
| 128 |
+
def save_embeddings(self, texts: List[str], save_path: str):
|
| 129 |
+
"""保存文本嵌入"""
|
| 130 |
+
embeddings = self.encode_batch(texts)
|
| 131 |
+
torch.save(embeddings, save_path)
|
| 132 |
+
print(f"嵌入已保存到: {save_path}")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class CachedTextEncoder:
|
| 136 |
+
"""带缓存的文本编码器"""
|
| 137 |
+
def __init__(self, encoder, cache_dir: str = './text_cache'):
|
| 138 |
+
self.encoder = encoder
|
| 139 |
+
self.cache_dir = cache_dir
|
| 140 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 141 |
+
|
| 142 |
+
# 内存缓存
|
| 143 |
+
self.memory_cache = {}
|
| 144 |
+
|
| 145 |
+
def encode(self, text: str) -> torch.Tensor:
|
| 146 |
+
"""编码文本,使用缓存"""
|
| 147 |
+
# 生成缓存键
|
| 148 |
+
import hashlib
|
| 149 |
+
cache_key = hashlib.md5(text.encode()).hexdigest()
|
| 150 |
+
|
| 151 |
+
# 检查内存缓存
|
| 152 |
+
if cache_key in self.memory_cache:
|
| 153 |
+
return self.memory_cache[cache_key]
|
| 154 |
+
|
| 155 |
+
# 检查磁盘缓存
|
| 156 |
+
cache_file = os.path.join(self.cache_dir, f"{cache_key}.pt")
|
| 157 |
+
if os.path.exists(cache_file):
|
| 158 |
+
embedding = torch.load(cache_file)
|
| 159 |
+
self.memory_cache[cache_key] = embedding
|
| 160 |
+
return embedding
|
| 161 |
+
|
| 162 |
+
# 编码并缓存
|
| 163 |
+
embedding = self.encoder.encode([text])[0]
|
| 164 |
+
|
| 165 |
+
# 保存到内存缓存
|
| 166 |
+
self.memory_cache[cache_key] = embedding
|
| 167 |
+
|
| 168 |
+
# 保存到磁盘缓存
|
| 169 |
+
torch.save(embedding, cache_file)
|
| 170 |
+
|
| 171 |
+
return embedding
|
| 172 |
+
|
| 173 |
+
def encode_batch(self, texts: List[str]) -> torch.Tensor:
|
| 174 |
+
"""批量编码文本"""
|
| 175 |
+
embeddings = []
|
| 176 |
+
|
| 177 |
+
for text in texts:
|
| 178 |
+
embedding = self.encode(text)
|
| 179 |
+
embeddings.append(embedding.unsqueeze(0))
|
| 180 |
+
|
| 181 |
+
return torch.cat(embeddings, dim=0)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def create_text_encoder(config: dict) -> CLIPTextEncoderWrapper:
|
| 185 |
+
"""创建文本编码器"""
|
| 186 |
+
model_name = config.get('preprocessing', {}).get('tokenizer', 'openai/clip-vit-base-patch32')
|
| 187 |
+
device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
|
| 188 |
+
|
| 189 |
+
encoder = CLIPTextEncoderWrapper(model_name, device)
|
| 190 |
+
|
| 191 |
+
# 如果需要缓存,包装一层
|
| 192 |
+
if config.get('use_cache', True):
|
| 193 |
+
cache_dir = config.get('cache_dir', './data/text_cache')
|
| 194 |
+
encoder = CachedTextEncoder(encoder, cache_dir)
|
| 195 |
+
|
| 196 |
+
return encoder
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def test_text_encoder():
|
| 200 |
+
"""测试文本编码器"""
|
| 201 |
+
config = {
|
| 202 |
+
'preprocessing': {
|
| 203 |
+
'tokenizer': 'openai/clip-vit-base-patch32'
|
| 204 |
+
},
|
| 205 |
+
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
encoder = create_text_encoder(config)
|
| 209 |
+
|
| 210 |
+
# 测试编码
|
| 211 |
+
texts = [
|
| 212 |
+
"A beautiful sunset over the mountains",
|
| 213 |
+
"A cute cat playing with a ball",
|
| 214 |
+
"An astronaut riding a horse on Mars"
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
embeddings = encoder.encode(texts)
|
| 218 |
+
|
| 219 |
+
print(f"文本数量: {len(texts)}")
|
| 220 |
+
print(f"嵌入形状: {embeddings.shape}")
|
| 221 |
+
print(f"嵌入范围: [{embeddings.min():.4f}, {embeddings.max():.4f}]")
|
| 222 |
+
|
| 223 |
+
return encoder, embeddings
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == '__main__':
|
| 227 |
+
encoder, embeddings = test_text_encoder()
|
src/inference/api.py
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 2 |
+
from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from typing import Optional, List, Dict, Any
|
| 5 |
+
import torch
|
| 6 |
+
import io
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import base64
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import asyncio
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 14 |
+
import uuid
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
from .sampler import TextToImagePipeline, SamplerFactory
|
| 18 |
+
from ..data.text_encoder import CLIPTextEncoderWrapper
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# 请求/响应模型
|
| 22 |
+
class TextToImageRequest(BaseModel):
|
| 23 |
+
prompt: str
|
| 24 |
+
negative_prompt: Optional[str] = ""
|
| 25 |
+
width: int = 512
|
| 26 |
+
height: int = 512
|
| 27 |
+
num_steps: int = 50
|
| 28 |
+
guidance_scale: float = 7.5
|
| 29 |
+
num_images: int = 1
|
| 30 |
+
seed: Optional[int] = None
|
| 31 |
+
sampler: str = "ddim"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ImageResponse(BaseModel):
|
| 35 |
+
images: List[str] # base64编码的图像
|
| 36 |
+
metadata: Dict[str, Any]
|
| 37 |
+
request_id: str
|
| 38 |
+
generation_time: float
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class BatchRequest(BaseModel):
|
| 42 |
+
requests: List[TextToImageRequest]
|
| 43 |
+
priority: int = 0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class StatusResponse(BaseModel):
|
| 47 |
+
status: str
|
| 48 |
+
queue_length: int
|
| 49 |
+
active_tasks: int
|
| 50 |
+
gpu_memory_usage: float
|
| 51 |
+
uptime: float
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# API应用
|
| 55 |
+
class LuminaAPI:
|
| 56 |
+
"""Lumina API服务器"""
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
model,
|
| 60 |
+
diffusion,
|
| 61 |
+
text_encoder,
|
| 62 |
+
vae_decoder=None,
|
| 63 |
+
host: str = "0.0.0.0",
|
| 64 |
+
port: int = 8000,
|
| 65 |
+
max_queue_size: int = 100,
|
| 66 |
+
max_workers: int = 2
|
| 67 |
+
):
|
| 68 |
+
self.model = model
|
| 69 |
+
self.diffusion = diffusion
|
| 70 |
+
self.text_encoder = text_encoder
|
| 71 |
+
self.vae_decoder = vae_decoder
|
| 72 |
+
self.host = host
|
| 73 |
+
self.port = port
|
| 74 |
+
|
| 75 |
+
# 创建FastAPI应用
|
| 76 |
+
self.app = FastAPI(
|
| 77 |
+
title="Lumina Image Generation API",
|
| 78 |
+
description="轻量级图像生成模型API",
|
| 79 |
+
version="1.0.0"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# 任务队列
|
| 83 |
+
self.task_queue = asyncio.Queue(maxsize=max_queue_size)
|
| 84 |
+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
| 85 |
+
self.active_tasks = 0
|
| 86 |
+
|
| 87 |
+
# 请求历史
|
| 88 |
+
self.request_history = []
|
| 89 |
+
self.max_history = 1000
|
| 90 |
+
|
| 91 |
+
# 统计信息
|
| 92 |
+
self.start_time = time.time()
|
| 93 |
+
self.total_requests = 0
|
| 94 |
+
self.total_images = 0
|
| 95 |
+
|
| 96 |
+
# 初始化管道
|
| 97 |
+
self.pipeline = TextToImagePipeline(
|
| 98 |
+
model=model,
|
| 99 |
+
diffusion=diffusion,
|
| 100 |
+
text_encoder=text_encoder,
|
| 101 |
+
vae_decoder=vae_decoder,
|
| 102 |
+
sampler_type="ddim"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# 设置路由
|
| 106 |
+
self._setup_routes()
|
| 107 |
+
|
| 108 |
+
def _setup_routes(self):
|
| 109 |
+
"""设置API路由"""
|
| 110 |
+
|
| 111 |
+
@self.app.get("/")
|
| 112 |
+
async def root():
|
| 113 |
+
return {
|
| 114 |
+
"message": "Lumina Image Generation API",
|
| 115 |
+
"version": "1.0.0",
|
| 116 |
+
"docs": "/docs",
|
| 117 |
+
"endpoints": [
|
| 118 |
+
"/generate",
|
| 119 |
+
"/batch_generate",
|
| 120 |
+
"/status",
|
| 121 |
+
"/health"
|
| 122 |
+
]
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
@self.app.get("/health")
|
| 126 |
+
async def health_check():
|
| 127 |
+
"""健康检查"""
|
| 128 |
+
gpu_available = torch.cuda.is_available()
|
| 129 |
+
gpu_memory = torch.cuda.memory_allocated() / 1024**3 if gpu_available else 0
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
"status": "healthy",
|
| 133 |
+
"gpu_available": gpu_available,
|
| 134 |
+
"gpu_memory_gb": gpu_memory,
|
| 135 |
+
"model_loaded": self.model is not None,
|
| 136 |
+
"text_encoder_loaded": self.text_encoder is not None
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
@self.app.get("/status")
|
| 140 |
+
async def get_status():
|
| 141 |
+
"""获取服务状态"""
|
| 142 |
+
uptime = time.time() - self.start_time
|
| 143 |
+
|
| 144 |
+
# GPU内存使用
|
| 145 |
+
if torch.cuda.is_available():
|
| 146 |
+
gpu_memory = torch.cuda.memory_allocated() / 1024**3
|
| 147 |
+
else:
|
| 148 |
+
gpu_memory = 0
|
| 149 |
+
|
| 150 |
+
return StatusResponse(
|
| 151 |
+
status="running",
|
| 152 |
+
queue_length=self.task_queue.qsize(),
|
| 153 |
+
active_tasks=self.active_tasks,
|
| 154 |
+
gpu_memory_usage=gpu_memory,
|
| 155 |
+
uptime=uptime
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
@self.app.post("/generate", response_model=ImageResponse)
|
| 159 |
+
async def generate_image(request: TextToImageRequest):
|
| 160 |
+
"""生成单个图像"""
|
| 161 |
+
request_id = str(uuid.uuid4())
|
| 162 |
+
|
| 163 |
+
# 记录请求
|
| 164 |
+
self.request_history.append({
|
| 165 |
+
"request_id": request_id,
|
| 166 |
+
"prompt": request.prompt,
|
| 167 |
+
"timestamp": datetime.now().isoformat()
|
| 168 |
+
})
|
| 169 |
+
|
| 170 |
+
# 限制历史记录大小
|
| 171 |
+
if len(self.request_history) > self.max_history:
|
| 172 |
+
self.request_history = self.request_history[-self.max_history:]
|
| 173 |
+
|
| 174 |
+
# 生成图像
|
| 175 |
+
start_time = time.time()
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
# 在线程池中运行生成任务
|
| 179 |
+
loop = asyncio.get_event_loop()
|
| 180 |
+
images = await loop.run_in_executor(
|
| 181 |
+
self.executor,
|
| 182 |
+
self._generate_sync,
|
| 183 |
+
request
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
generation_time = time.time() - start_time
|
| 187 |
+
|
| 188 |
+
# 转换为base64
|
| 189 |
+
image_b64_list = []
|
| 190 |
+
for img_tensor in images:
|
| 191 |
+
if isinstance(img_tensor, torch.Tensor):
|
| 192 |
+
# 转换为PIL图像
|
| 193 |
+
if img_tensor.dim() == 4:
|
| 194 |
+
img_tensor = img_tensor.squeeze(0)
|
| 195 |
+
|
| 196 |
+
# 归一化到[0, 255]
|
| 197 |
+
img_tensor = torch.clamp(img_tensor * 255, 0, 255).byte()
|
| 198 |
+
|
| 199 |
+
# 转换为numpy数组
|
| 200 |
+
if img_tensor.shape[0] == 3: # CHW格式
|
| 201 |
+
img_array = img_tensor.permute(1, 2, 0).cpu().numpy()
|
| 202 |
+
else:
|
| 203 |
+
img_array = img_tensor.cpu().numpy()
|
| 204 |
+
|
| 205 |
+
# 转换为PIL图像
|
| 206 |
+
img = Image.fromarray(img_array)
|
| 207 |
+
|
| 208 |
+
# 转换为base64
|
| 209 |
+
buffered = io.BytesIO()
|
| 210 |
+
img.save(buffered, format="PNG")
|
| 211 |
+
img_b64 = base64.b64encode(buffered.getvalue()).decode()
|
| 212 |
+
image_b64_list.append(img_b64)
|
| 213 |
+
else:
|
| 214 |
+
image_b64_list.append("")
|
| 215 |
+
|
| 216 |
+
# 更新统计信息
|
| 217 |
+
self.total_requests += 1
|
| 218 |
+
self.total_images += len(images)
|
| 219 |
+
|
| 220 |
+
return ImageResponse(
|
| 221 |
+
images=image_b64_list,
|
| 222 |
+
metadata={
|
| 223 |
+
"prompt": request.prompt,
|
| 224 |
+
"negative_prompt": request.negative_prompt,
|
| 225 |
+
"width": request.width,
|
| 226 |
+
"height": request.height,
|
| 227 |
+
"num_steps": request.num_steps,
|
| 228 |
+
"guidance_scale": request.guidance_scale,
|
| 229 |
+
"seed": request.seed,
|
| 230 |
+
"sampler": request.sampler
|
| 231 |
+
},
|
| 232 |
+
request_id=request_id,
|
| 233 |
+
generation_time=generation_time
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
except Exception as e:
|
| 237 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 238 |
+
|
| 239 |
+
@self.app.post("/batch_generate")
|
| 240 |
+
async def batch_generate(batch_request: BatchRequest):
|
| 241 |
+
"""批量生成图像"""
|
| 242 |
+
request_ids = [str(uuid.uuid4()) for _ in batch_request.requests]
|
| 243 |
+
results = []
|
| 244 |
+
|
| 245 |
+
# 为每个请求生成任务
|
| 246 |
+
tasks = []
|
| 247 |
+
for req, req_id in zip(batch_request.requests, request_ids):
|
| 248 |
+
task = asyncio.create_task(self._process_single_request(req, req_id))
|
| 249 |
+
tasks.append(task)
|
| 250 |
+
|
| 251 |
+
# 等待所有任务完成
|
| 252 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 253 |
+
|
| 254 |
+
# 处理结果
|
| 255 |
+
successful_results = []
|
| 256 |
+
failed_results = []
|
| 257 |
+
|
| 258 |
+
for result in results:
|
| 259 |
+
if isinstance(result, Exception):
|
| 260 |
+
failed_results.append({"error": str(result)})
|
| 261 |
+
else:
|
| 262 |
+
successful_results.append(result)
|
| 263 |
+
|
| 264 |
+
return {
|
| 265 |
+
"successful": successful_results,
|
| 266 |
+
"failed": failed_results,
|
| 267 |
+
"total_requests": len(batch_request.requests),
|
| 268 |
+
"successful_count": len(successful_results),
|
| 269 |
+
"failed_count": len(failed_results)
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
@self.app.get("/history")
|
| 273 |
+
async def get_history(limit: int = 50):
|
| 274 |
+
"""获取请求历史"""
|
| 275 |
+
return self.request_history[-limit:]
|
| 276 |
+
|
| 277 |
+
@self.app.post("/txt2img") # Stable Diffusion兼容端点
|
| 278 |
+
async def txt2img(
|
| 279 |
+
prompt: str = Form(...),
|
| 280 |
+
negative_prompt: str = Form(""),
|
| 281 |
+
width: int = Form(512),
|
| 282 |
+
height: int = Form(512),
|
| 283 |
+
steps: int = Form(50),
|
| 284 |
+
cfg_scale: float = Form(7.5),
|
| 285 |
+
seed: int = Form(-1)
|
| 286 |
+
):
|
| 287 |
+
"""兼容Stable Diffusion WebUI的端点"""
|
| 288 |
+
request = TextToImageRequest(
|
| 289 |
+
prompt=prompt,
|
| 290 |
+
negative_prompt=negative_prompt,
|
| 291 |
+
width=width,
|
| 292 |
+
height=height,
|
| 293 |
+
num_steps=steps,
|
| 294 |
+
guidance_scale=cfg_scale,
|
| 295 |
+
seed=seed if seed != -1 else None
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
response = await generate_image(request)
|
| 299 |
+
|
| 300 |
+
# 返回第一个图像
|
| 301 |
+
if response.images:
|
| 302 |
+
return StreamingResponse(
|
| 303 |
+
io.BytesIO(base64.b64decode(response.images[0])),
|
| 304 |
+
media_type="image/png"
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
raise HTTPException(status_code=500, detail="生成失败")
|
| 308 |
+
|
| 309 |
+
@self.app.get("/stats")
|
| 310 |
+
async def get_stats():
|
| 311 |
+
"""获取统计信息"""
|
| 312 |
+
uptime = time.time() - self.start_time
|
| 313 |
+
|
| 314 |
+
return {
|
| 315 |
+
"total_requests": self.total_requests,
|
| 316 |
+
"total_images": self.total_images,
|
| 317 |
+
"requests_per_minute": self.total_requests / (uptime / 60) if uptime > 0 else 0,
|
| 318 |
+
"avg_generation_time": None, # 可以添加计时逻辑
|
| 319 |
+
"uptime_seconds": uptime,
|
| 320 |
+
"queue_size": self.task_queue.qsize(),
|
| 321 |
+
"active_workers": self.active_tasks
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
def _generate_sync(self, request: TextToImageRequest) -> List[torch.Tensor]:
|
| 325 |
+
"""同步生成图像(在单独的线程中运行)"""
|
| 326 |
+
# 更新管道采样器
|
| 327 |
+
self.pipeline.sampler = SamplerFactory.create_sampler(
|
| 328 |
+
request.sampler,
|
| 329 |
+
self.model,
|
| 330 |
+
self.diffusion,
|
| 331 |
+
request.num_steps
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# 生成图像
|
| 335 |
+
images = self.pipeline(
|
| 336 |
+
prompt=request.prompt,
|
| 337 |
+
negative_prompt=request.negative_prompt,
|
| 338 |
+
height=request.height,
|
| 339 |
+
width=request.width,
|
| 340 |
+
num_inference_steps=request.num_steps,
|
| 341 |
+
guidance_scale=request.guidance_scale,
|
| 342 |
+
num_images=request.num_images,
|
| 343 |
+
seed=request.seed,
|
| 344 |
+
progress_bar=False
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
return images
|
| 348 |
+
|
| 349 |
+
async def _process_single_request(self, request: TextToImageRequest, request_id: str) -> Dict:
|
| 350 |
+
"""处理单个请求"""
|
| 351 |
+
try:
|
| 352 |
+
# 在队列中添加任务
|
| 353 |
+
await self.task_queue.put((request, request_id))
|
| 354 |
+
|
| 355 |
+
# 更新活动任务计数
|
| 356 |
+
self.active_tasks += 1
|
| 357 |
+
|
| 358 |
+
# 处理任务
|
| 359 |
+
loop = asyncio.get_event_loop()
|
| 360 |
+
images = await loop.run_in_executor(
|
| 361 |
+
self.executor,
|
| 362 |
+
self._generate_sync,
|
| 363 |
+
request
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# 转换图像
|
| 367 |
+
image_b64_list = []
|
| 368 |
+
for img_tensor in images:
|
| 369 |
+
# 简化的转换逻辑
|
| 370 |
+
img_b64 = "placeholder" # 实际应该转换为base64
|
| 371 |
+
image_b64_list.append(img_b64)
|
| 372 |
+
|
| 373 |
+
# 更新活动任务计数
|
| 374 |
+
self.active_tasks -= 1
|
| 375 |
+
|
| 376 |
+
return {
|
| 377 |
+
"request_id": request_id,
|
| 378 |
+
"images": image_b64_list,
|
| 379 |
+
"success": True
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
self.active_tasks -= 1
|
| 384 |
+
return {
|
| 385 |
+
"request_id": request_id,
|
| 386 |
+
"error": str(e),
|
| 387 |
+
"success": False
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
def run(self):
|
| 391 |
+
"""运行API服务器"""
|
| 392 |
+
import uvicorn
|
| 393 |
+
print(f"启动Lumina API服务器在 http://{self.host}:{self.port}")
|
| 394 |
+
print(f"API文档: http://{self.host}:{self.port}/docs")
|
| 395 |
+
|
| 396 |
+
uvicorn.run(
|
| 397 |
+
self.app,
|
| 398 |
+
host=self.host,
|
| 399 |
+
port=self.port,
|
| 400 |
+
log_level="info"
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class SimpleWebUI:
|
| 405 |
+
"""简单的Web UI(使用Gradio)"""
|
| 406 |
+
def __init__(self, pipeline: TextToImagePipeline):
|
| 407 |
+
self.pipeline = pipeline
|
| 408 |
+
|
| 409 |
+
def create_interface(self):
|
| 410 |
+
"""创建Gradio界面"""
|
| 411 |
+
try:
|
| 412 |
+
import gradio as gr
|
| 413 |
+
except ImportError:
|
| 414 |
+
print("Gradio未安装,无法创建Web UI")
|
| 415 |
+
return None
|
| 416 |
+
|
| 417 |
+
def generate_image_ui(
|
| 418 |
+
prompt,
|
| 419 |
+
negative_prompt,
|
| 420 |
+
width,
|
| 421 |
+
height,
|
| 422 |
+
num_steps,
|
| 423 |
+
guidance_scale,
|
| 424 |
+
seed,
|
| 425 |
+
sampler
|
| 426 |
+
):
|
| 427 |
+
"""UI生成函数"""
|
| 428 |
+
# 设置种子
|
| 429 |
+
if seed and seed > 0:
|
| 430 |
+
torch.manual_seed(seed)
|
| 431 |
+
if torch.cuda.is_available():
|
| 432 |
+
torch.cuda.manual_seed(seed)
|
| 433 |
+
|
| 434 |
+
# 生成图像
|
| 435 |
+
images = self.pipeline(
|
| 436 |
+
prompt=prompt,
|
| 437 |
+
negative_prompt=negative_prompt,
|
| 438 |
+
height=height,
|
| 439 |
+
width=width,
|
| 440 |
+
num_inference_steps=num_steps,
|
| 441 |
+
guidance_scale=guidance_scale,
|
| 442 |
+
num_images=1,
|
| 443 |
+
seed=seed if seed > 0 else None,
|
| 444 |
+
progress_bar=False
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# 转换为PIL图像
|
| 448 |
+
if images:
|
| 449 |
+
img_tensor = images[0]
|
| 450 |
+
if isinstance(img_tensor, torch.Tensor):
|
| 451 |
+
if img_tensor.dim() == 4:
|
| 452 |
+
img_tensor = img_tensor.squeeze(0)
|
| 453 |
+
|
| 454 |
+
# 归一化到[0, 255]
|
| 455 |
+
img_tensor = torch.clamp(img_tensor * 255, 0, 255).byte()
|
| 456 |
+
|
| 457 |
+
# 转换为numpy数组
|
| 458 |
+
if img_tensor.shape[0] == 3: # CHW格式
|
| 459 |
+
img_array = img_tensor.permute(1, 2, 0).cpu().numpy()
|
| 460 |
+
else:
|
| 461 |
+
img_array = img_tensor.cpu().numpy()
|
| 462 |
+
|
| 463 |
+
# 转换为PIL图像
|
| 464 |
+
from PIL import Image
|
| 465 |
+
img = Image.fromarray(img_array)
|
| 466 |
+
return img
|
| 467 |
+
|
| 468 |
+
return None
|
| 469 |
+
|
| 470 |
+
# 创建界面
|
| 471 |
+
with gr.Blocks(title="Lumina Image Generator") as interface:
|
| 472 |
+
gr.Markdown("# 🎨 Lumina - 轻量级图像生成")
|
| 473 |
+
gr.Markdown("基于扩散模型的文本到图像生成系统")
|
| 474 |
+
|
| 475 |
+
with gr.Row():
|
| 476 |
+
with gr.Column():
|
| 477 |
+
prompt = gr.Textbox(
|
| 478 |
+
label="提示词",
|
| 479 |
+
placeholder="输入描述图像的文本...",
|
| 480 |
+
lines=3
|
| 481 |
+
)
|
| 482 |
+
negative_prompt = gr.Textbox(
|
| 483 |
+
label="负面提示词",
|
| 484 |
+
placeholder="不想在图像中出现的内容...",
|
| 485 |
+
lines=2
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
with gr.Row():
|
| 489 |
+
width = gr.Slider(
|
| 490 |
+
minimum=256,
|
| 491 |
+
maximum=1024,
|
| 492 |
+
value=512,
|
| 493 |
+
step=64,
|
| 494 |
+
label="宽度"
|
| 495 |
+
)
|
| 496 |
+
height = gr.Slider(
|
| 497 |
+
minimum=256,
|
| 498 |
+
maximum=1024,
|
| 499 |
+
value=512,
|
| 500 |
+
step=64,
|
| 501 |
+
label="高度"
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
with gr.Row():
|
| 505 |
+
num_steps = gr.Slider(
|
| 506 |
+
minimum=1,
|
| 507 |
+
maximum=100,
|
| 508 |
+
value=30,
|
| 509 |
+
step=1,
|
| 510 |
+
label="采样步数"
|
| 511 |
+
)
|
| 512 |
+
guidance_scale = gr.Slider(
|
| 513 |
+
minimum=1.0,
|
| 514 |
+
maximum=20.0,
|
| 515 |
+
value=7.5,
|
| 516 |
+
step=0.5,
|
| 517 |
+
label="引导强度"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
with gr.Row():
|
| 521 |
+
seed = gr.Number(
|
| 522 |
+
value=-1,
|
| 523 |
+
label="随机种子 (-1为随机)"
|
| 524 |
+
)
|
| 525 |
+
sampler = gr.Dropdown(
|
| 526 |
+
choices=["ddim", "dpm", "lcm"],
|
| 527 |
+
value="ddim",
|
| 528 |
+
label="采样器"
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
generate_btn = gr.Button("生成图像", variant="primary")
|
| 532 |
+
|
| 533 |
+
with gr.Column():
|
| 534 |
+
output_image = gr.Image(
|
| 535 |
+
label="生成的图像",
|
| 536 |
+
type="pil"
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# 示例
|
| 540 |
+
gr.Markdown("### 示例提示词")
|
| 541 |
+
examples = gr.Examples(
|
| 542 |
+
examples=[
|
| 543 |
+
["A beautiful sunset over mountains, digital art", "", 512, 512, 30, 7.5, -1],
|
| 544 |
+
["A cute cat playing with a ball of yarn", "blurry, deformed", 512, 512, 25, 8.0, -1],
|
| 545 |
+
["An astronaut riding a horse on Mars", "cartoon, anime", 512, 512, 40, 7.0, -1]
|
| 546 |
+
],
|
| 547 |
+
inputs=[prompt, negative_prompt, width, height, num_steps, guidance_scale, seed]
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# 事件处理
|
| 551 |
+
generate_btn.click(
|
| 552 |
+
fn=generate_image_ui,
|
| 553 |
+
inputs=[prompt, negative_prompt, width, height, num_steps, guidance_scale, seed, sampler],
|
| 554 |
+
outputs=output_image
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
return interface
|
| 558 |
+
|
| 559 |
+
def launch(self, share: bool = False, server_name: str = "0.0.0.0", server_port: int = 7860):
|
| 560 |
+
"""启动Web UI"""
|
| 561 |
+
interface = self.create_interface()
|
| 562 |
+
if interface:
|
| 563 |
+
interface.launch(
|
| 564 |
+
share=share,
|
| 565 |
+
server_name=server_name,
|
| 566 |
+
server_port=server_port
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def create_api_server(config: dict, model, diffusion, text_encoder, vae_decoder=None):
|
| 571 |
+
"""创建API服务器"""
|
| 572 |
+
# 确定主机和端口
|
| 573 |
+
host = config.get('host', '0.0.0.0')
|
| 574 |
+
port = config.get('port', 8000)
|
| 575 |
+
|
| 576 |
+
# 创建API服务器
|
| 577 |
+
api_server = LuminaAPI(
|
| 578 |
+
model=model,
|
| 579 |
+
diffusion=diffusion,
|
| 580 |
+
text_encoder=text_encoder,
|
| 581 |
+
vae_decoder=vae_decoder,
|
| 582 |
+
host=host,
|
| 583 |
+
port=port,
|
| 584 |
+
max_queue_size=config.get('max_queue_size', 100),
|
| 585 |
+
max_workers=config.get('max_workers', 2)
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
return api_server
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
def test_api():
|
| 592 |
+
"""测试API"""
|
| 593 |
+
import torch.nn as nn
|
| 594 |
+
|
| 595 |
+
# 创建模拟组件
|
| 596 |
+
class MockModel(nn.Module):
|
| 597 |
+
def forward(self, x, t, context):
|
| 598 |
+
return torch.randn_like(x)
|
| 599 |
+
|
| 600 |
+
class MockDiffusion:
|
| 601 |
+
pass
|
| 602 |
+
|
| 603 |
+
class MockTextEncoder:
|
| 604 |
+
def encode(self, texts):
|
| 605 |
+
return torch.randn(len(texts), 77, 768)
|
| 606 |
+
|
| 607 |
+
model = MockModel()
|
| 608 |
+
diffusion = MockDiffusion()
|
| 609 |
+
text_encoder = MockTextEncoder()
|
| 610 |
+
|
| 611 |
+
# 创建API服务器
|
| 612 |
+
api = LuminaAPI(
|
| 613 |
+
model=model,
|
| 614 |
+
diffusion=diffusion,
|
| 615 |
+
text_encoder=text_encoder
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
print("API服务器创建成功")
|
| 619 |
+
print("端点:")
|
| 620 |
+
print(" POST /generate - 生成图像")
|
| 621 |
+
print(" GET /health - 健康检查")
|
| 622 |
+
print(" GET /status - 状态信息")
|
| 623 |
+
|
| 624 |
+
return api
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
if __name__ == '__main__':
|
| 628 |
+
# 测试API
|
| 629 |
+
api = test_api()
|
| 630 |
+
|
| 631 |
+
# 注意:实际运行需要调用 api.run()
|
src/inference/optimization.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Dict, Any
|
| 4 |
+
import onnx
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ModelOptimizer:
|
| 10 |
+
"""模型优化器"""
|
| 11 |
+
def __init__(self, model: nn.Module, device: str = 'cuda'):
|
| 12 |
+
self.model = model
|
| 13 |
+
self.device = device
|
| 14 |
+
self.optimized_model = None
|
| 15 |
+
|
| 16 |
+
def optimize_for_inference(self, use_jit: bool = True, use_cuda_graph: bool = False):
|
| 17 |
+
"""优化模型用于推理"""
|
| 18 |
+
self.model.eval()
|
| 19 |
+
|
| 20 |
+
# 应用一系列优化
|
| 21 |
+
if use_jit:
|
| 22 |
+
self._jit_compile()
|
| 23 |
+
|
| 24 |
+
if use_cuda_graph and torch.cuda.is_available():
|
| 25 |
+
self._capture_cuda_graph()
|
| 26 |
+
|
| 27 |
+
# 应用其他优化
|
| 28 |
+
self._apply_inference_optimizations()
|
| 29 |
+
|
| 30 |
+
return self.optimized_model or self.model
|
| 31 |
+
|
| 32 |
+
def _jit_compile(self):
|
| 33 |
+
"""使用TorchScript编译模型"""
|
| 34 |
+
try:
|
| 35 |
+
# 创建示例输入
|
| 36 |
+
example_input = torch.randn(1, 4, 64, 64, device=self.device)
|
| 37 |
+
example_timestep = torch.tensor([500], device=self.device)
|
| 38 |
+
example_context = torch.randn(1, 77, 768, device=self.device)
|
| 39 |
+
|
| 40 |
+
# 脚本编译
|
| 41 |
+
scripted_model = torch.jit.trace(
|
| 42 |
+
self.model,
|
| 43 |
+
(example_input, example_timestep, example_context),
|
| 44 |
+
check_trace=False
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
self.optimized_model = scripted_model
|
| 48 |
+
print("模型已使用TorchScript编译")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"TorchScript编译失败: {e}")
|
| 51 |
+
|
| 52 |
+
def _capture_cuda_graph(self):
|
| 53 |
+
"""捕获CUDA图(用于重复推理)"""
|
| 54 |
+
if not torch.cuda.is_available():
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
# 创建静态输入
|
| 58 |
+
static_input = torch.randn(1, 4, 64, 64, device='cuda', dtype=torch.float16)
|
| 59 |
+
static_timestep = torch.tensor([500], device='cuda')
|
| 60 |
+
static_context = torch.randn(1, 77, 768, device='cuda', dtype=torch.float16)
|
| 61 |
+
|
| 62 |
+
# 预热
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
for _ in range(3):
|
| 65 |
+
_ = self.model(static_input, static_timestep, static_context)
|
| 66 |
+
|
| 67 |
+
# 捕获图
|
| 68 |
+
graph = torch.cuda.CUDAGraph()
|
| 69 |
+
with torch.cuda.graph(graph):
|
| 70 |
+
static_output = self.model(static_input, static_timestep, static_context)
|
| 71 |
+
|
| 72 |
+
# 创建包装函数
|
| 73 |
+
def graph_executor(input_tensor, timestep, context):
|
| 74 |
+
static_input.copy_(input_tensor)
|
| 75 |
+
static_timestep.copy_(timestep)
|
| 76 |
+
static_context.copy_(context)
|
| 77 |
+
graph.replay()
|
| 78 |
+
return static_output.clone()
|
| 79 |
+
|
| 80 |
+
self.optimized_model = graph_executor
|
| 81 |
+
print("已捕获CUDA图")
|
| 82 |
+
|
| 83 |
+
def _apply_inference_optimizations(self):
|
| 84 |
+
"""应用推理优化"""
|
| 85 |
+
# 设置为评估模式
|
| 86 |
+
self.model.eval()
|
| 87 |
+
|
| 88 |
+
# 融合操作(如果可用)
|
| 89 |
+
if hasattr(torch, 'compile'):
|
| 90 |
+
try:
|
| 91 |
+
self.model = torch.compile(self.model, mode='max-autotune')
|
| 92 |
+
print("模型已使用torch.compile优化")
|
| 93 |
+
except:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# 使用半精度
|
| 97 |
+
if self.device == 'cuda':
|
| 98 |
+
self.model.half()
|
| 99 |
+
print("模型已转换为半精度")
|
| 100 |
+
|
| 101 |
+
def quantize(self, quantization_mode: str = 'dynamic'):
|
| 102 |
+
"""量化模型"""
|
| 103 |
+
if quantization_mode == 'dynamic':
|
| 104 |
+
# 动态量化
|
| 105 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
| 106 |
+
self.model,
|
| 107 |
+
{nn.Linear, nn.Conv2d},
|
| 108 |
+
dtype=torch.qint8
|
| 109 |
+
)
|
| 110 |
+
self.optimized_model = quantized_model
|
| 111 |
+
print("模型已动态量化")
|
| 112 |
+
|
| 113 |
+
elif quantization_mode == 'static':
|
| 114 |
+
# 静态量化需要校准数据
|
| 115 |
+
print("静态量化需要校准数据,暂未实现")
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError(f"未知的量化模式: {quantization_mode}")
|
| 119 |
+
|
| 120 |
+
def prune(self, pruning_rate: float = 0.2):
|
| 121 |
+
"""剪枝模型"""
|
| 122 |
+
from torch.nn.utils import prune
|
| 123 |
+
|
| 124 |
+
# 对线性层和卷积层进行剪枝
|
| 125 |
+
for name, module in self.model.named_modules():
|
| 126 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 127 |
+
prune.l1_unstructured(module, name='weight', amount=pruning_rate)
|
| 128 |
+
prune.remove(module, 'weight')
|
| 129 |
+
|
| 130 |
+
print(f"模型已剪枝,剪枝率: {pruning_rate}")
|
| 131 |
+
|
| 132 |
+
def get_model_size(self) -> Dict[str, float]:
|
| 133 |
+
"""获取模型大小"""
|
| 134 |
+
# 计算参数量
|
| 135 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 136 |
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 137 |
+
|
| 138 |
+
# 计算模型大小(MB)
|
| 139 |
+
param_size = 0
|
| 140 |
+
for param in self.model.parameters():
|
| 141 |
+
param_size += param.nelement() * param.element_size()
|
| 142 |
+
|
| 143 |
+
buffer_size = 0
|
| 144 |
+
for buffer in self.model.buffers():
|
| 145 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
| 146 |
+
|
| 147 |
+
size_mb = (param_size + buffer_size) / 1024**2
|
| 148 |
+
|
| 149 |
+
return {
|
| 150 |
+
'total_params': total_params,
|
| 151 |
+
'trainable_params': trainable_params,
|
| 152 |
+
'size_mb': size_mb
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class ONNXExporter:
|
| 157 |
+
"""ONNX导出器"""
|
| 158 |
+
def __init__(self, model: nn.Module):
|
| 159 |
+
self.model = model
|
| 160 |
+
|
| 161 |
+
def export(
|
| 162 |
+
self,
|
| 163 |
+
output_path: str,
|
| 164 |
+
input_shape: tuple = (1, 4, 64, 64),
|
| 165 |
+
opset_version: int = 14,
|
| 166 |
+
dynamic_axes: Optional[Dict] = None
|
| 167 |
+
):
|
| 168 |
+
"""导出为ONNX格式"""
|
| 169 |
+
# 设置为评估模式
|
| 170 |
+
self.model.eval()
|
| 171 |
+
|
| 172 |
+
# 创建示例输入
|
| 173 |
+
dummy_input = torch.randn(*input_shape)
|
| 174 |
+
dummy_timestep = torch.tensor([500])
|
| 175 |
+
dummy_context = torch.randn(1, 77, 768)
|
| 176 |
+
|
| 177 |
+
# 默认动态轴
|
| 178 |
+
if dynamic_axes is None:
|
| 179 |
+
dynamic_axes = {
|
| 180 |
+
'input': {0: 'batch_size'},
|
| 181 |
+
'timestep': {0: 'batch_size'},
|
| 182 |
+
'context': {0: 'batch_size'},
|
| 183 |
+
'output': {0: 'batch_size'}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
# 导出
|
| 187 |
+
torch.onnx.export(
|
| 188 |
+
self.model,
|
| 189 |
+
(dummy_input, dummy_timestep, dummy_context),
|
| 190 |
+
output_path,
|
| 191 |
+
input_names=['input', 'timestep', 'context'],
|
| 192 |
+
output_names=['output'],
|
| 193 |
+
dynamic_axes=dynamic_axes,
|
| 194 |
+
opset_version=opset_version,
|
| 195 |
+
do_constant_folding=True,
|
| 196 |
+
verbose=False
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
print(f"模型已导出为ONNX: {output_path}")
|
| 200 |
+
|
| 201 |
+
# 验证ONNX模型
|
| 202 |
+
self._validate_onnx(output_path)
|
| 203 |
+
|
| 204 |
+
def _validate_onnx(self, onnx_path: str):
|
| 205 |
+
"""验证ONNX模型"""
|
| 206 |
+
try:
|
| 207 |
+
onnx_model = onnx.load(onnx_path)
|
| 208 |
+
onnx.checker.check_model(onnx_model)
|
| 209 |
+
print("ONNX模型验证成功")
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"ONNX模型验证失败: {e}")
|
| 212 |
+
|
| 213 |
+
def optimize_onnx(self, onnx_path: str, optimized_path: str):
|
| 214 |
+
"""优化ONNX模型"""
|
| 215 |
+
try:
|
| 216 |
+
# 使用ONNX Runtime优化
|
| 217 |
+
sess_options = ort.SessionOptions()
|
| 218 |
+
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 219 |
+
|
| 220 |
+
# 创建优化会话
|
| 221 |
+
ort_session = ort.InferenceSession(onnx_path, sess_options)
|
| 222 |
+
|
| 223 |
+
# 保存优化后的模型
|
| 224 |
+
optimized_model = ort_session.get_model()
|
| 225 |
+
onnx.save(optimized_model, optimized_path)
|
| 226 |
+
|
| 227 |
+
print(f"ONNX模型已优化: {optimized_path}")
|
| 228 |
+
except Exception as e:
|
| 229 |
+
print(f"ONNX优化失败: {e}")
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MemoryEfficientInference:
|
| 233 |
+
"""内存高效推理"""
|
| 234 |
+
def __init__(self, model: nn.Module, chunk_size: int = 32):
|
| 235 |
+
self.model = model
|
| 236 |
+
self.chunk_size = chunk_size
|
| 237 |
+
|
| 238 |
+
def chunked_inference(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
| 239 |
+
"""分块推理,减少内存使用"""
|
| 240 |
+
B, C, H, W = x.shape
|
| 241 |
+
output = torch.zeros_like(x)
|
| 242 |
+
|
| 243 |
+
# 分块处理
|
| 244 |
+
for i in range(0, H, self.chunk_size):
|
| 245 |
+
for j in range(0, W, self.chunk_size):
|
| 246 |
+
# 提取块
|
| 247 |
+
chunk = x[:, :, i:i+self.chunk_size, j:j+self.chunk_size]
|
| 248 |
+
|
| 249 |
+
# 推理
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
chunk_output = self.model(chunk, t, context)
|
| 252 |
+
|
| 253 |
+
# 存储结果
|
| 254 |
+
output[:, :, i:i+self.chunk_size, j:j+self.chunk_size] = chunk_output
|
| 255 |
+
|
| 256 |
+
return output
|
| 257 |
+
|
| 258 |
+
def tiled_inference(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor, tile_size: int = 512) -> torch.Tensor:
|
| 259 |
+
"""平铺推理,用于大图像"""
|
| 260 |
+
B, C, H, W = x.shape
|
| 261 |
+
|
| 262 |
+
# 如果图像不大,直接推理
|
| 263 |
+
if H <= tile_size and W <= tile_size:
|
| 264 |
+
with torch.no_grad():
|
| 265 |
+
return self.model(x, t, context)
|
| 266 |
+
|
| 267 |
+
# 计算平铺数量
|
| 268 |
+
n_tiles_h = (H + tile_size - 1) // tile_size
|
| 269 |
+
n_tiles_w = (W + tile_size - 1) // tile_size
|
| 270 |
+
|
| 271 |
+
output = torch.zeros_like(x)
|
| 272 |
+
|
| 273 |
+
# 处理每个平铺
|
| 274 |
+
for i in range(n_tiles_h):
|
| 275 |
+
for j in range(n_tiles_w):
|
| 276 |
+
# 计算平铺位置
|
| 277 |
+
h_start = i * tile_size
|
| 278 |
+
w_start = j * tile_size
|
| 279 |
+
h_end = min(h_start + tile_size, H)
|
| 280 |
+
w_end = min(w_start + tile_size, W)
|
| 281 |
+
|
| 282 |
+
# 提取平铺
|
| 283 |
+
tile = x[:, :, h_start:h_end, w_start:w_end]
|
| 284 |
+
|
| 285 |
+
# 推理
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
tile_output = self.model(tile, t, context)
|
| 288 |
+
|
| 289 |
+
# 存储结果
|
| 290 |
+
output[:, :, h_start:h_end, w_start:w_end] = tile_output
|
| 291 |
+
|
| 292 |
+
return output
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class InferenceBenchmark:
|
| 296 |
+
"""推理基准测试"""
|
| 297 |
+
def __init__(self, model: nn.Module, device: str = 'cuda'):
|
| 298 |
+
self.model = model
|
| 299 |
+
self.device = device
|
| 300 |
+
|
| 301 |
+
def benchmark(
|
| 302 |
+
self,
|
| 303 |
+
input_shape: tuple = (1, 4, 64, 64),
|
| 304 |
+
num_iterations: int = 100,
|
| 305 |
+
warmup_iterations: int = 10
|
| 306 |
+
) -> Dict[str, float]:
|
| 307 |
+
"""运行基准测试"""
|
| 308 |
+
# 准备输入
|
| 309 |
+
x = torch.randn(*input_shape, device=self.device)
|
| 310 |
+
t = torch.tensor([500], device=self.device)
|
| 311 |
+
context = torch.randn(1, 77, 768, device=self.device)
|
| 312 |
+
|
| 313 |
+
# 预热
|
| 314 |
+
print("预热...")
|
| 315 |
+
with torch.no_grad():
|
| 316 |
+
for _ in range(warmup_iterations):
|
| 317 |
+
_ = self.model(x, t, context)
|
| 318 |
+
|
| 319 |
+
# 同步
|
| 320 |
+
if torch.cuda.is_available():
|
| 321 |
+
torch.cuda.synchronize()
|
| 322 |
+
|
| 323 |
+
# 基准测试
|
| 324 |
+
print("运行基准测试...")
|
| 325 |
+
import time
|
| 326 |
+
|
| 327 |
+
times = []
|
| 328 |
+
|
| 329 |
+
for i in range(num_iterations):
|
| 330 |
+
start_time = time.time()
|
| 331 |
+
|
| 332 |
+
with torch.no_grad():
|
| 333 |
+
_ = self.model(x, t, context)
|
| 334 |
+
|
| 335 |
+
if torch.cuda.is_available():
|
| 336 |
+
torch.cuda.synchronize()
|
| 337 |
+
|
| 338 |
+
end_time = time.time()
|
| 339 |
+
times.append(end_time - start_time)
|
| 340 |
+
|
| 341 |
+
# 统计
|
| 342 |
+
times = torch.tensor(times)
|
| 343 |
+
|
| 344 |
+
stats = {
|
| 345 |
+
'mean_ms': times.mean().item() * 1000,
|
| 346 |
+
'std_ms': times.std().item() * 1000,
|
| 347 |
+
'min_ms': times.min().item() * 1000,
|
| 348 |
+
'max_ms': times.max().item() * 1000,
|
| 349 |
+
'fps': 1 / times.mean().item(),
|
| 350 |
+
'num_iterations': num_iterations
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
# 打印结果
|
| 354 |
+
print("\n" + "="*50)
|
| 355 |
+
print("推理基准测试结果:")
|
| 356 |
+
print(f"平均推理时间: {stats['mean_ms']:.2f} ms")
|
| 357 |
+
print(f"标准差: {stats['std_ms']:.2f} ms")
|
| 358 |
+
print(f"最小推理时间: {stats['min_ms']:.2f} ms")
|
| 359 |
+
print(f"最大推理时间: {stats['max_ms']:.2f} ms")
|
| 360 |
+
print(f"FPS: {stats['fps']:.2f}")
|
| 361 |
+
print("="*50)
|
| 362 |
+
|
| 363 |
+
return stats
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def optimize_model_for_p4(model: nn.Module) -> nn.Module:
|
| 367 |
+
"""为P4优化模型"""
|
| 368 |
+
optimizer = ModelOptimizer(model)
|
| 369 |
+
|
| 370 |
+
# 获取模型大小
|
| 371 |
+
size_info = optimizer.get_model_size()
|
| 372 |
+
print(f"优化前模型大小: {size_info['size_mb']:.2f} MB")
|
| 373 |
+
|
| 374 |
+
# 应用优化
|
| 375 |
+
optimized_model = optimizer.optimize_for_inference(
|
| 376 |
+
use_jit=True,
|
| 377 |
+
use_cuda_graph=False # P4可能不支持
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# 量化(可选)
|
| 381 |
+
if size_info['size_mb'] > 500: # 如果模型大于500MB,进行量化
|
| 382 |
+
optimizer.quantize('dynamic')
|
| 383 |
+
|
| 384 |
+
# 获取优化后的模型大小
|
| 385 |
+
size_info_after = optimizer.get_model_size()
|
| 386 |
+
print(f"优化后模型大小: {size_info_after['size_mb']:.2f} MB")
|
| 387 |
+
print(f"压缩比: {size_info['size_mb'] / size_info_after['size_mb']:.2f}x")
|
| 388 |
+
|
| 389 |
+
return optimized_model
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def test_optimization():
|
| 393 |
+
"""测试优化"""
|
| 394 |
+
import torch.nn as nn
|
| 395 |
+
|
| 396 |
+
# 创建模拟模型
|
| 397 |
+
class MockModel(nn.Module):
|
| 398 |
+
def __init__(self):
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.conv1 = nn.Conv2d(4, 64, 3, padding=1)
|
| 401 |
+
self.conv2 = nn.Conv2d(64, 4, 3, padding=1)
|
| 402 |
+
|
| 403 |
+
def forward(self, x, t, context):
|
| 404 |
+
x = self.conv1(x)
|
| 405 |
+
x = nn.functional.relu(x)
|
| 406 |
+
x = self.conv2(x)
|
| 407 |
+
return x
|
| 408 |
+
|
| 409 |
+
model = MockModel()
|
| 410 |
+
|
| 411 |
+
# 测试优化器
|
| 412 |
+
optimizer = ModelOptimizer(model)
|
| 413 |
+
optimized_model = optimizer.optimize_for_inference()
|
| 414 |
+
|
| 415 |
+
# 测试基准测试
|
| 416 |
+
benchmark = InferenceBenchmark(model)
|
| 417 |
+
stats = benchmark.benchmark(num_iterations=10)
|
| 418 |
+
|
| 419 |
+
# 测试ONNX导出
|
| 420 |
+
exporter = ONNXExporter(model)
|
| 421 |
+
exporter.export('./test_model.onnx', input_shape=(1, 4, 64, 64))
|
| 422 |
+
|
| 423 |
+
return optimized_model, stats
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
if __name__ == '__main__':
|
| 427 |
+
optimized_model, stats = test_optimization()
|
src/inference/sampler.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Tuple, List, Union
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DDIMSampler:
|
| 10 |
+
"""DDIM采样器"""
|
| 11 |
+
def __init__(self, model: nn.Module, diffusion, num_inference_steps: int = 50):
|
| 12 |
+
self.model = model
|
| 13 |
+
self.diffusion = diffusion
|
| 14 |
+
self.num_inference_steps = num_inference_steps
|
| 15 |
+
|
| 16 |
+
# 设置时间步
|
| 17 |
+
self.set_timesteps(num_inference_steps)
|
| 18 |
+
|
| 19 |
+
def set_timesteps(self, num_inference_steps: int):
|
| 20 |
+
"""设置推理时间步"""
|
| 21 |
+
self.num_inference_steps = num_inference_steps
|
| 22 |
+
|
| 23 |
+
# 选择时间步
|
| 24 |
+
if self.diffusion.num_train_timesteps == num_inference_steps:
|
| 25 |
+
timesteps = np.arange(0, self.diffusion.num_train_timesteps)
|
| 26 |
+
else:
|
| 27 |
+
step_ratio = self.diffusion.num_train_timesteps // num_inference_steps
|
| 28 |
+
timesteps = np.arange(0, self.diffusion.num_train_timesteps, step_ratio)
|
| 29 |
+
|
| 30 |
+
self.timesteps = torch.from_numpy(timesteps).long().flip(0)
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def step(
|
| 34 |
+
self,
|
| 35 |
+
model_output: torch.Tensor,
|
| 36 |
+
timestep: int,
|
| 37 |
+
sample: torch.Tensor,
|
| 38 |
+
eta: float = 0.0,
|
| 39 |
+
use_clipped_model_output: bool = False
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
"""DDIM单步采样"""
|
| 42 |
+
# 获取当前和上一个时间步
|
| 43 |
+
prev_timestep = timestep - self.diffusion.num_train_timesteps // self.num_inference_steps
|
| 44 |
+
|
| 45 |
+
# 提取alpha参数
|
| 46 |
+
alpha_prod_t = self.diffusion.extract(self.diffusion.alphas_cumprod, timestep, sample.shape)
|
| 47 |
+
alpha_prod_t_prev = self.diffusion.extract(
|
| 48 |
+
self.diffusion.alphas_cumprod,
|
| 49 |
+
prev_timestep,
|
| 50 |
+
sample.shape
|
| 51 |
+
) if prev_timestep >= 0 else torch.ones_like(alpha_prod_t)
|
| 52 |
+
|
| 53 |
+
# 根据预测类型处理模型输出
|
| 54 |
+
if self.diffusion.prediction_type == "epsilon":
|
| 55 |
+
pred_original_sample = (sample - (1 - alpha_prod_t) ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
| 56 |
+
pred_epsilon = model_output
|
| 57 |
+
elif self.diffusion.prediction_type == "sample":
|
| 58 |
+
pred_original_sample = model_output
|
| 59 |
+
pred_epsilon = (sample - alpha_prod_t ** 0.5 * pred_original_sample) / (1 - alpha_prod_t) ** 0.5
|
| 60 |
+
elif self.diffusion.prediction_type == "v_prediction":
|
| 61 |
+
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (1 - alpha_prod_t) ** 0.5 * model_output
|
| 62 |
+
pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (1 - alpha_prod_t) ** 0.5 * sample
|
| 63 |
+
else:
|
| 64 |
+
raise ValueError(f"Unsupported prediction type: {self.diffusion.prediction_type}")
|
| 65 |
+
|
| 66 |
+
# 裁剪预测的原始样本
|
| 67 |
+
if use_clipped_model_output:
|
| 68 |
+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
| 69 |
+
|
| 70 |
+
# 计算x_t-1的方差
|
| 71 |
+
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 72 |
+
std_dev_t = eta * variance ** 0.5
|
| 73 |
+
|
| 74 |
+
# 当eta > 0时,使用随机采样
|
| 75 |
+
if eta > 0:
|
| 76 |
+
noise = torch.randn_like(model_output)
|
| 77 |
+
variance = std_dev_t ** 2
|
| 78 |
+
else:
|
| 79 |
+
noise = 0
|
| 80 |
+
variance = 0
|
| 81 |
+
|
| 82 |
+
# 计算x_t-1的均值
|
| 83 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - variance) ** 0.5 * pred_epsilon
|
| 84 |
+
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
|
| 85 |
+
|
| 86 |
+
# 添加噪声
|
| 87 |
+
if eta > 0:
|
| 88 |
+
prev_sample = prev_sample + std_dev_t * noise
|
| 89 |
+
|
| 90 |
+
return prev_sample
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def sample(
|
| 94 |
+
self,
|
| 95 |
+
prompt_embeds: torch.Tensor,
|
| 96 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 97 |
+
height: int = 512,
|
| 98 |
+
width: int = 512,
|
| 99 |
+
num_images_per_prompt: int = 1,
|
| 100 |
+
guidance_scale: float = 7.5,
|
| 101 |
+
eta: float = 0.0,
|
| 102 |
+
generator: Optional[torch.Generator] = None,
|
| 103 |
+
progress_bar: bool = True
|
| 104 |
+
) -> torch.Tensor:
|
| 105 |
+
"""生成样本"""
|
| 106 |
+
# 设置模型为评估模式
|
| 107 |
+
self.model.eval()
|
| 108 |
+
|
| 109 |
+
# 批次大小
|
| 110 |
+
batch_size = prompt_embeds.shape[0]
|
| 111 |
+
|
| 112 |
+
# 初始化潜在表示
|
| 113 |
+
latents = torch.randn(
|
| 114 |
+
(batch_size * num_images_per_prompt, self.model.in_channels, height // 8, width // 8),
|
| 115 |
+
device=prompt_embeds.device,
|
| 116 |
+
generator=generator
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# 准备额外的条件
|
| 120 |
+
if negative_prompt_embeds is not None:
|
| 121 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 122 |
+
|
| 123 |
+
# 分类器自由引导的缩放因子
|
| 124 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 125 |
+
if do_classifier_free_guidance:
|
| 126 |
+
latents = torch.cat([latents] * 2)
|
| 127 |
+
|
| 128 |
+
# 采样循环
|
| 129 |
+
timesteps = self.timesteps.to(latents.device)
|
| 130 |
+
|
| 131 |
+
if progress_bar:
|
| 132 |
+
timesteps = tqdm(timesteps, desc="DDIM Sampling")
|
| 133 |
+
|
| 134 |
+
for i, t in enumerate(timesteps):
|
| 135 |
+
# 扩展潜在表示以匹配引导的批次大小
|
| 136 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 137 |
+
latent_model_input = self.diffusion.scale_model_input(latent_model_input, t)
|
| 138 |
+
|
| 139 |
+
# 预测噪声
|
| 140 |
+
noise_pred = self.model(latent_model_input, t, prompt_embeds)
|
| 141 |
+
|
| 142 |
+
# 执行分类器自由引导
|
| 143 |
+
if do_classifier_free_guidance:
|
| 144 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 145 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 146 |
+
|
| 147 |
+
# 计算上一个样本
|
| 148 |
+
latents = self.step(noise_pred, t, latents, eta)
|
| 149 |
+
|
| 150 |
+
return latents
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class DPMSampler:
|
| 154 |
+
"""DPM采样器(更快)"""
|
| 155 |
+
def __init__(self, model: nn.Module, diffusion, num_inference_steps: int = 20):
|
| 156 |
+
self.model = model
|
| 157 |
+
self.diffusion = diffusion
|
| 158 |
+
self.num_inference_steps = num_inference_steps
|
| 159 |
+
|
| 160 |
+
@torch.no_grad()
|
| 161 |
+
def sample(
|
| 162 |
+
self,
|
| 163 |
+
prompt_embeds: torch.Tensor,
|
| 164 |
+
height: int = 512,
|
| 165 |
+
width: int = 512,
|
| 166 |
+
guidance_scale: float = 7.5,
|
| 167 |
+
progress_bar: bool = True
|
| 168 |
+
) -> torch.Tensor:
|
| 169 |
+
"""DPM采样"""
|
| 170 |
+
self.model.eval()
|
| 171 |
+
|
| 172 |
+
# 初始化潜在表示
|
| 173 |
+
latents = torch.randn(
|
| 174 |
+
(1, self.model.in_channels, height // 8, width // 8),
|
| 175 |
+
device=prompt_embeds.device
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# 简化的DPM采样
|
| 179 |
+
timesteps = torch.linspace(1, 0, self.num_inference_steps + 1, device=latents.device)
|
| 180 |
+
|
| 181 |
+
if progress_bar:
|
| 182 |
+
timesteps_iter = tqdm(enumerate(timesteps[:-1]), total=len(timesteps)-1, desc="DPM Sampling")
|
| 183 |
+
else:
|
| 184 |
+
timesteps_iter = enumerate(timesteps[:-1])
|
| 185 |
+
|
| 186 |
+
for i, t in timesteps_iter:
|
| 187 |
+
# 预测噪声
|
| 188 |
+
noise_pred = self.model(latents, t.unsqueeze(0) * 999, prompt_embeds)
|
| 189 |
+
|
| 190 |
+
# 应用引导
|
| 191 |
+
if guidance_scale > 1.0:
|
| 192 |
+
# 简单引导
|
| 193 |
+
noise_pred = noise_pred * guidance_scale
|
| 194 |
+
|
| 195 |
+
# DPM更新步骤
|
| 196 |
+
dt = timesteps[i + 1] - t
|
| 197 |
+
latents = latents + dt * noise_pred
|
| 198 |
+
|
| 199 |
+
return latents
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class LCMSampler:
|
| 203 |
+
"""LCM(潜在一致性模型)采样器,极快"""
|
| 204 |
+
def __init__(self, model: nn.Module, diffusion, num_inference_steps: int = 4):
|
| 205 |
+
self.model = model
|
| 206 |
+
self.diffusion = diffusion
|
| 207 |
+
self.num_inference_steps = num_inference_steps
|
| 208 |
+
|
| 209 |
+
# LCM特定的参数
|
| 210 |
+
self.c_skip = 1.0
|
| 211 |
+
self.c_out = 1.0
|
| 212 |
+
self.c_in = 1.0
|
| 213 |
+
self.c_noise = 1.0
|
| 214 |
+
|
| 215 |
+
@torch.no_grad()
|
| 216 |
+
def sample(
|
| 217 |
+
self,
|
| 218 |
+
prompt_embeds: torch.Tensor,
|
| 219 |
+
height: int = 512,
|
| 220 |
+
width: int = 512,
|
| 221 |
+
guidance_scale: float = 7.5,
|
| 222 |
+
progress_bar: bool = True
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
"""LCM采样(极快,只需要4-8步)"""
|
| 225 |
+
self.model.eval()
|
| 226 |
+
|
| 227 |
+
# 初始化潜在表示
|
| 228 |
+
latents = torch.randn(
|
| 229 |
+
(1, self.model.in_channels, height // 8, width // 8),
|
| 230 |
+
device=prompt_embeds.device
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# LCM采样循环
|
| 234 |
+
timesteps = torch.linspace(1, 0, self.num_inference_steps + 1, device=latents.device)
|
| 235 |
+
|
| 236 |
+
if progress_bar:
|
| 237 |
+
timesteps_iter = tqdm(enumerate(timesteps[:-1]), total=len(timesteps)-1, desc="LCM Sampling")
|
| 238 |
+
else:
|
| 239 |
+
timesteps_iter = enumerate(timesteps[:-1])
|
| 240 |
+
|
| 241 |
+
for i, t in timesteps_iter:
|
| 242 |
+
# LCM特定的缩放
|
| 243 |
+
c_skip = self.c_skip
|
| 244 |
+
c_out = self.c_out
|
| 245 |
+
c_in = self.c_in
|
| 246 |
+
c_noise = self.c_noise
|
| 247 |
+
|
| 248 |
+
# 缩放输入
|
| 249 |
+
scaled_latents = c_in * latents
|
| 250 |
+
|
| 251 |
+
# 预测
|
| 252 |
+
noise_pred = self.model(scaled_latents, c_noise * t.unsqueeze(0), prompt_embeds)
|
| 253 |
+
|
| 254 |
+
# LCM更新规则
|
| 255 |
+
denoised = c_skip * latents + c_out * noise_pred
|
| 256 |
+
|
| 257 |
+
# 更新潜在表示
|
| 258 |
+
dt = timesteps[i + 1] - t
|
| 259 |
+
latents = denoised + dt * noise_pred
|
| 260 |
+
|
| 261 |
+
return latents
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class SamplerFactory:
|
| 265 |
+
"""采样器工厂"""
|
| 266 |
+
@staticmethod
|
| 267 |
+
def create_sampler(
|
| 268 |
+
sampler_type: str,
|
| 269 |
+
model: nn.Module,
|
| 270 |
+
diffusion,
|
| 271 |
+
num_inference_steps: int = 50
|
| 272 |
+
):
|
| 273 |
+
"""创建采样器"""
|
| 274 |
+
if sampler_type == "ddim":
|
| 275 |
+
return DDIMSampler(model, diffusion, num_inference_steps)
|
| 276 |
+
elif sampler_type == "dpm":
|
| 277 |
+
return DPMSampler(model, diffusion, num_inference_steps)
|
| 278 |
+
elif sampler_type == "lcm":
|
| 279 |
+
return LCMSampler(model, diffusion, num_inference_steps)
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"未知的采样器类型: {sampler_type}")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class TextToImagePipeline:
|
| 285 |
+
"""文本到图像管道"""
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
model: nn.Module,
|
| 289 |
+
diffusion,
|
| 290 |
+
text_encoder,
|
| 291 |
+
vae_decoder,
|
| 292 |
+
sampler_type: str = "ddim",
|
| 293 |
+
device: str = "cuda"
|
| 294 |
+
):
|
| 295 |
+
self.model = model.to(device)
|
| 296 |
+
self.diffusion = diffusion
|
| 297 |
+
self.text_encoder = text_encoder
|
| 298 |
+
self.vae_decoder = vae_decoder
|
| 299 |
+
self.sampler_type = sampler_type
|
| 300 |
+
self.device = device
|
| 301 |
+
|
| 302 |
+
# 创建采样器
|
| 303 |
+
self.sampler = SamplerFactory.create_sampler(
|
| 304 |
+
sampler_type, model, diffusion
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# 设置为评估模式
|
| 308 |
+
self.model.eval()
|
| 309 |
+
if self.vae_decoder is not None:
|
| 310 |
+
self.vae_decoder.eval()
|
| 311 |
+
|
| 312 |
+
@torch.no_grad()
|
| 313 |
+
def __call__(
|
| 314 |
+
self,
|
| 315 |
+
prompt: str,
|
| 316 |
+
negative_prompt: str = "",
|
| 317 |
+
height: int = 512,
|
| 318 |
+
width: int = 512,
|
| 319 |
+
num_inference_steps: int = 50,
|
| 320 |
+
guidance_scale: float = 7.5,
|
| 321 |
+
num_images: int = 1,
|
| 322 |
+
seed: Optional[int] = None,
|
| 323 |
+
progress_bar: bool = True
|
| 324 |
+
) -> List:
|
| 325 |
+
"""生成图像"""
|
| 326 |
+
# 设置随机种子
|
| 327 |
+
if seed is not None:
|
| 328 |
+
torch.manual_seed(seed)
|
| 329 |
+
if torch.cuda.is_available():
|
| 330 |
+
torch.cuda.manual_seed(seed)
|
| 331 |
+
|
| 332 |
+
# 编码提示
|
| 333 |
+
prompt_embeds = self.text_encoder.encode([prompt]).to(self.device)
|
| 334 |
+
negative_prompt_embeds = None
|
| 335 |
+
|
| 336 |
+
if negative_prompt:
|
| 337 |
+
negative_prompt_embeds = self.text_encoder.encode([negative_prompt]).to(self.device)
|
| 338 |
+
|
| 339 |
+
# 生成潜在表示
|
| 340 |
+
latents = self.sampler.sample(
|
| 341 |
+
prompt_embeds=prompt_embeds,
|
| 342 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 343 |
+
height=height,
|
| 344 |
+
width=width,
|
| 345 |
+
num_images_per_prompt=num_images,
|
| 346 |
+
guidance_scale=guidance_scale,
|
| 347 |
+
progress_bar=progress_bar
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# 解码为图像
|
| 351 |
+
images = []
|
| 352 |
+
for i in range(num_images):
|
| 353 |
+
latent = latents[i:i+1]
|
| 354 |
+
|
| 355 |
+
if self.vae_decoder is not None:
|
| 356 |
+
image = self.vae_decoder(latent)
|
| 357 |
+
else:
|
| 358 |
+
# 如果没有VAE解码器,返回潜在表示
|
| 359 |
+
image = latent
|
| 360 |
+
|
| 361 |
+
images.append(image.cpu())
|
| 362 |
+
|
| 363 |
+
return images
|
| 364 |
+
|
| 365 |
+
def generate_grid(
|
| 366 |
+
self,
|
| 367 |
+
prompts: List[str],
|
| 368 |
+
grid_size: Tuple[int, int] = (2, 2),
|
| 369 |
+
**kwargs
|
| 370 |
+
) -> torch.Tensor:
|
| 371 |
+
"""生成图像网格"""
|
| 372 |
+
images = []
|
| 373 |
+
|
| 374 |
+
for prompt in prompts[:grid_size[0] * grid_size[1]]:
|
| 375 |
+
image = self(prompt, **kwargs)[0]
|
| 376 |
+
images.append(image)
|
| 377 |
+
|
| 378 |
+
# 创建网格
|
| 379 |
+
from torchvision.utils import make_grid
|
| 380 |
+
grid = make_grid(torch.cat(images, dim=0), nrow=grid_size[1])
|
| 381 |
+
|
| 382 |
+
return grid
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def test_sampler():
|
| 386 |
+
"""测试采样器"""
|
| 387 |
+
import torch.nn as nn
|
| 388 |
+
|
| 389 |
+
# 创建模拟模型
|
| 390 |
+
class MockModel(nn.Module):
|
| 391 |
+
def __init__(self):
|
| 392 |
+
super().__init__()
|
| 393 |
+
self.in_channels = 4
|
| 394 |
+
|
| 395 |
+
def forward(self, x, t, context):
|
| 396 |
+
# 返回随机噪声
|
| 397 |
+
return torch.randn_like(x)
|
| 398 |
+
|
| 399 |
+
# 创建模拟扩散过程
|
| 400 |
+
class MockDiffusion:
|
| 401 |
+
def __init__(self):
|
| 402 |
+
self.num_train_timesteps = 1000
|
| 403 |
+
self.alphas_cumprod = torch.ones(1000)
|
| 404 |
+
self.prediction_type = "epsilon"
|
| 405 |
+
|
| 406 |
+
def extract(self, a, t, x_shape):
|
| 407 |
+
return torch.ones(x_shape[0], 1, 1, 1)
|
| 408 |
+
|
| 409 |
+
def scale_model_input(self, x, t):
|
| 410 |
+
return x
|
| 411 |
+
|
| 412 |
+
model = MockModel()
|
| 413 |
+
diffusion = MockDiffusion()
|
| 414 |
+
|
| 415 |
+
# 测试DDIM采样器
|
| 416 |
+
sampler = DDIMSampler(model, diffusion, num_inference_steps=10)
|
| 417 |
+
|
| 418 |
+
# 测试采样
|
| 419 |
+
prompt_embeds = torch.randn(1, 77, 768)
|
| 420 |
+
latents = sampler.sample(prompt_embeds, height=64, width=64, progress_bar=False)
|
| 421 |
+
|
| 422 |
+
print(f"DDIM采样完成,潜在表示形状: {latents.shape}")
|
| 423 |
+
|
| 424 |
+
return sampler, latents
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
if __name__ == '__main__':
|
| 428 |
+
test_sampler()
|
src/models/attention.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MemoryEfficientAttention(nn.Module):
|
| 9 |
+
"""内存高效的多头注意力"""
|
| 10 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dropout: float = 0.0):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.dim = dim
|
| 13 |
+
self.num_heads = num_heads
|
| 14 |
+
self.head_dim = dim // num_heads
|
| 15 |
+
self.scale = self.head_dim ** -0.5
|
| 16 |
+
|
| 17 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 18 |
+
self.proj = nn.Linear(dim, dim)
|
| 19 |
+
self.dropout = nn.Dropout(dropout)
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 22 |
+
B, N, C = x.shape
|
| 23 |
+
|
| 24 |
+
# 计算QKV
|
| 25 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 26 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 27 |
+
|
| 28 |
+
# 分块计算注意力,避免OOM
|
| 29 |
+
chunk_size = min(32, N)
|
| 30 |
+
attn_output = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
|
| 31 |
+
|
| 32 |
+
for i in range(0, N, chunk_size):
|
| 33 |
+
q_chunk = q[:, :, i:i+chunk_size, :]
|
| 34 |
+
|
| 35 |
+
# 计算注意力分数
|
| 36 |
+
attn_scores = torch.matmul(q_chunk, k.transpose(-2, -1)) * self.scale
|
| 37 |
+
|
| 38 |
+
if mask is not None:
|
| 39 |
+
attn_scores = attn_scores + mask
|
| 40 |
+
|
| 41 |
+
attn_probs = F.softmax(attn_scores, dim=-1)
|
| 42 |
+
attn_probs = self.dropout(attn_probs)
|
| 43 |
+
|
| 44 |
+
# 计算输出
|
| 45 |
+
attn_output[:, :, i:i+chunk_size, :] = torch.matmul(attn_probs, v)
|
| 46 |
+
|
| 47 |
+
# 合并多头
|
| 48 |
+
attn_output = attn_output.transpose(1, 2).reshape(B, N, C)
|
| 49 |
+
|
| 50 |
+
# 输出投影
|
| 51 |
+
output = self.proj(attn_output)
|
| 52 |
+
output = self.dropout(output)
|
| 53 |
+
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CrossAttention(nn.Module):
|
| 58 |
+
"""交叉注意力(用于文本条件)"""
|
| 59 |
+
def __init__(self, query_dim: int, context_dim: int, num_heads: int = 8, dropout: float = 0.0):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.query_dim = query_dim
|
| 62 |
+
self.context_dim = context_dim
|
| 63 |
+
self.num_heads = num_heads
|
| 64 |
+
self.head_dim = query_dim // num_heads
|
| 65 |
+
self.scale = self.head_dim ** -0.5
|
| 66 |
+
|
| 67 |
+
self.to_q = nn.Linear(query_dim, query_dim)
|
| 68 |
+
self.to_k = nn.Linear(context_dim, query_dim)
|
| 69 |
+
self.to_v = nn.Linear(context_dim, query_dim)
|
| 70 |
+
self.proj = nn.Linear(query_dim, query_dim)
|
| 71 |
+
self.dropout = nn.Dropout(dropout)
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 74 |
+
B, N, C = x.shape
|
| 75 |
+
|
| 76 |
+
# 计算Q
|
| 77 |
+
q = self.to_q(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
| 78 |
+
|
| 79 |
+
# 计算K, V
|
| 80 |
+
k = self.to_k(context).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 81 |
+
v = self.to_v(context).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 82 |
+
|
| 83 |
+
# 分块计算注意力
|
| 84 |
+
chunk_size = min(32, N)
|
| 85 |
+
attn_output = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
|
| 86 |
+
|
| 87 |
+
for i in range(0, N, chunk_size):
|
| 88 |
+
q_chunk = q[:, :, i:i+chunk_size, :]
|
| 89 |
+
|
| 90 |
+
# 计算注意力分数
|
| 91 |
+
attn_scores = torch.matmul(q_chunk, k.transpose(-2, -1)) * self.scale
|
| 92 |
+
|
| 93 |
+
if mask is not None:
|
| 94 |
+
attn_scores = attn_scores + mask
|
| 95 |
+
|
| 96 |
+
attn_probs = F.softmax(attn_scores, dim=-1)
|
| 97 |
+
attn_probs = self.dropout(attn_probs)
|
| 98 |
+
|
| 99 |
+
# 计算输出
|
| 100 |
+
attn_output[:, :, i:i+chunk_size, :] = torch.matmul(attn_probs, v)
|
| 101 |
+
|
| 102 |
+
# 合并多头
|
| 103 |
+
attn_output = attn_output.transpose(1, 2).reshape(B, N, C)
|
| 104 |
+
|
| 105 |
+
# 输出投影
|
| 106 |
+
output = self.proj(attn_output)
|
| 107 |
+
|
| 108 |
+
return output
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class FlashAttentionWrapper(nn.Module):
|
| 112 |
+
"""FlashAttention包装器(如果可用)"""
|
| 113 |
+
def __init__(self, dim: int, num_heads: int = 8):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.dim = dim
|
| 116 |
+
self.num_heads = num_heads
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
from flash_attn import flash_attn_qkvpacked_func
|
| 120 |
+
self.use_flash = True
|
| 121 |
+
except ImportError:
|
| 122 |
+
self.use_flash = False
|
| 123 |
+
|
| 124 |
+
if not self.use_flash:
|
| 125 |
+
self.attention = MemoryEfficientAttention(dim, num_heads)
|
| 126 |
+
|
| 127 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
if self.use_flash:
|
| 129 |
+
return self._flash_attention(x)
|
| 130 |
+
else:
|
| 131 |
+
return self.attention(x)
|
| 132 |
+
|
| 133 |
+
def _flash_attention(self, x: torch.Tensor) -> torch.Tensor:
|
| 134 |
+
# FlashAttention实现
|
| 135 |
+
B, N, C = x.shape
|
| 136 |
+
qkv = x.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 137 |
+
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, num_heads, N, head_dim]
|
| 138 |
+
|
| 139 |
+
from flash_attn import flash_attn_qkvpacked_func
|
| 140 |
+
output = flash_attn_qkvpacked_func(qkv)
|
| 141 |
+
output = output.reshape(B, N, C)
|
| 142 |
+
|
| 143 |
+
return output
|
src/models/diffusion.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BetaScheduler:
|
| 9 |
+
"""Beta调度器"""
|
| 10 |
+
@staticmethod
|
| 11 |
+
def linear(num_timesteps: int, beta_start: float = 0.0001, beta_end: float = 0.02) -> np.ndarray:
|
| 12 |
+
"""线性调度"""
|
| 13 |
+
return np.linspace(beta_start, beta_end, num_timesteps, dtype=np.float32)
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
def cosine(num_timesteps: int, s: float = 0.008) -> np.ndarray:
|
| 17 |
+
"""余弦调度"""
|
| 18 |
+
steps = num_timesteps + 1
|
| 19 |
+
x = np.linspace(0, num_timesteps, steps)
|
| 20 |
+
alphas_cumprod = np.cos(((x / num_timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
| 21 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 22 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 23 |
+
return np.clip(betas, 0, 0.999)
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def scaled_linear(num_timesteps: int) -> np.ndarray:
|
| 27 |
+
"""缩放线性调度(Stable Diffusion默认)"""
|
| 28 |
+
beta_start = 0.00085
|
| 29 |
+
beta_end = 0.012
|
| 30 |
+
return np.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DiffusionProcess:
|
| 34 |
+
"""扩散过程管理"""
|
| 35 |
+
def __init__(self, config: dict):
|
| 36 |
+
self.config = config
|
| 37 |
+
diff_config = config.get('diffusion', {})
|
| 38 |
+
|
| 39 |
+
self.num_train_timesteps = diff_config.get('num_train_timesteps', 1000)
|
| 40 |
+
self.num_inference_timesteps = diff_config.get('num_inference_timesteps', 50)
|
| 41 |
+
self.beta_start = diff_config.get('beta_start', 0.00085)
|
| 42 |
+
self.beta_end = diff_config.get('beta_end', 0.012)
|
| 43 |
+
self.beta_schedule = diff_config.get('beta_schedule', 'scaled_linear')
|
| 44 |
+
self.prediction_type = diff_config.get('prediction_type', 'epsilon')
|
| 45 |
+
|
| 46 |
+
# 初始化调度参数
|
| 47 |
+
self._init_schedule()
|
| 48 |
+
|
| 49 |
+
def _init_schedule(self):
|
| 50 |
+
"""初始化扩散调度参数"""
|
| 51 |
+
# 计算betas
|
| 52 |
+
if self.beta_schedule == "linear":
|
| 53 |
+
betas = BetaScheduler.linear(
|
| 54 |
+
self.num_train_timesteps,
|
| 55 |
+
self.beta_start,
|
| 56 |
+
self.beta_end
|
| 57 |
+
)
|
| 58 |
+
elif self.beta_schedule == "cosine":
|
| 59 |
+
betas = BetaScheduler.cosine(self.num_train_timesteps)
|
| 60 |
+
elif self.beta_schedule == "scaled_linear":
|
| 61 |
+
betas = BetaScheduler.scaled_linear(self.num_train_timesteps)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Unknown beta schedule: {self.beta_schedule}")
|
| 64 |
+
|
| 65 |
+
self.betas = torch.from_numpy(betas).float()
|
| 66 |
+
|
| 67 |
+
# 计算alphas
|
| 68 |
+
self.alphas = 1.0 - self.betas
|
| 69 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
| 70 |
+
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
|
| 71 |
+
|
| 72 |
+
# 计算扩散后验方差
|
| 73 |
+
self.variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
| 74 |
+
|
| 75 |
+
# 注册为buffer
|
| 76 |
+
self.register_buffer = lambda name, tensor: setattr(self, name, tensor)
|
| 77 |
+
self.register_buffer('betas', self.betas)
|
| 78 |
+
self.register_buffer('alphas', self.alphas)
|
| 79 |
+
self.register_buffer('alphas_cumprod', self.alphas_cumprod)
|
| 80 |
+
self.register_buffer('alphas_cumprod_prev', self.alphas_cumprod_prev)
|
| 81 |
+
self.register_buffer('variance', self.variance)
|
| 82 |
+
|
| 83 |
+
# 计算采样系数
|
| 84 |
+
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
|
| 85 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - self.alphas_cumprod))
|
| 86 |
+
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1.0 - self.alphas_cumprod))
|
| 87 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / self.alphas_cumprod))
|
| 88 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / self.alphas_cumprod - 1))
|
| 89 |
+
|
| 90 |
+
def q_sample(self, x_start: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 91 |
+
"""前向扩散过程:加噪"""
|
| 92 |
+
if noise is None:
|
| 93 |
+
noise = torch.randn_like(x_start)
|
| 94 |
+
|
| 95 |
+
sqrt_alphas_cumprod_t = self.extract(self.sqrt_alphas_cumprod, t, x_start.shape)
|
| 96 |
+
sqrt_one_minus_alphas_cumprod_t = self.extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
| 97 |
+
|
| 98 |
+
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
|
| 99 |
+
|
| 100 |
+
def extract(self, a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...]) -> torch.Tensor:
|
| 101 |
+
"""从张量a中提取索引t处的值"""
|
| 102 |
+
batch_size = t.shape[0]
|
| 103 |
+
out = a.gather(-1, t.cpu())
|
| 104 |
+
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
|
| 105 |
+
|
| 106 |
+
def get_loss_weight(self, snr: torch.Tensor, gamma: float = 5.0) -> torch.Tensor:
|
| 107 |
+
"""根据SNR计算损失权重"""
|
| 108 |
+
if gamma is None:
|
| 109 |
+
return torch.ones_like(snr)
|
| 110 |
+
|
| 111 |
+
snr = torch.clamp(snr, min=1e-8)
|
| 112 |
+
min_snr = torch.tensor(gamma, device=snr.device)
|
| 113 |
+
weight = torch.minimum(snr, min_snr) / snr
|
| 114 |
+
return weight
|
| 115 |
+
|
| 116 |
+
def compute_snr(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
"""计算信噪比(SNR)"""
|
| 118 |
+
alphas_cumprod = self.extract(self.alphas_cumprod, timesteps, timesteps.shape)
|
| 119 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
| 120 |
+
return snr
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class DDIMScheduler:
|
| 124 |
+
"""DDIM采样器"""
|
| 125 |
+
def __init__(self, diffusion: DiffusionProcess):
|
| 126 |
+
self.diffusion = diffusion
|
| 127 |
+
self.num_train_timesteps = diffusion.num_train_timesteps
|
| 128 |
+
self.num_inference_timesteps = diffusion.num_inference_timesteps
|
| 129 |
+
|
| 130 |
+
# 设置时间步
|
| 131 |
+
self.set_timesteps(self.num_inference_timesteps)
|
| 132 |
+
|
| 133 |
+
def set_timesteps(self, num_inference_timesteps: int):
|
| 134 |
+
"""设置推理时间步"""
|
| 135 |
+
self.num_inference_timesteps = num_inference_timesteps
|
| 136 |
+
|
| 137 |
+
# 选择时间步
|
| 138 |
+
if self.num_train_timesteps == self.num_inference_timesteps:
|
| 139 |
+
self.timesteps = torch.arange(0, self.num_train_timesteps).long()
|
| 140 |
+
else:
|
| 141 |
+
step_ratio = self.num_train_timesteps // self.num_inference_timesteps
|
| 142 |
+
self.timesteps = torch.arange(0, self.num_train_timesteps, step_ratio).long()
|
| 143 |
+
|
| 144 |
+
self.timesteps = self.timesteps.flip(0) # 从T到0
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, eta: float = 0.0) -> torch.Tensor:
|
| 148 |
+
"""DDIM单步采样"""
|
| 149 |
+
# 获取当前时间步的参数
|
| 150 |
+
prev_timestep = timestep - self.num_train_timesteps // self.num_inference_timesteps
|
| 151 |
+
|
| 152 |
+
# 提取alpha参数
|
| 153 |
+
alpha_prod_t = self.diffusion.extract(self.diffusion.alphas_cumprod, timestep, sample.shape)
|
| 154 |
+
alpha_prod_t_prev = self.diffusion.extract(
|
| 155 |
+
self.diffusion.alphas_cumprod,
|
| 156 |
+
prev_timestep,
|
| 157 |
+
sample.shape
|
| 158 |
+
) if prev_timestep >= 0 else torch.ones_like(alpha_prod_t)
|
| 159 |
+
|
| 160 |
+
# 根据预测类型处理模型输出
|
| 161 |
+
if self.diffusion.prediction_type == "epsilon":
|
| 162 |
+
pred_original_sample = (sample - (1 - alpha_prod_t) ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
| 163 |
+
pred_epsilon = model_output
|
| 164 |
+
elif self.diffusion.prediction_type == "sample":
|
| 165 |
+
pred_original_sample = model_output
|
| 166 |
+
pred_epsilon = (sample - alpha_prod_t ** 0.5 * pred_original_sample) / (1 - alpha_prod_t) ** 0.5
|
| 167 |
+
elif self.diffusion.prediction_type == "v_prediction":
|
| 168 |
+
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (1 - alpha_prod_t) ** 0.5 * model_output
|
| 169 |
+
pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (1 - alpha_prod_t) ** 0.5 * sample
|
| 170 |
+
else:
|
| 171 |
+
raise ValueError(f"Unsupported prediction type: {self.diffusion.prediction_type}")
|
| 172 |
+
|
| 173 |
+
# 计算x_t-1的方差
|
| 174 |
+
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 175 |
+
std_dev_t = eta * variance ** 0.5
|
| 176 |
+
|
| 177 |
+
# 计算x_t-1的均值
|
| 178 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon
|
| 179 |
+
prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
|
| 180 |
+
|
| 181 |
+
# 添加噪声
|
| 182 |
+
if eta > 0:
|
| 183 |
+
noise = torch.randn_like(model_output)
|
| 184 |
+
prev_sample = prev_sample + std_dev_t * noise
|
| 185 |
+
|
| 186 |
+
return prev_sample
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class DiffusionModel(nn.Module):
|
| 190 |
+
"""扩散模型封装"""
|
| 191 |
+
def __init__(self, unet: nn.Module, diffusion: DiffusionProcess):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.unet = unet
|
| 194 |
+
self.diffusion = diffusion
|
| 195 |
+
self.scheduler = DDIMScheduler(diffusion)
|
| 196 |
+
|
| 197 |
+
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
| 198 |
+
"""前向传播:预测噪声"""
|
| 199 |
+
return self.unet(x, timesteps, context)
|
| 200 |
+
|
| 201 |
+
def compute_loss(self, x_start: torch.Tensor, context: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 202 |
+
"""计算扩散损失"""
|
| 203 |
+
if noise is None:
|
| 204 |
+
noise = torch.randn_like(x_start)
|
| 205 |
+
|
| 206 |
+
# 随机采样时间步
|
| 207 |
+
batch_size = x_start.shape[0]
|
| 208 |
+
timesteps = torch.randint(
|
| 209 |
+
0, self.diffusion.num_train_timesteps,
|
| 210 |
+
(batch_size,), device=x_start.device
|
| 211 |
+
).long()
|
| 212 |
+
|
| 213 |
+
# 前向扩散
|
| 214 |
+
x_noisy = self.diffusion.q_sample(x_start, timesteps, noise)
|
| 215 |
+
|
| 216 |
+
# 预测噪声
|
| 217 |
+
predicted_noise = self.unet(x_noisy, timesteps, context)
|
| 218 |
+
|
| 219 |
+
# 计算损失
|
| 220 |
+
loss = F.mse_loss(predicted_noise, noise)
|
| 221 |
+
|
| 222 |
+
return loss
|
| 223 |
+
|
| 224 |
+
@torch.no_grad()
|
| 225 |
+
def generate(
|
| 226 |
+
self,
|
| 227 |
+
context: torch.Tensor,
|
| 228 |
+
num_samples: int = 1,
|
| 229 |
+
height: int = 512,
|
| 230 |
+
width: int = 512,
|
| 231 |
+
guidance_scale: float = 7.5
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
"""生成图像"""
|
| 234 |
+
# 初始化噪声
|
| 235 |
+
latents = torch.randn(
|
| 236 |
+
(num_samples, self.unet.in_channels, height // 8, width // 8),
|
| 237 |
+
device=next(self.unet.parameters()).device
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# DDIM采样
|
| 241 |
+
self.scheduler.set_timesteps(self.diffusion.num_inference_timesteps)
|
| 242 |
+
|
| 243 |
+
for t in self.scheduler.timesteps:
|
| 244 |
+
# 扩展latents以匹配批大小
|
| 245 |
+
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
|
| 246 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 247 |
+
|
| 248 |
+
# 预测噪声
|
| 249 |
+
timesteps = torch.full((num_samples,), t, device=latents.device).long()
|
| 250 |
+
if guidance_scale > 1.0:
|
| 251 |
+
timesteps = torch.cat([timesteps] * 2)
|
| 252 |
+
|
| 253 |
+
noise_pred = self.unet(latent_model_input, timesteps, context)
|
| 254 |
+
|
| 255 |
+
# 应用分类器自由引导
|
| 256 |
+
if guidance_scale > 1.0:
|
| 257 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 258 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 259 |
+
|
| 260 |
+
# 计算上一个样本
|
| 261 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 262 |
+
|
| 263 |
+
return latents
|
src/models/unet_light.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TimestepEmbedding(nn.Module):
|
| 9 |
+
"""时间步嵌入"""
|
| 10 |
+
def __init__(self, embedding_dim: int, time_embed_dim: int):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.embedding_dim = embedding_dim
|
| 13 |
+
self.time_embed_dim = time_embed_dim
|
| 14 |
+
|
| 15 |
+
self.mlp = nn.Sequential(
|
| 16 |
+
nn.Linear(embedding_dim, time_embed_dim),
|
| 17 |
+
nn.SiLU(),
|
| 18 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
# 正弦位置编码
|
| 23 |
+
half_dim = self.embedding_dim // 2
|
| 24 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 25 |
+
emb = torch.exp(torch.arange(half_dim, device=timestep.device) * -emb)
|
| 26 |
+
emb = timestep[:, None] * emb[None, :]
|
| 27 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 28 |
+
|
| 29 |
+
if self.embedding_dim % 2 == 1:
|
| 30 |
+
emb = F.pad(emb, (0, 1, 0, 0))
|
| 31 |
+
|
| 32 |
+
return self.mlp(emb)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AttentionBlock(nn.Module):
|
| 36 |
+
"""内存高效的注意力块"""
|
| 37 |
+
def __init__(self, channels: int, num_heads: int = 4, use_checkpoint: bool = True):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.channels = channels
|
| 40 |
+
self.num_heads = num_heads
|
| 41 |
+
self.use_checkpoint = use_checkpoint
|
| 42 |
+
|
| 43 |
+
self.norm = nn.GroupNorm(32, channels)
|
| 44 |
+
self.qkv = nn.Conv2d(channels, channels * 3, 1)
|
| 45 |
+
self.proj_out = nn.Conv2d(channels, channels, 1)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
if self.use_checkpoint and self.training:
|
| 49 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x)
|
| 50 |
+
return self._forward(x)
|
| 51 |
+
|
| 52 |
+
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
B, C, H, W = x.shape
|
| 54 |
+
qkv = self.qkv(self.norm(x))
|
| 55 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 56 |
+
|
| 57 |
+
# 分块计算注意力,避免OOM
|
| 58 |
+
chunk_size = min(32, H * W)
|
| 59 |
+
q = q.view(B, self.num_heads, C // self.num_heads, H * W).permute(0, 1, 3, 2)
|
| 60 |
+
k = k.view(B, self.num_heads, C // self.num_heads, H * W).permute(0, 1, 2, 3)
|
| 61 |
+
v = v.view(B, self.num_heads, C // self.num_heads, H * W).permute(0, 1, 3, 2)
|
| 62 |
+
|
| 63 |
+
# 分块计算
|
| 64 |
+
attn_output = torch.zeros_like(q)
|
| 65 |
+
for i in range(0, H * W, chunk_size):
|
| 66 |
+
q_chunk = q[:, :, i:i+chunk_size, :]
|
| 67 |
+
scores = torch.matmul(q_chunk, k) / math.sqrt(C // self.num_heads)
|
| 68 |
+
attn = F.softmax(scores, dim=-1)
|
| 69 |
+
attn_output[:, :, i:i+chunk_size, :] = torch.matmul(attn, v)
|
| 70 |
+
|
| 71 |
+
attn_output = attn_output.permute(0, 1, 3, 2).reshape(B, C, H, W)
|
| 72 |
+
return x + self.proj_out(attn_output)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class ResNetBlock(nn.Module):
|
| 76 |
+
"""残差块"""
|
| 77 |
+
def __init__(self, in_channels: int, out_channels: int, time_embed_dim: int, dropout: float = 0.0):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.in_channels = in_channels
|
| 80 |
+
self.out_channels = out_channels
|
| 81 |
+
|
| 82 |
+
# 第一个归一化+激活+卷积
|
| 83 |
+
self.norm1 = nn.GroupNorm(32, in_channels)
|
| 84 |
+
self.act1 = nn.SiLU()
|
| 85 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
| 86 |
+
|
| 87 |
+
# 时间嵌入投影
|
| 88 |
+
self.time_emb_proj = nn.Linear(time_embed_dim, out_channels)
|
| 89 |
+
|
| 90 |
+
# 第二个归一化+激活+卷积
|
| 91 |
+
self.norm2 = nn.GroupNorm(32, out_channels)
|
| 92 |
+
self.act2 = nn.SiLU()
|
| 93 |
+
self.dropout = nn.Dropout(dropout)
|
| 94 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
|
| 95 |
+
|
| 96 |
+
# 如果需要调整通道数
|
| 97 |
+
if in_channels != out_channels:
|
| 98 |
+
self.skip_conv = nn.Conv2d(in_channels, out_channels, 1)
|
| 99 |
+
else:
|
| 100 |
+
self.skip_conv = nn.Identity()
|
| 101 |
+
|
| 102 |
+
def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
h = self.conv1(self.act1(self.norm1(x)))
|
| 104 |
+
|
| 105 |
+
# 添加时间嵌入
|
| 106 |
+
time_emb = self.time_emb_proj(F.silu(time_emb))
|
| 107 |
+
h = h + time_emb[:, :, None, None]
|
| 108 |
+
|
| 109 |
+
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
|
| 110 |
+
|
| 111 |
+
return h + self.skip_conv(x)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class CrossAttentionBlock(nn.Module):
|
| 115 |
+
"""交叉注意力块(文本条件)"""
|
| 116 |
+
def __init__(self, query_dim: int, context_dim: int, num_heads: int = 4, use_checkpoint: bool = True):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.query_dim = query_dim
|
| 119 |
+
self.context_dim = context_dim
|
| 120 |
+
self.num_heads = num_heads
|
| 121 |
+
self.use_checkpoint = use_checkpoint
|
| 122 |
+
|
| 123 |
+
self.norm = nn.GroupNorm(32, query_dim)
|
| 124 |
+
self.to_q = nn.Linear(query_dim, query_dim)
|
| 125 |
+
self.to_k = nn.Linear(context_dim, query_dim)
|
| 126 |
+
self.to_v = nn.Linear(context_dim, query_dim)
|
| 127 |
+
self.to_out = nn.Linear(query_dim, query_dim)
|
| 128 |
+
|
| 129 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
if self.use_checkpoint and self.training:
|
| 131 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, context)
|
| 132 |
+
return self._forward(x, context)
|
| 133 |
+
|
| 134 |
+
def _forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
B, C, H, W = x.shape
|
| 136 |
+
x_reshaped = x.reshape(B, C, H * W).permute(0, 2, 1)
|
| 137 |
+
x_norm = self.norm(x_reshaped)
|
| 138 |
+
|
| 139 |
+
q = self.to_q(x_norm)
|
| 140 |
+
k = self.to_k(context)
|
| 141 |
+
v = self.to_v(context)
|
| 142 |
+
|
| 143 |
+
# 分头
|
| 144 |
+
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
|
| 145 |
+
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
|
| 146 |
+
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
|
| 147 |
+
|
| 148 |
+
# 分块计算注意力
|
| 149 |
+
chunk_size = min(32, q.shape[2])
|
| 150 |
+
attn_output = torch.zeros_like(q)
|
| 151 |
+
|
| 152 |
+
for i in range(0, q.shape[2], chunk_size):
|
| 153 |
+
q_chunk = q[:, :, i:i+chunk_size, :]
|
| 154 |
+
scores = torch.matmul(q_chunk, k.transpose(-2, -1)) / math.sqrt(C // self.num_heads)
|
| 155 |
+
attn = F.softmax(scores, dim=-1)
|
| 156 |
+
attn_output[:, :, i:i+chunk_size, :] = torch.matmul(attn, v)
|
| 157 |
+
|
| 158 |
+
attn_output = attn_output.transpose(1, 2).reshape(B, -1, C)
|
| 159 |
+
attn_output = self.to_out(attn_output)
|
| 160 |
+
|
| 161 |
+
return (x_reshaped + attn_output).permute(0, 2, 1).reshape(B, C, H, W)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class DownsampleBlock(nn.Module):
|
| 165 |
+
"""下采样块"""
|
| 166 |
+
def __init__(self, channels: int, use_conv: bool = True):
|
| 167 |
+
super().__init__()
|
| 168 |
+
if use_conv:
|
| 169 |
+
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
|
| 170 |
+
else:
|
| 171 |
+
self.op = nn.AvgPool2d(2)
|
| 172 |
+
|
| 173 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 174 |
+
return self.op(x)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class UpsampleBlock(nn.Module):
|
| 178 |
+
"""上采样块"""
|
| 179 |
+
def __init__(self, channels: int, use_conv: bool = True):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.channels = channels
|
| 182 |
+
if use_conv:
|
| 183 |
+
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
|
| 184 |
+
else:
|
| 185 |
+
self.conv = nn.Identity()
|
| 186 |
+
|
| 187 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 189 |
+
return self.conv(x)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class UNetLight(nn.Module):
|
| 193 |
+
"""轻量级UNet模型"""
|
| 194 |
+
def __init__(self, config: dict):
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.config = config
|
| 197 |
+
model_config = config.get('model', {})
|
| 198 |
+
|
| 199 |
+
# 基本参数
|
| 200 |
+
self.in_channels = model_config.get('in_channels', 4)
|
| 201 |
+
self.out_channels = model_config.get('out_channels', 4)
|
| 202 |
+
self.base_channels = model_config.get('base_channels', 64)
|
| 203 |
+
self.channel_mults = model_config.get('channel_mults', [1, 2, 4, 8])
|
| 204 |
+
self.num_res_blocks = model_config.get('num_res_blocks', 2)
|
| 205 |
+
self.attention_resolutions = model_config.get('attention_resolutions', [8])
|
| 206 |
+
self.dropout = model_config.get('dropout', 0.0)
|
| 207 |
+
self.use_checkpoint = model_config.get('use_checkpoint', True)
|
| 208 |
+
self.num_heads = model_config.get('num_heads', 4)
|
| 209 |
+
|
| 210 |
+
# 条件参数
|
| 211 |
+
self.context_dim = model_config.get('context_dim', 768)
|
| 212 |
+
self.use_linear_projection = model_config.get('use_linear_projection', True)
|
| 213 |
+
|
| 214 |
+
# 时间嵌入
|
| 215 |
+
self.time_embed_dim = model_config.get('time_embed_dim', 256)
|
| 216 |
+
self.time_embed = TimestepEmbedding(self.base_channels, self.time_embed_dim)
|
| 217 |
+
|
| 218 |
+
# 输入卷积
|
| 219 |
+
self.input_conv = nn.Conv2d(self.in_channels, self.base_channels, 3, padding=1)
|
| 220 |
+
|
| 221 |
+
# 构建下采样路径
|
| 222 |
+
self.down_blocks = nn.ModuleList()
|
| 223 |
+
self.down_attention_blocks = nn.ModuleList()
|
| 224 |
+
self.downsample_blocks = nn.ModuleList()
|
| 225 |
+
|
| 226 |
+
in_ch = self.base_channels
|
| 227 |
+
resolution = 1
|
| 228 |
+
|
| 229 |
+
for i, mult in enumerate(self.channel_mults):
|
| 230 |
+
out_ch = self.base_channels * mult
|
| 231 |
+
resolution *= 2
|
| 232 |
+
|
| 233 |
+
# 残差块
|
| 234 |
+
for _ in range(self.num_res_blocks):
|
| 235 |
+
block = ResNetBlock(in_ch, out_ch, self.time_embed_dim, self.dropout)
|
| 236 |
+
self.down_blocks.append(block)
|
| 237 |
+
|
| 238 |
+
# 在指定分辨率添加注意力
|
| 239 |
+
if resolution in self.attention_resolutions:
|
| 240 |
+
attn = CrossAttentionBlock(out_ch, self.context_dim, self.num_heads, self.use_checkpoint)
|
| 241 |
+
self.down_attention_blocks.append(attn)
|
| 242 |
+
else:
|
| 243 |
+
self.down_attention_blocks.append(None)
|
| 244 |
+
|
| 245 |
+
in_ch = out_ch
|
| 246 |
+
|
| 247 |
+
# 如果不是最后一层,添加下采样
|
| 248 |
+
if i != len(self.channel_mults) - 1:
|
| 249 |
+
downsample = DownsampleBlock(in_ch, use_conv=True)
|
| 250 |
+
self.downsample_blocks.append(downsample)
|
| 251 |
+
|
| 252 |
+
# 中间层
|
| 253 |
+
self.mid_block1 = ResNetBlock(in_ch, in_ch, self.time_embed_dim, self.dropout)
|
| 254 |
+
self.mid_attention = CrossAttentionBlock(in_ch, self.context_dim, self.num_heads, self.use_checkpoint)
|
| 255 |
+
self.mid_block2 = ResNetBlock(in_ch, in_ch, self.time_embed_dim, self.dropout)
|
| 256 |
+
|
| 257 |
+
# 构建上采样路径
|
| 258 |
+
self.up_blocks = nn.ModuleList()
|
| 259 |
+
self.up_attention_blocks = nn.ModuleList()
|
| 260 |
+
self.upsample_blocks = nn.ModuleList()
|
| 261 |
+
|
| 262 |
+
for i, mult in enumerate(reversed(self.channel_mults)):
|
| 263 |
+
out_ch = self.base_channels * mult
|
| 264 |
+
resolution //= 2
|
| 265 |
+
|
| 266 |
+
# 上采样块
|
| 267 |
+
if i > 0:
|
| 268 |
+
upsample = UpsampleBlock(in_ch, use_conv=True)
|
| 269 |
+
self.upsample_blocks.append(upsample)
|
| 270 |
+
|
| 271 |
+
# 残差块
|
| 272 |
+
for j in range(self.num_res_blocks + 1):
|
| 273 |
+
block_in_ch = in_ch * 2 if j == 0 else out_ch
|
| 274 |
+
block = ResNetBlock(block_in_ch, out_ch, self.time_embed_dim, self.dropout)
|
| 275 |
+
self.up_blocks.append(block)
|
| 276 |
+
|
| 277 |
+
# 在指定分辨率添加注意力
|
| 278 |
+
if resolution in self.attention_resolutions:
|
| 279 |
+
attn = CrossAttentionBlock(out_ch, self.context_dim, self.num_heads, self.use_checkpoint)
|
| 280 |
+
self.up_attention_blocks.append(attn)
|
| 281 |
+
else:
|
| 282 |
+
self.up_attention_blocks.append(None)
|
| 283 |
+
|
| 284 |
+
in_ch = out_ch
|
| 285 |
+
|
| 286 |
+
# 输出层
|
| 287 |
+
self.norm_out = nn.GroupNorm(32, self.base_channels)
|
| 288 |
+
self.act_out = nn.SiLU()
|
| 289 |
+
self.output_conv = nn.Conv2d(self.base_channels, self.out_channels, 3, padding=1)
|
| 290 |
+
|
| 291 |
+
# 如果需要,初始化权重
|
| 292 |
+
self.apply(self._init_weights)
|
| 293 |
+
|
| 294 |
+
def _init_weights(self, module):
|
| 295 |
+
if isinstance(module, (nn.Conv2d, nn.Linear)):
|
| 296 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
| 297 |
+
if module.bias is not None:
|
| 298 |
+
nn.init.constant_(module.bias, 0)
|
| 299 |
+
elif isinstance(module, nn.GroupNorm):
|
| 300 |
+
nn.init.constant_(module.weight, 1)
|
| 301 |
+
nn.init.constant_(module.bias, 0)
|
| 302 |
+
|
| 303 |
+
def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
| 304 |
+
# 时间嵌入
|
| 305 |
+
time_emb = self.time_embed(timesteps)
|
| 306 |
+
|
| 307 |
+
# 初始卷积
|
| 308 |
+
h = self.input_conv(x)
|
| 309 |
+
|
| 310 |
+
# 存储跳跃连接
|
| 311 |
+
down_h = []
|
| 312 |
+
|
| 313 |
+
# 下采样路径
|
| 314 |
+
down_idx = 0
|
| 315 |
+
attn_idx = 0
|
| 316 |
+
|
| 317 |
+
for i, mult in enumerate(self.channel_mults):
|
| 318 |
+
for j in range(self.num_res_blocks):
|
| 319 |
+
block = self.down_blocks[down_idx]
|
| 320 |
+
h = block(h, time_emb)
|
| 321 |
+
down_idx += 1
|
| 322 |
+
|
| 323 |
+
# 注意力
|
| 324 |
+
if self.down_attention_blocks[attn_idx] is not None:
|
| 325 |
+
h = self.down_attention_blocks[attn_idx](h, context)
|
| 326 |
+
attn_idx += 1
|
| 327 |
+
|
| 328 |
+
down_h.append(h)
|
| 329 |
+
|
| 330 |
+
# 下采样(除了最后一层)
|
| 331 |
+
if i != len(self.channel_mults) - 1:
|
| 332 |
+
downsample = self.downsample_blocks[i]
|
| 333 |
+
h = downsample(h)
|
| 334 |
+
|
| 335 |
+
# 中间层
|
| 336 |
+
h = self.mid_block1(h, time_emb)
|
| 337 |
+
h = self.mid_attention(h, context)
|
| 338 |
+
h = self.mid_block2(h, time_emb)
|
| 339 |
+
|
| 340 |
+
# 上采样路径
|
| 341 |
+
up_idx = 0
|
| 342 |
+
attn_up_idx = 0
|
| 343 |
+
upsample_idx = 0
|
| 344 |
+
|
| 345 |
+
for i, mult in enumerate(reversed(self.channel_mults)):
|
| 346 |
+
# 上采样(除了第一层)
|
| 347 |
+
if i > 0:
|
| 348 |
+
upsample = self.upsample_blocks[upsample_idx]
|
| 349 |
+
h = upsample(h)
|
| 350 |
+
upsample_idx += 1
|
| 351 |
+
|
| 352 |
+
for j in range(self.num_res_blocks + 1):
|
| 353 |
+
# 拼接跳跃连接
|
| 354 |
+
if j == 0:
|
| 355 |
+
skip = down_h.pop()
|
| 356 |
+
h = torch.cat([h, skip], dim=1)
|
| 357 |
+
|
| 358 |
+
block = self.up_blocks[up_idx]
|
| 359 |
+
h = block(h, time_emb)
|
| 360 |
+
up_idx += 1
|
| 361 |
+
|
| 362 |
+
# 注意力
|
| 363 |
+
if self.up_attention_blocks[attn_up_idx] is not None:
|
| 364 |
+
h = self.up_attention_blocks[attn_up_idx](h, context)
|
| 365 |
+
attn_up_idx += 1
|
| 366 |
+
|
| 367 |
+
# 输出层
|
| 368 |
+
h = self.norm_out(h)
|
| 369 |
+
h = self.act_out(h)
|
| 370 |
+
h = self.output_conv(h)
|
| 371 |
+
|
| 372 |
+
return h
|
| 373 |
+
|
| 374 |
+
def enable_gradient_checkpointing(self):
|
| 375 |
+
"""启用梯度检查点"""
|
| 376 |
+
self.use_checkpoint = True
|
| 377 |
+
for module in self.modules():
|
| 378 |
+
if hasattr(module, 'use_checkpoint'):
|
| 379 |
+
module.use_checkpoint = True
|
src/training/callbacks.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Dict, Any, Optional, List
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torchvision.transforms as T
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Callback:
|
| 11 |
+
"""回调基类"""
|
| 12 |
+
def on_train_begin(self, trainer):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
def on_train_end(self, trainer):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
def on_epoch_begin(self, trainer, epoch):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def on_batch_begin(self, trainer, batch_idx, batch):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def on_batch_end(self, trainer, batch_idx, batch, loss):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def on_validation_begin(self, trainer):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def on_validation_end(self, trainer, val_loss):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class EarlyStopping(Callback):
|
| 38 |
+
"""早停回调"""
|
| 39 |
+
def __init__(self, patience: int = 10, min_delta: float = 1e-4):
|
| 40 |
+
self.patience = patience
|
| 41 |
+
self.min_delta = min_delta
|
| 42 |
+
self.best_loss = float('inf')
|
| 43 |
+
self.counter = 0
|
| 44 |
+
self.should_stop = False
|
| 45 |
+
|
| 46 |
+
def on_validation_end(self, trainer, val_loss):
|
| 47 |
+
if val_loss < self.best_loss - self.min_delta:
|
| 48 |
+
self.best_loss = val_loss
|
| 49 |
+
self.counter = 0
|
| 50 |
+
else:
|
| 51 |
+
self.counter += 1
|
| 52 |
+
|
| 53 |
+
if self.counter >= self.patience:
|
| 54 |
+
self.should_stop = True
|
| 55 |
+
print(f"早停触发,最佳损失: {self.best_loss:.4f}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ModelCheckpoint(Callback):
|
| 59 |
+
"""模型检查点回调"""
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
save_dir: str = './checkpoints',
|
| 63 |
+
save_best_only: bool = True,
|
| 64 |
+
save_freq: int = 1,
|
| 65 |
+
monitor: str = 'val_loss',
|
| 66 |
+
mode: str = 'min'
|
| 67 |
+
):
|
| 68 |
+
self.save_dir = save_dir
|
| 69 |
+
self.save_best_only = save_best_only
|
| 70 |
+
self.save_freq = save_freq
|
| 71 |
+
self.monitor = monitor
|
| 72 |
+
self.mode = mode
|
| 73 |
+
|
| 74 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
self.best_value = float('inf') if mode == 'min' else -float('inf')
|
| 77 |
+
|
| 78 |
+
def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
|
| 79 |
+
if epoch % self.save_freq != 0:
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
# 获取监控的值
|
| 83 |
+
if self.monitor == 'val_loss':
|
| 84 |
+
value = val_loss
|
| 85 |
+
elif self.monitor == 'train_loss':
|
| 86 |
+
value = train_loss
|
| 87 |
+
else:
|
| 88 |
+
value = val_loss
|
| 89 |
+
|
| 90 |
+
# 检查是否需要保存
|
| 91 |
+
should_save = False
|
| 92 |
+
if self.save_best_only:
|
| 93 |
+
if self.mode == 'min' and value < self.best_value:
|
| 94 |
+
self.best_value = value
|
| 95 |
+
should_save = True
|
| 96 |
+
elif self.mode == 'max' and value > self.best_value:
|
| 97 |
+
self.best_value = value
|
| 98 |
+
should_save = True
|
| 99 |
+
else:
|
| 100 |
+
should_save = True
|
| 101 |
+
|
| 102 |
+
if should_save:
|
| 103 |
+
# 保存检查点
|
| 104 |
+
checkpoint = {
|
| 105 |
+
'epoch': epoch,
|
| 106 |
+
'model_state_dict': trainer.model.state_dict(),
|
| 107 |
+
'optimizer_state_dict': trainer.optimizer.state_dict(),
|
| 108 |
+
'train_loss': train_loss,
|
| 109 |
+
'val_loss': val_loss,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
if trainer.use_ema:
|
| 113 |
+
checkpoint['ema_model_state_dict'] = trainer.ema_model.state_dict()
|
| 114 |
+
|
| 115 |
+
filename = f'checkpoint_epoch_{epoch}.pt' if not self.save_best_only else 'best_model.pt'
|
| 116 |
+
save_path = os.path.join(self.save_dir, filename)
|
| 117 |
+
|
| 118 |
+
torch.save(checkpoint, save_path)
|
| 119 |
+
print(f"检查点已保存: {save_path}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class LearningRateSchedulerCallback(Callback):
|
| 123 |
+
"""学习率调度回调"""
|
| 124 |
+
def __init__(self, scheduler, update_on: str = 'epoch'):
|
| 125 |
+
self.scheduler = scheduler
|
| 126 |
+
self.update_on = update_on # 'epoch' 或 'batch'
|
| 127 |
+
|
| 128 |
+
def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
|
| 129 |
+
if self.update_on == 'epoch':
|
| 130 |
+
self.scheduler.step()
|
| 131 |
+
|
| 132 |
+
def on_batch_end(self, trainer, batch_idx, batch, loss):
|
| 133 |
+
if self.update_on == 'batch':
|
| 134 |
+
self.scheduler.step()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class TensorBoardLogger(Callback):
|
| 138 |
+
"""TensorBoard日志记录器"""
|
| 139 |
+
def __init__(self, log_dir: str = './logs'):
|
| 140 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 141 |
+
self.writer = SummaryWriter(log_dir=log_dir)
|
| 142 |
+
self.global_step = 0
|
| 143 |
+
|
| 144 |
+
def on_batch_end(self, trainer, batch_idx, batch, loss):
|
| 145 |
+
self.writer.add_scalar('train/loss', loss, self.global_step)
|
| 146 |
+
self.writer.add_scalar('train/lr', trainer.optimizer.param_groups[0]['lr'], self.global_step)
|
| 147 |
+
self.global_step += 1
|
| 148 |
+
|
| 149 |
+
def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
|
| 150 |
+
self.writer.add_scalar('epoch/train_loss', train_loss, epoch)
|
| 151 |
+
self.writer.add_scalar('epoch/val_loss', val_loss, epoch)
|
| 152 |
+
|
| 153 |
+
def on_train_end(self, trainer):
|
| 154 |
+
self.writer.close()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class SampleGeneratorCallback(Callback):
|
| 158 |
+
"""样本生成回调"""
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
sample_freq: int = 500,
|
| 162 |
+
num_samples: int = 4,
|
| 163 |
+
save_dir: str = './samples'
|
| 164 |
+
):
|
| 165 |
+
self.sample_freq = sample_freq
|
| 166 |
+
self.num_samples = num_samples
|
| 167 |
+
self.save_dir = save_dir
|
| 168 |
+
|
| 169 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 170 |
+
|
| 171 |
+
def on_batch_end(self, trainer, batch_idx, batch, loss):
|
| 172 |
+
if trainer.global_step % self.sample_freq != 0:
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
# 生成样本
|
| 176 |
+
trainer.model.eval()
|
| 177 |
+
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
# 使用验证集的提示
|
| 180 |
+
sample_batch = next(iter(trainer.val_loader))
|
| 181 |
+
text_embeddings = sample_batch['text_embeddings'][:self.num_samples].to(trainer.device)
|
| 182 |
+
|
| 183 |
+
# 生成潜在表示
|
| 184 |
+
latents = trainer.diffusion.generate(
|
| 185 |
+
context=text_embeddings,
|
| 186 |
+
num_samples=self.num_samples,
|
| 187 |
+
guidance_scale=7.5
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# 保存样本
|
| 191 |
+
for i in range(self.num_samples):
|
| 192 |
+
sample_path = os.path.join(
|
| 193 |
+
self.save_dir,
|
| 194 |
+
f'step_{trainer.global_step}_sample_{i}.pt'
|
| 195 |
+
)
|
| 196 |
+
torch.save(latents[i].cpu(), sample_path)
|
| 197 |
+
|
| 198 |
+
trainer.model.train()
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class MemoryMonitorCallback(Callback):
|
| 202 |
+
"""内存监控回调"""
|
| 203 |
+
def __init__(self, monitor_freq: int = 100):
|
| 204 |
+
self.monitor_freq = monitor_freq
|
| 205 |
+
|
| 206 |
+
def on_batch_end(self, trainer, batch_idx, batch, loss):
|
| 207 |
+
if trainer.global_step % self.monitor_freq == 0:
|
| 208 |
+
if hasattr(trainer, 'memory_manager'):
|
| 209 |
+
trainer.memory_manager.print_memory_stats()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class GradientMonitorCallback(Callback):
|
| 213 |
+
"""梯度监控回调"""
|
| 214 |
+
def __init__(self, monitor_freq: int = 100):
|
| 215 |
+
self.monitor_freq = monitor_freq
|
| 216 |
+
|
| 217 |
+
def on_batch_end(self, trainer, batch_idx, batch, loss):
|
| 218 |
+
if trainer.global_step % self.monitor_freq == 0:
|
| 219 |
+
grad_norm = self._compute_gradient_norm(trainer.model)
|
| 220 |
+
|
| 221 |
+
if hasattr(trainer, 'writer'):
|
| 222 |
+
trainer.writer.add_scalar('train/grad_norm', grad_norm, trainer.global_step)
|
| 223 |
+
|
| 224 |
+
def _compute_gradient_norm(self, model) -> float:
|
| 225 |
+
total_norm = 0.0
|
| 226 |
+
for p in model.parameters():
|
| 227 |
+
if p.grad is not None:
|
| 228 |
+
param_norm = p.grad.data.norm(2)
|
| 229 |
+
total_norm += param_norm.item() ** 2
|
| 230 |
+
return total_norm ** 0.5
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class CallbackHandler:
|
| 234 |
+
"""回调处理器"""
|
| 235 |
+
def __init__(self):
|
| 236 |
+
self.callbacks = []
|
| 237 |
+
|
| 238 |
+
def add_callback(self, callback: Callback):
|
| 239 |
+
self.callbacks.append(callback)
|
| 240 |
+
|
| 241 |
+
def on_train_begin(self, trainer):
|
| 242 |
+
for callback in self.callbacks:
|
| 243 |
+
callback.on_train_begin(trainer)
|
| 244 |
+
|
| 245 |
+
def on_train_end(self, trainer):
|
| 246 |
+
for callback in self.callbacks:
|
| 247 |
+
callback.on_train_end(trainer)
|
| 248 |
+
|
| 249 |
+
def on_epoch_begin(self, trainer, epoch):
|
| 250 |
+
for callback in self.callbacks:
|
| 251 |
+
callback.on_epoch_begin(trainer, epoch)
|
| 252 |
+
|
| 253 |
+
def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
|
| 254 |
+
for callback in self.callbacks:
|
| 255 |
+
callback.on_epoch_end(trainer, epoch, train_loss, val_loss)
|
| 256 |
+
|
| 257 |
+
def on_batch_begin(self, trainer, batch_idx, batch):
|
| 258 |
+
for callback in self.callbacks:
|
| 259 |
+
callback.on_batch_begin(trainer, batch_idx, batch)
|
| 260 |
+
|
| 261 |
+
def on_batch_end(self, trainer, batch_idx, batch, loss):
|
| 262 |
+
for callback in self.callbacks:
|
| 263 |
+
callback.on_batch_end(trainer, batch_idx, batch, loss)
|
| 264 |
+
|
| 265 |
+
def on_validation_begin(self, trainer):
|
| 266 |
+
for callback in self.callbacks:
|
| 267 |
+
callback.on_validation_begin(trainer)
|
| 268 |
+
|
| 269 |
+
def on_validation_end(self, trainer, val_loss):
|
| 270 |
+
for callback in self.callbacks:
|
| 271 |
+
callback.on_validation_end(trainer, val_loss)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def create_default_callbacks(config: dict) -> CallbackHandler:
|
| 275 |
+
"""创建默认回调"""
|
| 276 |
+
handler = CallbackHandler()
|
| 277 |
+
|
| 278 |
+
# 模型检查点
|
| 279 |
+
checkpoint_callback = ModelCheckpoint(
|
| 280 |
+
save_dir=config.get('checkpoint_dir', './checkpoints'),
|
| 281 |
+
save_best_only=config.get('save_best_model', True),
|
| 282 |
+
save_freq=config.get('save_checkpoint_every', 1),
|
| 283 |
+
monitor='val_loss',
|
| 284 |
+
mode='min'
|
| 285 |
+
)
|
| 286 |
+
handler.add_callback(checkpoint_callback)
|
| 287 |
+
|
| 288 |
+
# TensorBoard日志
|
| 289 |
+
if config.get('use_tensorboard', True):
|
| 290 |
+
tb_logger = TensorBoardLogger(
|
| 291 |
+
log_dir=config.get('log_dir', './logs')
|
| 292 |
+
)
|
| 293 |
+
handler.add_callback(tb_logger)
|
| 294 |
+
|
| 295 |
+
# 样本生成
|
| 296 |
+
if config.get('sample_steps', 500) > 0:
|
| 297 |
+
sample_callback = SampleGeneratorCallback(
|
| 298 |
+
sample_freq=config.get('sample_steps', 500),
|
| 299 |
+
num_samples=4,
|
| 300 |
+
save_dir=config.get('sample_dir', './samples')
|
| 301 |
+
)
|
| 302 |
+
handler.add_callback(sample_callback)
|
| 303 |
+
|
| 304 |
+
# 内存监控
|
| 305 |
+
memory_callback = MemoryMonitorCallback(
|
| 306 |
+
monitor_freq=config.get('log_steps', 50)
|
| 307 |
+
)
|
| 308 |
+
handler.add_callback(memory_callback)
|
| 309 |
+
|
| 310 |
+
# 梯度监控
|
| 311 |
+
grad_callback = GradientMonitorCallback(
|
| 312 |
+
monitor_freq=config.get('log_steps', 50)
|
| 313 |
+
)
|
| 314 |
+
handler.add_callback(grad_callback)
|
| 315 |
+
|
| 316 |
+
# 早停
|
| 317 |
+
if config.get('early_stopping', False):
|
| 318 |
+
early_stop = EarlyStopping(
|
| 319 |
+
patience=config.get('early_stopping_patience', 10),
|
| 320 |
+
min_delta=config.get('early_stopping_min_delta', 1e-4)
|
| 321 |
+
)
|
| 322 |
+
handler.add_callback(early_stop)
|
| 323 |
+
|
| 324 |
+
return handler
|
src/training/memory_manager.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import psutil
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CPUMemoryManager:
|
| 9 |
+
"""CPU内存管理器"""
|
| 10 |
+
def __init__(self, warning_threshold: float = 0.9):
|
| 11 |
+
"""
|
| 12 |
+
参数:
|
| 13 |
+
warning_threshold: 内存使用率警告阈值 (0-1)
|
| 14 |
+
"""
|
| 15 |
+
self.warning_threshold = warning_threshold
|
| 16 |
+
|
| 17 |
+
def get_memory_usage(self) -> tuple:
|
| 18 |
+
"""获取内存使用情况"""
|
| 19 |
+
process = psutil.Process(os.getpid())
|
| 20 |
+
memory_info = process.memory_info()
|
| 21 |
+
|
| 22 |
+
# 获取系统内存信息
|
| 23 |
+
system_memory = psutil.virtual_memory()
|
| 24 |
+
|
| 25 |
+
return {
|
| 26 |
+
'process_rss_mb': memory_info.rss / 1024 / 1024,
|
| 27 |
+
'process_vms_mb': memory_info.vms / 1024 / 1024,
|
| 28 |
+
'system_total_mb': system_memory.total / 1024 / 1024,
|
| 29 |
+
'system_available_mb': system_memory.available / 1024 / 1024,
|
| 30 |
+
'system_used_percent': system_memory.percent
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def check_memory(self) -> bool:
|
| 34 |
+
"""检查内存使用是否安全"""
|
| 35 |
+
memory_info = self.get_memory_usage()
|
| 36 |
+
|
| 37 |
+
if memory_info['system_used_percent'] > self.warning_threshold * 100:
|
| 38 |
+
print(f"警告: 系统内存使用率过高: {memory_info['system_used_percent']:.1f}%")
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class OptimizerCPUOffload:
|
| 45 |
+
"""优化器状态CPU卸载"""
|
| 46 |
+
def __init__(self, optimizer: torch.optim.Optimizer):
|
| 47 |
+
self.optimizer = optimizer
|
| 48 |
+
self.original_states = {}
|
| 49 |
+
|
| 50 |
+
def offload_to_cpu(self):
|
| 51 |
+
"""将优化器状态卸载到CPU"""
|
| 52 |
+
for param_group in self.optimizer.param_groups:
|
| 53 |
+
for param in param_group['params']:
|
| 54 |
+
if param in self.optimizer.state:
|
| 55 |
+
state = self.optimizer.state[param]
|
| 56 |
+
for key in list(state.keys()):
|
| 57 |
+
if torch.is_tensor(state[key]):
|
| 58 |
+
# 移动到CPU并保留引用
|
| 59 |
+
self.original_states[(param, key)] = state[key]
|
| 60 |
+
state[key] = state[key].cpu()
|
| 61 |
+
|
| 62 |
+
def load_to_gpu(self, device: torch.device):
|
| 63 |
+
"""将优化器状态加载回GPU"""
|
| 64 |
+
for param_group in self.optimizer.param_groups:
|
| 65 |
+
for param in param_group['params']:
|
| 66 |
+
if param in self.optimizer.state:
|
| 67 |
+
state = self.optimizer.state[param]
|
| 68 |
+
for key in list(state.keys()):
|
| 69 |
+
if (param, key) in self.original_states:
|
| 70 |
+
state[key] = self.original_states[(param, key)].to(device)
|
| 71 |
+
del self.original_states[(param, key)]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ActivationCPUOffload:
|
| 75 |
+
"""激活值CPU卸载"""
|
| 76 |
+
def __init__(self, model: torch.nn.Module):
|
| 77 |
+
self.model = model
|
| 78 |
+
self.hooks = []
|
| 79 |
+
|
| 80 |
+
def register_hooks(self):
|
| 81 |
+
"""注册前向钩子来卸载激活值"""
|
| 82 |
+
def hook_fn(module, input, output):
|
| 83 |
+
if torch.is_tensor(output):
|
| 84 |
+
return output.cpu()
|
| 85 |
+
elif isinstance(output, tuple):
|
| 86 |
+
return tuple(x.cpu() if torch.is_tensor(x) else x for x in output)
|
| 87 |
+
return output
|
| 88 |
+
|
| 89 |
+
# 为每个模块注册钩子
|
| 90 |
+
for name, module in self.model.named_modules():
|
| 91 |
+
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm)):
|
| 92 |
+
hook = module.register_forward_hook(hook_fn)
|
| 93 |
+
self.hooks.append(hook)
|
| 94 |
+
|
| 95 |
+
def remove_hooks(self):
|
| 96 |
+
"""移除所有钩子"""
|
| 97 |
+
for hook in self.hooks:
|
| 98 |
+
hook.remove()
|
| 99 |
+
self.hooks = []
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MemoryOptimizer:
|
| 103 |
+
"""综合内存优化器"""
|
| 104 |
+
def __init__(self, config: dict):
|
| 105 |
+
self.config = config
|
| 106 |
+
|
| 107 |
+
# GPU内存管理
|
| 108 |
+
self.gpu_warning_threshold = config.get('warning_threshold_gb', 6.0) * 1024**3
|
| 109 |
+
self.gpu_critical_threshold = config.get('memory_threshold_gb', 6.5) * 1024**3
|
| 110 |
+
|
| 111 |
+
# CPU内存管理
|
| 112 |
+
self.cpu_manager = CPUMemoryManager(
|
| 113 |
+
warning_threshold=config.get('cpu_warning_threshold', 0.85)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# 清理频率
|
| 117 |
+
self.cleanup_frequency = config.get('cleanup_frequency', 100)
|
| 118 |
+
|
| 119 |
+
# 状态跟踪
|
| 120 |
+
self.optimizer_offloader = None
|
| 121 |
+
self.activation_offloader = None
|
| 122 |
+
|
| 123 |
+
def setup_model_optimizations(self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None):
|
| 124 |
+
"""设置模型优化"""
|
| 125 |
+
# 启用梯度检查点
|
| 126 |
+
if hasattr(model, 'enable_gradient_checkpointing'):
|
| 127 |
+
model.enable_gradient_checkpointing()
|
| 128 |
+
|
| 129 |
+
# 设置优化器CPU卸载
|
| 130 |
+
if optimizer is not None and self.config.get('optimizer_on_cpu', True):
|
| 131 |
+
self.optimizer_offloader = OptimizerCPUOffload(optimizer)
|
| 132 |
+
|
| 133 |
+
# 设置激活值CPU卸载
|
| 134 |
+
if self.config.get('cpu_offload', True):
|
| 135 |
+
self.activation_offloader = ActivationCPUOffload(model)
|
| 136 |
+
self.activation_offloader.register_hooks()
|
| 137 |
+
|
| 138 |
+
# 设置注意力分片
|
| 139 |
+
if self.config.get('attention_slicing', 'auto') == 'auto':
|
| 140 |
+
self._enable_attention_slicing(model)
|
| 141 |
+
|
| 142 |
+
def _enable_attention_slicing(self, model: torch.nn.Module):
|
| 143 |
+
"""启用注意力分片"""
|
| 144 |
+
for module in model.modules():
|
| 145 |
+
if hasattr(module, 'set_attention_slice'):
|
| 146 |
+
module.set_attention_slice('auto')
|
| 147 |
+
|
| 148 |
+
def step_start(self):
|
| 149 |
+
"""训练步骤开始时的内存管理"""
|
| 150 |
+
# 将优化器状态加载到GPU(如果需要)
|
| 151 |
+
if self.optimizer_offloader is not None:
|
| 152 |
+
device = next(self.optimizer_offloader.optimizer.param_groups[0]['params'][0].device)
|
| 153 |
+
self.optimizer_offloader.load_to_gpu(device)
|
| 154 |
+
|
| 155 |
+
# 检查内存
|
| 156 |
+
self.check_all_memory()
|
| 157 |
+
|
| 158 |
+
def step_end(self, step: int):
|
| 159 |
+
"""训练步骤结束时的内存管理"""
|
| 160 |
+
# 将优化器状态卸载到CPU
|
| 161 |
+
if self.optimizer_offloader is not None:
|
| 162 |
+
self.optimizer_offloader.offload_to_cpu()
|
| 163 |
+
|
| 164 |
+
# 定期清理
|
| 165 |
+
if step % self.cleanup_frequency == 0:
|
| 166 |
+
self.cleanup()
|
| 167 |
+
|
| 168 |
+
# 检查内存
|
| 169 |
+
self.check_all_memory()
|
| 170 |
+
|
| 171 |
+
def check_all_memory(self):
|
| 172 |
+
"""检查所有内存"""
|
| 173 |
+
# 检查GPU内存
|
| 174 |
+
gpu_allocated = torch.cuda.memory_allocated()
|
| 175 |
+
if gpu_allocated > self.gpu_critical_threshold:
|
| 176 |
+
self._handle_gpu_oom()
|
| 177 |
+
elif gpu_allocated > self.gpu_warning_threshold:
|
| 178 |
+
print(f"GPU内存警告: {gpu_allocated / 1024**3:.2f} GB")
|
| 179 |
+
|
| 180 |
+
# 检查CPU内存
|
| 181 |
+
if not self.cpu_manager.check_memory():
|
| 182 |
+
self._handle_cpu_oom()
|
| 183 |
+
|
| 184 |
+
def _handle_gpu_oom(self):
|
| 185 |
+
"""处理GPU OOM"""
|
| 186 |
+
print("GPU内存不足,尝试清理...")
|
| 187 |
+
self.cleanup(force=True)
|
| 188 |
+
|
| 189 |
+
# 如果仍然不足,抛出异常
|
| 190 |
+
if torch.cuda.memory_allocated() > self.gpu_critical_threshold:
|
| 191 |
+
raise RuntimeError("GPU内存不足,无法继续训练")
|
| 192 |
+
|
| 193 |
+
def _handle_cpu_oom(self):
|
| 194 |
+
"""处理CPU OOM"""
|
| 195 |
+
print("CPU内存不足,尝试清理...")
|
| 196 |
+
gc.collect()
|
| 197 |
+
|
| 198 |
+
def cleanup(self, force: bool = False):
|
| 199 |
+
"""清理内存"""
|
| 200 |
+
gc.collect()
|
| 201 |
+
|
| 202 |
+
if torch.cuda.is_available():
|
| 203 |
+
torch.cuda.empty_cache()
|
| 204 |
+
|
| 205 |
+
# 如果强制清理,尝试更激进的清理
|
| 206 |
+
if force:
|
| 207 |
+
torch.cuda.synchronize()
|
| 208 |
+
torch.cuda.ipc_collect()
|
| 209 |
+
|
| 210 |
+
def get_memory_stats(self) -> dict:
|
| 211 |
+
"""获取内存统计信息"""
|
| 212 |
+
stats = {}
|
| 213 |
+
|
| 214 |
+
# GPU统计
|
| 215 |
+
if torch.cuda.is_available():
|
| 216 |
+
stats['gpu'] = {
|
| 217 |
+
'allocated_gb': torch.cuda.memory_allocated() / 1024**3,
|
| 218 |
+
'reserved_gb': torch.cuda.memory_reserved() / 1024**3,
|
| 219 |
+
'max_allocated_gb': torch.cuda.max_memory_allocated() / 1024**3,
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
# CPU统计
|
| 223 |
+
cpu_stats = self.cpu_manager.get_memory_usage()
|
| 224 |
+
stats['cpu'] = cpu_stats
|
| 225 |
+
|
| 226 |
+
return stats
|
| 227 |
+
|
| 228 |
+
def print_memory_stats(self):
|
| 229 |
+
"""打印内存统计信息"""
|
| 230 |
+
stats = self.get_memory_stats()
|
| 231 |
+
|
| 232 |
+
print("=" * 50)
|
| 233 |
+
print("内存使用统计:")
|
| 234 |
+
|
| 235 |
+
if 'gpu' in stats:
|
| 236 |
+
gpu = stats['gpu']
|
| 237 |
+
print(f"GPU - 已分配: {gpu['allocated_gb']:.2f} GB, "
|
| 238 |
+
f"已保留: {gpu['reserved_gb']:.2f} GB, "
|
| 239 |
+
f"最大分配: {gpu['max_allocated_gb']:.2f} GB")
|
| 240 |
+
|
| 241 |
+
if 'cpu' in stats:
|
| 242 |
+
cpu = stats['cpu']
|
| 243 |
+
print(f"CPU - 进程RSS: {cpu['process_rss_mb']:.1f} MB, "
|
| 244 |
+
f"系统使用率: {cpu['system_used_percent']:.1f}%")
|
| 245 |
+
print("=" * 50)
|
src/training/trainer_p4.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from typing import Optional, Dict, Any
|
| 6 |
+
import wandb
|
| 7 |
+
import os
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MemoryManager:
|
| 13 |
+
"""P4显存管理器"""
|
| 14 |
+
def __init__(self, config: dict):
|
| 15 |
+
self.config = config
|
| 16 |
+
self.warning_threshold = config.get('warning_threshold_gb', 6.0) * 1024**3
|
| 17 |
+
self.critical_threshold = config.get('memory_threshold_gb', 6.5) * 1024**3
|
| 18 |
+
self.cleanup_frequency = config.get('cleanup_frequency', 100)
|
| 19 |
+
|
| 20 |
+
def check_memory(self, step: int):
|
| 21 |
+
"""检查显存使用情况"""
|
| 22 |
+
if step % self.cleanup_frequency == 0:
|
| 23 |
+
self.cleanup()
|
| 24 |
+
|
| 25 |
+
allocated = torch.cuda.memory_allocated()
|
| 26 |
+
if allocated > self.critical_threshold:
|
| 27 |
+
raise RuntimeError(f"显存超出临界阈值: {allocated / 1024**3:.2f} GB")
|
| 28 |
+
elif allocated > self.warning_threshold:
|
| 29 |
+
print(f"警告: 显存使用较高: {allocated / 1024**3:.2f} GB")
|
| 30 |
+
|
| 31 |
+
def cleanup(self):
|
| 32 |
+
"""清理显存"""
|
| 33 |
+
import gc
|
| 34 |
+
gc.collect()
|
| 35 |
+
torch.cuda.empty_cache()
|
| 36 |
+
|
| 37 |
+
def auto_adjust_batch_size(self, model: nn.Module, data_shape: tuple) -> int:
|
| 38 |
+
"""自动调整批次大小"""
|
| 39 |
+
max_batch = 1
|
| 40 |
+
device = next(model.parameters()).device
|
| 41 |
+
|
| 42 |
+
for batch_size in [1, 2, 4, 8]:
|
| 43 |
+
try:
|
| 44 |
+
# 测试内存
|
| 45 |
+
dummy_input = torch.randn(batch_size, *data_shape, device=device)
|
| 46 |
+
dummy_timestep = torch.randint(0, 1000, (batch_size,), device=device)
|
| 47 |
+
dummy_context = torch.randn(batch_size, 77, 768, device=device)
|
| 48 |
+
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
_ = model(dummy_input, dummy_timestep, dummy_context)
|
| 51 |
+
|
| 52 |
+
torch.cuda.empty_cache()
|
| 53 |
+
max_batch = batch_size
|
| 54 |
+
except RuntimeError as e:
|
| 55 |
+
if "CUDA out of memory" in str(e):
|
| 56 |
+
break
|
| 57 |
+
else:
|
| 58 |
+
raise e
|
| 59 |
+
|
| 60 |
+
return max_batch
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GradientAccumulationScheduler:
|
| 64 |
+
"""梯度累积调度器"""
|
| 65 |
+
def __init__(self, config: dict):
|
| 66 |
+
self.initial_steps = config.get('gradient_accumulation_steps', 8)
|
| 67 |
+
self.current_steps = self.initial_steps
|
| 68 |
+
self.warmup_epochs = config.get('warmup_epochs', 5)
|
| 69 |
+
|
| 70 |
+
def update(self, epoch: int):
|
| 71 |
+
"""根据epoch更新累积步数"""
|
| 72 |
+
if epoch < self.warmup_epochs:
|
| 73 |
+
self.current_steps = self.initial_steps
|
| 74 |
+
else:
|
| 75 |
+
# 逐步减少累积步数以加快训练
|
| 76 |
+
self.current_steps = max(4, self.current_steps // 2)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class P4Trainer:
|
| 80 |
+
"""针对P4优化的训练器"""
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
model: nn.Module,
|
| 84 |
+
diffusion: DiffusionProcess,
|
| 85 |
+
optimizer: torch.optim.Optimizer,
|
| 86 |
+
train_loader: DataLoader,
|
| 87 |
+
val_loader: Optional[DataLoader],
|
| 88 |
+
config: dict,
|
| 89 |
+
device: torch.device
|
| 90 |
+
):
|
| 91 |
+
self.model = model
|
| 92 |
+
self.diffusion = diffusion
|
| 93 |
+
self.optimizer = optimizer
|
| 94 |
+
self.train_loader = train_loader
|
| 95 |
+
self.val_loader = val_loader
|
| 96 |
+
self.config = config
|
| 97 |
+
self.device = device
|
| 98 |
+
|
| 99 |
+
# 训练状态
|
| 100 |
+
self.current_epoch = 0
|
| 101 |
+
self.global_step = 0
|
| 102 |
+
self.best_loss = float('inf')
|
| 103 |
+
|
| 104 |
+
# 初始化工具
|
| 105 |
+
self.memory_manager = MemoryManager(config)
|
| 106 |
+
self.grad_scheduler = GradientAccumulationScheduler(config)
|
| 107 |
+
|
| 108 |
+
# 混合精度训练
|
| 109 |
+
self.use_amp = config.get('mixed_precision', 'fp16') != 'no'
|
| 110 |
+
self.scaler = GradScaler(enabled=self.use_amp)
|
| 111 |
+
|
| 112 |
+
# 学习率调度器
|
| 113 |
+
self.lr_scheduler = self._create_lr_scheduler(config)
|
| 114 |
+
|
| 115 |
+
# EMA模型
|
| 116 |
+
self.use_ema = config.get('use_ema', True)
|
| 117 |
+
if self.use_ema:
|
| 118 |
+
self.ema_model = self._create_ema_model(model, config.get('ema_decay', 0.9999))
|
| 119 |
+
|
| 120 |
+
# 日志记录
|
| 121 |
+
self.use_wandb = config.get('use_wandb', False)
|
| 122 |
+
self.log_dir = config.get('log_dir', './logs')
|
| 123 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
| 124 |
+
|
| 125 |
+
# 检查点
|
| 126 |
+
self.checkpoint_dir = config.get('checkpoint_dir', './checkpoints')
|
| 127 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 128 |
+
|
| 129 |
+
def _create_lr_scheduler(self, config: dict):
|
| 130 |
+
"""创建学习率调度器"""
|
| 131 |
+
scheduler_type = config.get('learning_rate_scheduler', 'cosine')
|
| 132 |
+
warmup_steps = config.get('warmup_steps', 1000)
|
| 133 |
+
|
| 134 |
+
if scheduler_type == 'cosine':
|
| 135 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 136 |
+
self.optimizer,
|
| 137 |
+
T_max=config.get('max_epochs', 50) * len(self.train_loader),
|
| 138 |
+
eta_min=1e-6
|
| 139 |
+
)
|
| 140 |
+
elif scheduler_type == 'linear':
|
| 141 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 142 |
+
self.optimizer,
|
| 143 |
+
start_factor=0.01,
|
| 144 |
+
total_iters=warmup_steps
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0)
|
| 148 |
+
|
| 149 |
+
return scheduler
|
| 150 |
+
|
| 151 |
+
def _create_ema_model(self, model: nn.Module, decay: float):
|
| 152 |
+
"""创建EMA模型"""
|
| 153 |
+
from torch.optim.swa_utils import AveragedModel
|
| 154 |
+
return AveragedModel(model, device=self.device, avg_fn=lambda avg, new, decay=decay: decay * avg + (1 - decay) * new)
|
| 155 |
+
|
| 156 |
+
def train_epoch(self) -> float:
|
| 157 |
+
"""训练一个epoch"""
|
| 158 |
+
self.model.train()
|
| 159 |
+
total_loss = 0.0
|
| 160 |
+
num_batches = len(self.train_loader)
|
| 161 |
+
|
| 162 |
+
# 梯度累积
|
| 163 |
+
accumulation_steps = self.grad_scheduler.current_steps
|
| 164 |
+
self.optimizer.zero_grad()
|
| 165 |
+
|
| 166 |
+
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}")
|
| 167 |
+
for batch_idx, batch in enumerate(pbar):
|
| 168 |
+
# 将数据移到设备
|
| 169 |
+
images = batch['images'].to(self.device)
|
| 170 |
+
text_embeddings = batch['text_embeddings'].to(self.device)
|
| 171 |
+
|
| 172 |
+
# 混合精度前向传播
|
| 173 |
+
with autocast(enabled=self.use_amp):
|
| 174 |
+
loss = self.diffusion.compute_loss(images, text_embeddings)
|
| 175 |
+
loss = loss / accumulation_steps
|
| 176 |
+
|
| 177 |
+
# 反向传播
|
| 178 |
+
self.scaler.scale(loss).backward()
|
| 179 |
+
|
| 180 |
+
# 梯度累积更新
|
| 181 |
+
if (batch_idx + 1) % accumulation_steps == 0:
|
| 182 |
+
# 梯度裁剪
|
| 183 |
+
self.scaler.unscale_(self.optimizer)
|
| 184 |
+
torch.nn.utils.clip_grad_norm_(
|
| 185 |
+
self.model.parameters(),
|
| 186 |
+
max_norm=self.config.get('gradient_clip', 1.0)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# 更新参数
|
| 190 |
+
self.scaler.step(self.optimizer)
|
| 191 |
+
self.scaler.update()
|
| 192 |
+
self.optimizer.zero_grad()
|
| 193 |
+
|
| 194 |
+
# 更新EMA模型
|
| 195 |
+
if self.use_ema:
|
| 196 |
+
self.ema_model.update_parameters(self.model)
|
| 197 |
+
|
| 198 |
+
# 更新学习率
|
| 199 |
+
self.lr_scheduler.step()
|
| 200 |
+
|
| 201 |
+
self.global_step += 1
|
| 202 |
+
|
| 203 |
+
# 记录损失
|
| 204 |
+
total_loss += loss.item() * accumulation_steps
|
| 205 |
+
current_loss = total_loss / (batch_idx + 1)
|
| 206 |
+
|
| 207 |
+
# 更新进度条
|
| 208 |
+
pbar.set_postfix({
|
| 209 |
+
'loss': f'{current_loss:.4f}',
|
| 210 |
+
'lr': f'{self.optimizer.param_groups[0]["lr"]:.2e}'
|
| 211 |
+
})
|
| 212 |
+
|
| 213 |
+
# 记录日志
|
| 214 |
+
if self.global_step % self.config.get('log_steps', 50) == 0:
|
| 215 |
+
self._log_metrics({
|
| 216 |
+
'train/loss': current_loss,
|
| 217 |
+
'train/lr': self.optimizer.param_groups[0]['lr'],
|
| 218 |
+
'train/grad_norm': self._get_grad_norm(),
|
| 219 |
+
})
|
| 220 |
+
|
| 221 |
+
# 生成样本
|
| 222 |
+
if self.global_step % self.config.get('sample_steps', 500) == 0:
|
| 223 |
+
self._generate_samples()
|
| 224 |
+
|
| 225 |
+
# 显存管理
|
| 226 |
+
self.memory_manager.check_memory(self.global_step)
|
| 227 |
+
|
| 228 |
+
epoch_loss = total_loss / num_batches
|
| 229 |
+
return epoch_loss
|
| 230 |
+
|
| 231 |
+
@torch.no_grad()
|
| 232 |
+
def validate(self) -> float:
|
| 233 |
+
"""验证"""
|
| 234 |
+
if self.val_loader is None:
|
| 235 |
+
return float('inf')
|
| 236 |
+
|
| 237 |
+
self.model.eval()
|
| 238 |
+
total_loss = 0.0
|
| 239 |
+
|
| 240 |
+
for batch in tqdm(self.val_loader, desc="Validation"):
|
| 241 |
+
images = batch['images'].to(self.device)
|
| 242 |
+
text_embeddings = batch['text_embeddings'].to(self.device)
|
| 243 |
+
|
| 244 |
+
with autocast(enabled=self.use_amp):
|
| 245 |
+
loss = self.diffusion.compute_loss(images, text_embeddings)
|
| 246 |
+
|
| 247 |
+
total_loss += loss.item()
|
| 248 |
+
|
| 249 |
+
val_loss = total_loss / len(self.val_loader)
|
| 250 |
+
|
| 251 |
+
# 记录验证指标
|
| 252 |
+
self._log_metrics({'val/loss': val_loss})
|
| 253 |
+
|
| 254 |
+
return val_loss
|
| 255 |
+
|
| 256 |
+
def train(self, num_epochs: Optional[int] = None):
|
| 257 |
+
"""训练循环"""
|
| 258 |
+
if num_epochs is None:
|
| 259 |
+
num_epochs = self.config.get('max_epochs', 50)
|
| 260 |
+
|
| 261 |
+
for epoch in range(self.current_epoch, num_epochs):
|
| 262 |
+
self.current_epoch = epoch
|
| 263 |
+
|
| 264 |
+
# 更新梯度累积策略
|
| 265 |
+
self.grad_scheduler.update(epoch)
|
| 266 |
+
|
| 267 |
+
# 训练一个epoch
|
| 268 |
+
train_loss = self.train_epoch()
|
| 269 |
+
|
| 270 |
+
# 验证
|
| 271 |
+
val_loss = self.validate()
|
| 272 |
+
|
| 273 |
+
# 保存最佳模型
|
| 274 |
+
if val_loss < self.best_loss:
|
| 275 |
+
self.best_loss = val_loss
|
| 276 |
+
self.save_checkpoint('best_model.pt')
|
| 277 |
+
|
| 278 |
+
# 定期保存检查点
|
| 279 |
+
if (epoch + 1) % self.config.get('save_checkpoint_every', 5) == 0:
|
| 280 |
+
self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pt')
|
| 281 |
+
|
| 282 |
+
# 记录epoch指标
|
| 283 |
+
self._log_metrics({
|
| 284 |
+
'epoch/train_loss': train_loss,
|
| 285 |
+
'epoch/val_loss': val_loss,
|
| 286 |
+
'epoch': epoch
|
| 287 |
+
})
|
| 288 |
+
|
| 289 |
+
print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
|
| 290 |
+
|
| 291 |
+
def _get_grad_norm(self) -> float:
|
| 292 |
+
"""计算梯度范数"""
|
| 293 |
+
total_norm = 0.0
|
| 294 |
+
for p in self.model.parameters():
|
| 295 |
+
if p.grad is not None:
|
| 296 |
+
param_norm = p.grad.data.norm(2)
|
| 297 |
+
total_norm += param_norm.item() ** 2
|
| 298 |
+
return total_norm ** 0.5
|
| 299 |
+
|
| 300 |
+
@torch.no_grad()
|
| 301 |
+
def _generate_samples(self):
|
| 302 |
+
"""生成样本用于监控"""
|
| 303 |
+
self.model.eval()
|
| 304 |
+
|
| 305 |
+
# 使用验证集的前几个提示
|
| 306 |
+
sample_batch = next(iter(self.val_loader))
|
| 307 |
+
text_embeddings = sample_batch['text_embeddings'][:4].to(self.device)
|
| 308 |
+
|
| 309 |
+
# 生成样本
|
| 310 |
+
with autocast(enabled=self.use_amp):
|
| 311 |
+
latents = self.diffusion.generate(
|
| 312 |
+
context=text_embeddings,
|
| 313 |
+
num_samples=4,
|
| 314 |
+
guidance_scale=7.5
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# 解码为图像
|
| 318 |
+
# 这里需要VAE解码器,暂时保存潜在表示
|
| 319 |
+
if self.use_wandb:
|
| 320 |
+
wandb.log({
|
| 321 |
+
'samples/latents': wandb.Image(latents[0].cpu().numpy())
|
| 322 |
+
})
|
| 323 |
+
|
| 324 |
+
def _log_metrics(self, metrics: Dict[str, Any]):
|
| 325 |
+
"""记录指标"""
|
| 326 |
+
if self.use_wandb:
|
| 327 |
+
wandb.log(metrics)
|
| 328 |
+
|
| 329 |
+
# 同时记录到本地文件
|
| 330 |
+
log_file = os.path.join(self.log_dir, 'training_log.csv')
|
| 331 |
+
with open(log_file, 'a') as f:
|
| 332 |
+
if self.global_step == 0:
|
| 333 |
+
header = ','.join(['step'] + list(metrics.keys()))
|
| 334 |
+
f.write(header + '\n')
|
| 335 |
+
|
| 336 |
+
values = ','.join([str(self.global_step)] + [str(v) for v in metrics.values()])
|
| 337 |
+
f.write(values + '\n')
|
| 338 |
+
|
| 339 |
+
def save_checkpoint(self, filename: str):
|
| 340 |
+
"""保存检查点"""
|
| 341 |
+
checkpoint = {
|
| 342 |
+
'epoch': self.current_epoch,
|
| 343 |
+
'global_step': self.global_step,
|
| 344 |
+
'model_state_dict': self.model.state_dict(),
|
| 345 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 346 |
+
'scaler_state_dict': self.scaler.state_dict(),
|
| 347 |
+
'best_loss': self.best_loss,
|
| 348 |
+
'config': self.config
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
if self.use_ema:
|
| 352 |
+
checkpoint['ema_model_state_dict'] = self.ema_model.state_dict()
|
| 353 |
+
|
| 354 |
+
save_path = os.path.join(self.checkpoint_dir, filename)
|
| 355 |
+
torch.save(checkpoint, save_path)
|
| 356 |
+
|
| 357 |
+
# 如果启用压缩,保存压缩版本
|
| 358 |
+
if self.config.get('save_compressed', True):
|
| 359 |
+
torch.save(checkpoint, save_path.replace('.pt', '_compressed.pt'), _use_new_zipfile_serialization=True)
|
| 360 |
+
|
| 361 |
+
print(f"检查点已保存: {save_path}")
|
| 362 |
+
|
| 363 |
+
def load_checkpoint(self, checkpoint_path: str):
|
| 364 |
+
"""加载检查点"""
|
| 365 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 366 |
+
|
| 367 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 368 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 369 |
+
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 370 |
+
|
| 371 |
+
if self.use_ema and 'ema_model_state_dict' in checkpoint:
|
| 372 |
+
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
| 373 |
+
|
| 374 |
+
self.current_epoch = checkpoint['epoch']
|
| 375 |
+
self.global_step = checkpoint['global_step']
|
| 376 |
+
self.best_loss = checkpoint['best_loss']
|
| 377 |
+
|
| 378 |
+
print(f"已加载检查点: {checkpoint_path}")
|
tests/test_basic.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
基础测试
|
| 4 |
+
测试项目的基本功能
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# 添加项目根目录到Python路径
|
| 14 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
| 15 |
+
|
| 16 |
+
from src.models.unet_light import UNetLight, TimestepEmbedding, ResNetBlock
|
| 17 |
+
from src.models.attention import MemoryEfficientAttention
|
| 18 |
+
from src.models.diffusion import DiffusionProcess
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_timestep_embedding():
|
| 22 |
+
"""测试时间步嵌入"""
|
| 23 |
+
print("测试时间步嵌入...")
|
| 24 |
+
|
| 25 |
+
embedding_dim = 256
|
| 26 |
+
time_embed_dim = 512
|
| 27 |
+
|
| 28 |
+
embedder = TimestepEmbedding(embedding_dim, time_embed_dim)
|
| 29 |
+
|
| 30 |
+
# 测试前向传播
|
| 31 |
+
timesteps = torch.tensor([100, 200, 300])
|
| 32 |
+
embeddings = embedder(timesteps)
|
| 33 |
+
|
| 34 |
+
assert embeddings.shape == (3, time_embed_dim)
|
| 35 |
+
print(f" 形状正确: {embeddings.shape}")
|
| 36 |
+
|
| 37 |
+
return embedder
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_resnet_block():
|
| 41 |
+
"""测试残差块"""
|
| 42 |
+
print("测试残差块...")
|
| 43 |
+
|
| 44 |
+
in_channels = 64
|
| 45 |
+
out_channels = 128
|
| 46 |
+
time_embed_dim = 256
|
| 47 |
+
|
| 48 |
+
block = ResNetBlock(in_channels, out_channels, time_embed_dim)
|
| 49 |
+
|
| 50 |
+
# 测试前向传播
|
| 51 |
+
x = torch.randn(2, in_channels, 32, 32)
|
| 52 |
+
time_emb = torch.randn(2, time_embed_dim)
|
| 53 |
+
|
| 54 |
+
output = block(x, time_emb)
|
| 55 |
+
|
| 56 |
+
assert output.shape == (2, out_channels, 32, 32)
|
| 57 |
+
print(f" 形状正确: {output.shape}")
|
| 58 |
+
|
| 59 |
+
# 测试跳跃连接
|
| 60 |
+
block_same = ResNetBlock(in_channels, in_channels, time_embed_dim)
|
| 61 |
+
output_same = block_same(x, time_emb)
|
| 62 |
+
|
| 63 |
+
assert output_same.shape == x.shape
|
| 64 |
+
print(f" 跳跃连接正确")
|
| 65 |
+
|
| 66 |
+
return block
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_attention():
|
| 70 |
+
"""测试注意力机制"""
|
| 71 |
+
print("测试注意力机制...")
|
| 72 |
+
|
| 73 |
+
dim = 256
|
| 74 |
+
num_heads = 8
|
| 75 |
+
|
| 76 |
+
attention = MemoryEfficientAttention(dim, num_heads)
|
| 77 |
+
|
| 78 |
+
# 测试前向传播
|
| 79 |
+
x = torch.randn(2, 16, dim) # [batch, seq_len, dim]
|
| 80 |
+
output = attention(x)
|
| 81 |
+
|
| 82 |
+
assert output.shape == x.shape
|
| 83 |
+
print(f" 形状正确: {output.shape}")
|
| 84 |
+
|
| 85 |
+
return attention
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def test_unet_light():
|
| 89 |
+
"""测试轻量UNet"""
|
| 90 |
+
print("测试轻量UNet...")
|
| 91 |
+
|
| 92 |
+
config = {
|
| 93 |
+
'model': {
|
| 94 |
+
'in_channels': 4,
|
| 95 |
+
'out_channels': 4,
|
| 96 |
+
'base_channels': 32, # 测试用小模型
|
| 97 |
+
'channel_mults': [1, 2, 4],
|
| 98 |
+
'num_res_blocks': 1,
|
| 99 |
+
'attention_resolutions': [8],
|
| 100 |
+
'dropout': 0.0,
|
| 101 |
+
'use_checkpoint': False,
|
| 102 |
+
'num_heads': 4,
|
| 103 |
+
'context_dim': 256,
|
| 104 |
+
'use_linear_projection': True,
|
| 105 |
+
'time_embed_dim': 128
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
model = UNetLight(config)
|
| 110 |
+
|
| 111 |
+
# 测试前向传播
|
| 112 |
+
batch_size = 2
|
| 113 |
+
x = torch.randn(batch_size, 4, 64, 64)
|
| 114 |
+
timesteps = torch.randint(0, 1000, (batch_size,))
|
| 115 |
+
context = torch.randn(batch_size, 77, 256)
|
| 116 |
+
|
| 117 |
+
output = model(x, timesteps, context)
|
| 118 |
+
|
| 119 |
+
assert output.shape == x.shape
|
| 120 |
+
print(f" 形状正确: {output.shape}")
|
| 121 |
+
|
| 122 |
+
# 测试梯度检查点
|
| 123 |
+
model.enable_gradient_checkpointing()
|
| 124 |
+
print(f" 梯度检查点已启用")
|
| 125 |
+
|
| 126 |
+
return model
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_diffusion_process():
|
| 130 |
+
"""测试扩散过程"""
|
| 131 |
+
print("测试扩散过程...")
|
| 132 |
+
|
| 133 |
+
config = {
|
| 134 |
+
'diffusion': {
|
| 135 |
+
'beta_schedule': 'linear',
|
| 136 |
+
'beta_start': 0.0001,
|
| 137 |
+
'beta_end': 0.02,
|
| 138 |
+
'num_train_timesteps': 100,
|
| 139 |
+
'num_inference_timesteps': 20
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
diffusion = DiffusionProcess(config)
|
| 144 |
+
|
| 145 |
+
# 测试前向扩散
|
| 146 |
+
x_start = torch.randn(2, 3, 32, 32)
|
| 147 |
+
t = torch.randint(0, 100, (2,))
|
| 148 |
+
|
| 149 |
+
x_noisy = diffusion.q_sample(x_start, t)
|
| 150 |
+
|
| 151 |
+
assert x_noisy.shape == x_start.shape
|
| 152 |
+
print(f" 前向扩散形状正确: {x_noisy.shape}")
|
| 153 |
+
|
| 154 |
+
# 测试参数提取
|
| 155 |
+
extracted = diffusion.extract(diffusion.sqrt_alphas_cumprod, t, x_start.shape)
|
| 156 |
+
assert extracted.shape == (2, 1, 1, 1)
|
| 157 |
+
print(f" 参数提取形状正确: {extracted.shape}")
|
| 158 |
+
|
| 159 |
+
return diffusion
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def test_memory_efficiency():
|
| 163 |
+
"""测试内存效率"""
|
| 164 |
+
print("测试内存效率...")
|
| 165 |
+
|
| 166 |
+
# 测试模型在不同批次大小下的内存使用
|
| 167 |
+
config = {
|
| 168 |
+
'model': {
|
| 169 |
+
'in_channels': 4,
|
| 170 |
+
'out_channels': 4,
|
| 171 |
+
'base_channels': 32,
|
| 172 |
+
'channel_mults': [1, 2],
|
| 173 |
+
'num_res_blocks': 1,
|
| 174 |
+
'attention_resolutions': [],
|
| 175 |
+
'dropout': 0.0,
|
| 176 |
+
'use_checkpoint': False,
|
| 177 |
+
'num_heads': 4,
|
| 178 |
+
'context_dim': 256,
|
| 179 |
+
'use_linear_projection': True,
|
| 180 |
+
'time_embed_dim': 128
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
model = UNetLight(config)
|
| 185 |
+
model.eval()
|
| 186 |
+
|
| 187 |
+
if torch.cuda.is_available():
|
| 188 |
+
device = torch.device('cuda')
|
| 189 |
+
model = model.to(device)
|
| 190 |
+
|
| 191 |
+
print(" GPU内存测试:")
|
| 192 |
+
|
| 193 |
+
for batch_size in [1, 2, 4]:
|
| 194 |
+
# 清空缓存
|
| 195 |
+
torch.cuda.empty_cache()
|
| 196 |
+
|
| 197 |
+
# 记录初始内存
|
| 198 |
+
initial_memory = torch.cuda.memory_allocated()
|
| 199 |
+
|
| 200 |
+
# 前向传播
|
| 201 |
+
x = torch.randn(batch_size, 4, 64, 64, device=device)
|
| 202 |
+
t = torch.randint(0, 1000, (batch_size,), device=device)
|
| 203 |
+
context = torch.randn(batch_size, 77, 256, device=device)
|
| 204 |
+
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
_ = model(x, t, context)
|
| 207 |
+
|
| 208 |
+
# 记录峰值内存
|
| 209 |
+
peak_memory = torch.cuda.max_memory_allocated()
|
| 210 |
+
memory_used = (peak_memory - initial_memory) / 1024**3 # GB
|
| 211 |
+
|
| 212 |
+
print(f" 批次大小 {batch_size}: {memory_used:.2f} GB")
|
| 213 |
+
|
| 214 |
+
else:
|
| 215 |
+
print(" GPU不可用,跳过内存测试")
|
| 216 |
+
|
| 217 |
+
return model
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def run_all_tests():
|
| 221 |
+
"""运行所有测试"""
|
| 222 |
+
print("=" * 60)
|
| 223 |
+
print("运行Lumina基础测试")
|
| 224 |
+
print("=" * 60)
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
# 测试各个组件
|
| 228 |
+
test_timestep_embedding()
|
| 229 |
+
test_resnet_block()
|
| 230 |
+
test_attention()
|
| 231 |
+
test_unet_light()
|
| 232 |
+
test_diffusion_process()
|
| 233 |
+
test_memory_efficiency()
|
| 234 |
+
|
| 235 |
+
print("\n" + "=" * 60)
|
| 236 |
+
print("所有测试通过!")
|
| 237 |
+
print("=" * 60)
|
| 238 |
+
|
| 239 |
+
return True
|
| 240 |
+
|
| 241 |
+
except Exception as e:
|
| 242 |
+
print(f"\n测试失败: {e}")
|
| 243 |
+
import traceback
|
| 244 |
+
traceback.print_exc()
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
success = run_all_tests()
|
| 250 |
+
sys.exit(0 if success else 1)
|