TAI Research commited on
Commit
29691f6
·
0 Parent(s):

Initial commit: Lumina_Dev_Legacy (archived)

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 TAI Research
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - image-generation
5
+ - legacy
6
+ - archived
7
+ - research
8
+ - pytorch
9
+ ---
10
+
11
+ # 🎨 Lumina_Dev_Legacy — [ARCHIVED]
12
+
13
+ > **⚠️ 项目状态:已归档 / 停止开发**
14
+ >
15
+ > 本项目为 TAI Research 早期探索性项目,开发阶段已终止,代码仅作存档用途。项目未完成完整训练,不提供预训练模型权重。
16
+
17
+ | 项目信息 | |
18
+ |---------|---|
19
+ | **状态** | 🔴 已废弃 (Archived) |
20
+ | **原因** | 项目方向调整,资源重新分配至其他研究领域 |
21
+ | **最后更新** | 2026 |
22
+ | **训练状态** | ❌ 未完成训练 |
23
+ | **可用资源** | 仅代码框架,无预训练权重 |
24
+
25
+ ---
26
+
27
+ ## 项目背景
28
+
29
+ Lumina 原本的目标是开发一个**针对有限硬件(特别是 NVIDIA P4,8GB 显存)优化的轻量级图像生成模型**,基于扩散模型架构,专注于文本到图像生成。
30
+
31
+ ### 原始设计目标
32
+
33
+ - **极致的显存优化**:专门为 8GB 显存的 GPU 优化
34
+ - **轻量级架构**:参数量 < 2000 万
35
+ - **完整训练管道**:从数据预处理到模型训练
36
+ - **高效推理**:支持多种采样器
37
+ - **模块化设计**:易于扩展和定制
38
+
39
+ ### 为什么废弃?
40
+
41
+ 1. **方向调整**:TAI Research 将重心转向语言模型和 AI 安全领域
42
+ 2. **资源限制**:图像生成模型训练需要大量计算资源
43
+ 3. **竞争环境**:开源社区已有成熟方案(Stable Diffusion、Flux 等)
44
+
45
+ ---
46
+
47
+ ## 代码内容
48
+
49
+ 本仓库仅包含开发阶段的代码框架:
50
+
51
+ ```
52
+ lumina_legacy/
53
+ ├── configs/ # 配置文件
54
+ │ ├── model/ # 模型配置(UNet架构)
55
+ │ ├── training/ # 训练配置(P4优化)
56
+ │ └── data/ # 数据配置
57
+ ├── src/ # 源代码
58
+ │ ├── models/ # UNet + 注意力机制
59
+ │ ├── training/ # 训练器 + 内存优化
60
+ │ ├── data/ # LAION数据处理
61
+ │ └── inference/ # 采样器(DDIM/DPM/LCM)
62
+ ├── scripts/ # 工具脚本
63
+ │ ├── train.py # 训练入口
64
+ │ ├── download_laion.py # 数据下载
65
+ │ └── webui.py # Gradio界面
66
+ └── tests/ # 单元测试
67
+ ```
68
+
69
+ **⚠️ 注意**:
70
+ - 代码未经完整测试,可能存在 bug
71
+ - 训练管道未验证
72
+ - 不保证可复现
73
+
74
+ ---
75
+
76
+ ## 技术栈
77
+
78
+ | 组件 | 技术 |
79
+ |------|------|
80
+ | 框架 | PyTorch 2.0+ |
81
+ | 架构 | 轻量级 UNet + Cross-Attention |
82
+ | 扩散 | DDPM / DDIM |
83
+ | 文本编码 | CLIP (预训练冻结) |
84
+ | 精度 | FP16 混合精度 |
85
+ | 优化 | 梯度检查点 + 梯度累积 |
86
+
87
+ ### 原始设计架构
88
+
89
+ ```
90
+ 输入 (4×64×64)
91
+
92
+ Conv2d (4→64)
93
+
94
+ [下采样块 × 4]
95
+
96
+ [注意力层 (8×8)]
97
+
98
+ [上采样块 × 4]
99
+
100
+ Conv2d (64→4)
101
+
102
+ 输出 (4×64×64)
103
+ ```
104
+
105
+ ---
106
+
107
+ ## 如何使用这些代码
108
+
109
+ > ⚠️ 代码仅为存档,不提供任何保证
110
+
111
+ ```bash
112
+ # 克隆仓库
113
+ git clone https://huggingface.co/TAI-Research/Lumina_Dev_Legacy
114
+ cd Lumina_Dev_Legacy
115
+
116
+ # 安装依赖(如需要)
117
+ pip install -r requirements.txt
118
+
119
+ # 尝试运行训练(仅测试代码是否可运行)
120
+ python scripts/train.py --config configs/training/p4_optimized.yaml --dummy
121
+ ```
122
+
123
+ ---
124
+
125
+ ## 已知问题
126
+
127
+ - [ ] 训练管道未完整验证
128
+ - [ ] 数据处理模块有性能问题
129
+ - [ ] 推理采样器可能有 bug
130
+ - [ ] 缺少单元测试覆盖
131
+
132
+ ---
133
+
134
+ ## 后续
135
+
136
+ 如果你对这个项目感兴趣,建议使用成熟的开源方案:
137
+
138
+ - [Stable Diffusion](https://github.com/CompVis/stable-diffusion)
139
+ - [Hugging Face Diffusers](https://github.com/huggingface/diffusers)
140
+ - [Flux](https://github.com/black-forest-labs/flux)
141
+
142
+ ---
143
+
144
+ ## 许可证
145
+
146
+ MIT License — 代码可自由使用,但无任何保证。
147
+
148
+ ---
149
+
150
+ ## 相关链接
151
+
152
+ - [TAI Research Hugging Face](https://huggingface.co/TAI-Research)
153
+ - [GTC-Guard-0](https://huggingface.co/TAI-Research/GTC-Guard-0)
154
+
155
+ ---
156
+
157
+ **最后更新**: 2026-05-23
158
+ **归档者**: TAI Research
SOLUTION_SUMMARY.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - image-generation
5
+ - legacy
6
+ - archived
7
+ - research
8
+ - pytorch
9
+ ---
10
+
11
+ # 🎨 Lumina_Dev_Legacy — [ARCHIVED]
12
+
13
+ > **⚠️ 项目状态:已归档 / 停止开发**
14
+ >
15
+ > 本项目为 TAI Research 早期探索性项目,开发阶段已终止,代码仅作存档用途。项目未完成完整训练,不提供预训练模型权重。
16
+
17
+ | 项目信息 | |
18
+ |---------|---|
19
+ | **状态** | 🔴 已废弃 (Archived) |
20
+ | **原因** | 项目方向调整,资源重新分配至其他研究领域 |
21
+ | **最后更新** | 2026 |
22
+ | **训练状态** | ❌ 未完成训练 |
23
+ | **可用资源** | 仅代码框架,无预训练权重 |
24
+
25
+ ---
26
+
27
+ ## 项目背景
28
+
29
+ Lumina 原本的目标是开发一个**针对有限硬件(特别是 NVIDIA P4,8GB 显存)优化的轻量级图像生成模型**,基于扩散模型架构,专注于文本到图像生成。
30
+
31
+ ### 原始设计目标
32
+
33
+ - **极致的显存优化**:专门为 8GB 显存的 GPU 优化
34
+ - **轻量级架构**:参数量 < 2000 万
35
+ - **完整训练管道**:从数据预处理到模型训练
36
+ - **高效推理**:支持多种采样器
37
+ - **模块化设计**:易于扩展和定制
38
+
39
+ ### 为什么废弃?
40
+
41
+ 1. **方向调整**:TAI Research 将重心转向语言模型和 AI 安全领域
42
+ 2. **资源限制**:图像生成模型训练需要大量计算资源
43
+ 3. **竞争环境**:开源社区已有成熟方案(Stable Diffusion、Flux 等)
44
+
45
+ ---
46
+
47
+ ## 代码内容
48
+
49
+ 本仓库仅包含开发阶段的代码框架:
configs/data/laion_filtered.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 数据处理配置
2
+ dataset:
3
+ name: "laion-aesthetic"
4
+ path: "./data/laion" # 数据存放路径
5
+ metadata_file: "./data/laion/metadata.parquet"
6
+
7
+ # 过滤条件
8
+ filters:
9
+ aesthetic_score: 6.0
10
+ watermark_prob: 0.5
11
+ nsfw: false
12
+
13
+ # 数据拆分
14
+ split:
15
+ train: 0.95
16
+ val: 0.05
17
+ seed: 42
18
+
19
+ # 处理配置
20
+ max_samples: 2000000 # 最大样本数
21
+ shuffle: true
22
+ shuffle_seed: 42
23
+
24
+ preprocessing:
25
+ # 图像处理
26
+ target_size: 512
27
+ resize_mode: "center_crop" # "center_crop", "random_crop", "resize"
28
+ random_crop: true
29
+ random_flip: true
30
+
31
+ # 归一化
32
+ normalize:
33
+ mean: [0.5, 0.5, 0.5]
34
+ std: [0.5, 0.5, 0.5]
35
+
36
+ # 文本处理
37
+ tokenizer: "openai/clip-vit-base-patch32"
38
+ max_length: 77
39
+ truncation: true
40
+ padding: "max_length"
41
+
42
+ # 缓存
43
+ use_cache: true
44
+ cache_dir: "./data/cache"
45
+ cache_compression: true
46
+
47
+ loader:
48
+ batch_size: 1
49
+ shuffle: true
50
+ num_workers: 2
51
+ prefetch_factor: 2
52
+ persistent_workers: true
configs/model/diffusion.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 扩散过程配置
2
+ diffusion:
3
+ beta_schedule: "scaled_linear"
4
+ beta_start: 0.00085
5
+ beta_end: 0.012
6
+ num_train_timesteps: 1000
7
+ num_inference_timesteps: 50
8
+
9
+ # 损失函数
10
+ loss_type: "l2" # 或 "l1"
11
+ snr_gamma: null
12
+
13
+ # 采样器
14
+ sampler_type: "ddim" # 或 "dpm++_2m"
15
+ prediction_type: "epsilon"
16
+
17
+ # 训练参数
18
+ vae_scale_factor: 0.18215
19
+ offset_noise_strength: 0.1
configs/model/unet_light.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 轻量UNet配置
2
+ model:
3
+ in_channels: 4 # 潜在空间通道数
4
+ out_channels: 4
5
+ base_channels: 64
6
+ channel_mults: [1, 2, 4, 8] # 4次下采样
7
+ num_res_blocks: 2
8
+ attention_resolutions: [8] # 仅在最低分辨率应用注意力
9
+ dropout: 0.0
10
+ use_checkpoint: true
11
+ num_heads: 4
12
+
13
+ # 文本条件
14
+ context_dim: 768 # CLIP文本编码维度
15
+ use_linear_projection: true
16
+
17
+ # 时间步嵌入
18
+ time_embed_dim: 256
19
+
20
+ # 优化配置
21
+ use_flash_attention: false # P4不支持,但保留选项
22
+ gradient_checkpointing: true
configs/training/p4_optimized.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # P4优化训练配置
2
+ hardware:
3
+ gpu_memory: 8 # GB
4
+ batch_size: 1 # 实际批大小
5
+ gradient_accumulation_steps: 8
6
+ num_workers: 2
7
+ pin_memory: true
8
+
9
+ # 显存优化策略
10
+ mixed_precision: "fp16"
11
+ gradient_checkpointing: true
12
+ attention_slicing: "auto"
13
+ cpu_offload: true
14
+ tiled_vae: false # 如果启用分块VAE解码
15
+
16
+ # 动态显存管理
17
+ memory_threshold_gb: 6.5
18
+ warning_threshold_gb: 6.0
19
+ cleanup_frequency: 100
20
+
21
+ training:
22
+ max_epochs: 50
23
+ learning_rate: 1e-4
24
+ learning_rate_scheduler: "cosine"
25
+ warmup_steps: 1000
26
+ weight_decay: 0.01
27
+ adam_beta1: 0.9
28
+ adam_beta2: 0.999
29
+ adam_epsilon: 1e-8
30
+
31
+ # 训练策略
32
+ gradient_clip: 1.0
33
+ use_ema: true
34
+ ema_decay: 0.9999
35
+ save_checkpoint_every: 1000
36
+ save_best_model: true
37
+
38
+ # 验证和监控
39
+ validation_steps: 500
40
+ sample_steps: 500
41
+ log_steps: 50
42
+
43
+ # 优化器状态
44
+ optimizer_on_cpu: true
45
+
46
+ data:
47
+ resolution: 512
48
+ center_crop: true
49
+ random_flip: true
50
+ cache_dataset: true
51
+
52
+ # 数据增强
53
+ augmentation:
54
+ random_crop: true
55
+ color_jitter: 0.05
56
+ random_rotation: 5.0 # 角度
57
+
58
+ logging:
59
+ use_wandb: false
60
+ use_tensorboard: true
61
+ log_dir: "./logs"
62
+ project_name: "lumina"
63
+ run_name: "lumina-v0.1"
64
+
65
+ checkpoint:
66
+ save_dir: "./checkpoints"
67
+ keep_last: 5
68
+ save_compressed: true
69
+ save_onnx: false
configs/training/schedule_256.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 256x256训练计划
2
+ phases:
3
+ - name: "phase1_warmup"
4
+ epochs: 5
5
+ resolution: 256
6
+ learning_rate: 1e-5
7
+ batch_size: 1
8
+ gradient_accumulation: 8
9
+ description: "预热阶段,低分辨率"
10
+
11
+ - name: "phase2_main"
12
+ epochs: 20
13
+ resolution: 256
14
+ learning_rate: 1e-4
15
+ batch_size: 1
16
+ gradient_accumulation: 8
17
+ description: "主训练阶段"
18
+
19
+ - name: "phase3_refine"
20
+ epochs: 10
21
+ resolution: 256
22
+ learning_rate: 5e-5
23
+ batch_size: 1
24
+ gradient_accumulation: 8
25
+ description: "精细调优"
26
+
27
+ - name: "phase4_upscale"
28
+ epochs: 15
29
+ resolution: 512
30
+ learning_rate: 2e-5
31
+ batch_size: 1
32
+ gradient_accumulation: 4
33
+ description: "升级到512分辨率"
data/laion/dataset_info.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "total_samples": 10,
3
+ "columns": [
4
+ "url",
5
+ "caption",
6
+ "aesthetic_score",
7
+ "watermark_prob",
8
+ "NSFW",
9
+ "image_file"
10
+ ],
11
+ "dataset": "dummy",
12
+ "description": "\u865a\u62df\u6d4b\u8bd5\u6570\u636e\u96c6 (10\u6761\u8bb0\u5f55)",
13
+ "note": "\u4e0d\u5305\u542b\u5b9e\u9645\u56fe\u50cf\u6587\u4ef6\uff0c\u4ec5\u7528\u4e8e\u6d4b\u8bd5\u4ee3\u7801\u6d41\u7a0b"
14
+ }
data/laion/metadata.parquet ADDED
Binary file (4.91 kB). View file
 
docs/LAION_DATASET_GUIDE.md ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LAION数据集下载指南
2
+
3
+ ## 问题描述
4
+
5
+ 运行代码时出现以下错误:
6
+ ```
7
+ FileNotFoundError: [Errno 2] No such file or directory: './data/laion/metadata.parquet'
8
+ ```
9
+
10
+ 这是因为项目需要LAION数据集,但数据集文件不存在。
11
+
12
+ ## LAION数据集简介
13
+
14
+ LAION(Large-scale Artificial Intelligence Open Network)是一个大规模的多模态数据集,包含:
15
+ - **LAION-5B**: 58.5亿个图像-文本对
16
+ - **LAION-Aesthetic**: 经过美学评分筛选的高质量子集
17
+ - **LAION-400M**: 4亿个图像-文本对
18
+
19
+ ## 解决方案
20
+
21
+ ### 方案1:使用虚拟数据集测试(推荐)
22
+
23
+ 对于测试和开发,可以使用虚拟数据集:
24
+
25
+ ```bash
26
+ # 创建虚拟数据集
27
+ python scripts/download_laion.py --dummy --dummy-size 100
28
+
29
+ # 或者直接运行
30
+ python scripts/download_laion.py --dummy
31
+ ```
32
+
33
+ 这将创建一个包含100条虚拟记录的元数据文件,用于测试代码流程。
34
+
35
+ ### 方案2:下载LAION-Aesthetic子集
36
+
37
+ LAION-Aesthetic是经过美学评分筛选的高质量数据集:
38
+
39
+ ```bash
40
+ # 下载LAION-Aesthetic 6.5+子集(默认)
41
+ python scripts/download_laion.py
42
+
43
+ # 下载LAION-Aesthetic 7.0+子集(更高质量)
44
+ python scripts/download_laion.py --subset "7.0+"
45
+ ```
46
+
47
+ ### 方案3:下载LAION-5B样本
48
+
49
+ ```bash
50
+ # 下载10,000条记录的样本
51
+ python scripts/download_laion.py --sample-size 10000
52
+ ```
53
+
54
+ ## 手动下载方法
55
+
56
+ ### 方法1:从Hugging Face下载
57
+
58
+ 1. **访问Hugging Face数据集页面**:
59
+ - LAION-Aesthetic 6.5+: https://huggingface.co/datasets/laion/laion-aesthetic-6.5plus
60
+ - LAION-Aesthetic 7.0+: https://huggingface.co/datasets/laion/laion-aesthetic-7.0plus
61
+ - LAION-5B: https://huggingface.co/datasets/laion/laion2b-en
62
+
63
+ 2. **下载元数据文件**:
64
+ ```bash
65
+ # 创建目录
66
+ mkdir -p data/laion
67
+
68
+ # 下载LAION-Aesthetic 6.5+元数据
69
+ wget https://huggingface.co/datasets/laion/laion-aesthetic-6.5plus/resolve/main/data/00000.parquet -O data/laion/metadata.parquet
70
+
71
+ # 或者使用curl
72
+ curl -L https://huggingface.co/datasets/laion/laion-aesthetic-6.5plus/resolve/main/data/00000.parquet -o data/laion/metadata.parquet
73
+ ```
74
+
75
+ ### 方法2:使用img2dataset工具
76
+
77
+ `img2dataset`是一个专门用于下载LAION数据集的工具:
78
+
79
+ ```bash
80
+ # 安装img2dataset
81
+ pip install img2dataset
82
+
83
+ # 下载LAION-400M子集
84
+ img2dataset \
85
+ --url_list "path/to/laion-400m.parquet" \
86
+ --input_format "parquet" \
87
+ --url_col "URL" \
88
+ --caption_col "TEXT" \
89
+ --output_folder "data/laion/images" \
90
+ --processes_count 16 \
91
+ --thread_count 64 \
92
+ --image_size 512 \
93
+ --resize_mode "keep_ratio" \
94
+ --output_format "webdataset"
95
+ ```
96
+
97
+ ### 方法3:使用官方脚本
98
+
99
+ LAION官方提供了一些下载脚本:
100
+
101
+ ```bash
102
+ # 克隆LAION工具仓库
103
+ git clone https://github.com/rom1504/img2dataset.git
104
+ cd img2dataset
105
+
106
+ # 查看使用说明
107
+ python -m img2dataset --help
108
+ ```
109
+
110
+ ## 数据集结构
111
+
112
+ 下载后,数据集目录结构应为:
113
+
114
+ ```
115
+ data/
116
+ └── laion/
117
+ ├── metadata.parquet # 元数据文件(必需)
118
+ ├── dataset_info.json # 数据集信息文件
119
+ └── images/ # 图像文件目录(可选)
120
+ ├── 00000.tar
121
+ ├── 00001.tar
122
+ └── ...
123
+ ```
124
+
125
+ ### 元数据文件格式
126
+
127
+ `metadata.parquet`文件通常包含以下列:
128
+ - `url`: 图像URL
129
+ - `caption` 或 `text`: 图像描述文本
130
+ - `aesthetic_score`: 美学评分(LAION-Aesthetic特有)
131
+ - `watermark_prob`: 水印概率
132
+ - `NSFW`: 成人内容标记
133
+ - `width`/`height`: 图像尺寸
134
+
135
+ ## 验证数据集
136
+
137
+ 下载完成后,验证数据集是否正确:
138
+
139
+ ```python
140
+ import pandas as pd
141
+ import os
142
+
143
+ # 检查文件是否存在
144
+ metadata_path = "./data/laion/metadata.parquet"
145
+ if os.path.exists(metadata_path):
146
+ print(f"元数据文件存在: {metadata_path}")
147
+
148
+ # 读取前几行
149
+ df = pd.read_parquet(metadata_path)
150
+ print(f"记录数: {len(df)}")
151
+ print(f"列名: {list(df.columns)}")
152
+ print("\n前5条记录:")
153
+ print(df.head())
154
+ else:
155
+ print(f"错误: 文件不存在 {metadata_path}")
156
+ ```
157
+
158
+ ## 常见问题
159
+
160
+ ### 问题1:下载速度慢
161
+ - **解决方案**:使用国内镜像或代理
162
+ - 可以尝试使用清华镜像:`pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple`
163
+
164
+ ### 问题2:存储空间不足
165
+ - **解决方案**:
166
+ 1. 下载较小的子集(如LAION-Aesthetic)
167
+ 2. 使用虚拟数据集进行测试
168
+ 3. 只下载元数据,不下载图像
169
+
170
+ ### 问题3:网络连接问题
171
+ - **解决方案**:
172
+ 1. 使用`--dummy`参数创建虚拟数据集
173
+ 2. 手动下载小样本文件
174
+ 3. 使用现有的本地数据集
175
+
176
+ ### 问题4:Parquet文件读取错误
177
+ - **解决方案**:
178
+ ```bash
179
+ # 安装正确版本的pandas和pyarrow
180
+ pip install pandas pyarrow fastparquet
181
+
182
+ # 或者使用dask读取
183
+ pip install dask[dataframe]
184
+ ```
185
+
186
+ ## 高级用法
187
+
188
+ ### 自定义数据集
189
+
190
+ 如��需要使用自定义数据集,可以修改配置文件:
191
+
192
+ ```yaml
193
+ # configs/data/laion_filtered.yaml
194
+ dataset:
195
+ name: "custom-dataset"
196
+ path: "./data/custom" # 修改路径
197
+ metadata_file: "./data/custom/metadata.parquet" # 修改元数据文件路径
198
+ ```
199
+
200
+ ### 数据集预处理
201
+
202
+ 项目包含数据预处理模块:
203
+
204
+ ```python
205
+ from src.data.preprocessing import get_transform
206
+
207
+ # 获取数据变换
208
+ transform = get_transform(config, mode='train')
209
+
210
+ # 创建数据集
211
+ from src.data.dataset import LAIONDataset
212
+ dataset = LAIONDataset(config, transform=transform, split='train')
213
+ ```
214
+
215
+ ### 批量下载图像
216
+
217
+ 如果需要下载实际图像文件:
218
+
219
+ ```python
220
+ import requests
221
+ from PIL import Image
222
+ from io import BytesIO
223
+ import pandas as pd
224
+
225
+ # 读取元数据
226
+ df = pd.read_parquet("./data/laion/metadata.parquet")
227
+
228
+ # 下载前N张图像
229
+ for i, row in df.head(10).iterrows():
230
+ try:
231
+ response = requests.get(row['url'], timeout=10)
232
+ img = Image.open(BytesIO(response.content))
233
+ img.save(f"./data/laion/images/image_{i:06d}.jpg")
234
+ print(f"下载完成: {i}")
235
+ except Exception as e:
236
+ print(f"下载失败 {row['url']}: {e}")
237
+ ```
238
+
239
+ ## 性能优化建议
240
+
241
+ 1. **使用缓存**:启用数据集缓存加速训练
242
+ ```yaml
243
+ preprocessing:
244
+ use_cache: true
245
+ cache_dir: "./data/cache"
246
+ ```
247
+
248
+ 2. **数据并行**:使用多个worker加载数据
249
+ ```yaml
250
+ loader:
251
+ num_workers: 4
252
+ prefetch_factor: 2
253
+ ```
254
+
255
+ 3. **内存映射**:对于大型数据集,使用内存映射文件
256
+ ```python
257
+ df = pd.read_parquet("metadata.parquet", memory_map=True)
258
+ ```
259
+
260
+ ## 参考资料
261
+
262
+ 1. [LAION官方网站](https://laion.ai/)
263
+ 2. [LAION数据集论文](https://arxiv.org/abs/2210.08402)
264
+ 3. [Hugging Face数据集](https://huggingface.co/datasets/laion)
265
+ 4. [img2dataset工具](https://github.com/rom1504/img2dataset)
266
+ 5. [WebDataset格式](https://github.com/webdataset/webdataset)
267
+
268
+ ## 技术支持
269
+
270
+ 如果遇到问题:
271
+ 1. 检查错误信息
272
+ 2. 查看日志文件
273
+ 3. 参考项目README
274
+ 4. 在GitHub Issues中搜索类似问题
275
+ 5. 创建新的Issue寻求帮助
276
+
277
+ ---
278
+
279
+ **注意**:LAION数据集受版权法保护,请确保遵守使用条款和许可证要求。
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 核心依赖
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ transformers>=4.30.0
5
+ diffusers>=0.20.0
6
+ accelerate>=0.21.0
7
+
8
+ # 数据处理
9
+ Pillow>=10.0.0
10
+ numpy>=1.24.0
11
+ pandas>=2.0.0
12
+ pyarrow>=12.0.0
13
+
14
+ # 训练工具
15
+ wandb>=0.15.0
16
+ tqdm>=4.65.0
17
+ matplotlib>=3.7.0
18
+ tensorboard>=2.13.0
19
+
20
+ # API和部署
21
+ gradio>=3.41.0
22
+ fastapi>=0.100.0
23
+ uvicorn>=0.23.0
24
+
25
+ # 开发工具
26
+ black>=23.7.0
27
+ flake8>=6.0.0
28
+ isort>=5.12.0
scripts/benchmark.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 性能基准测试脚本
4
+ 测试模型训练和推理性能
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import time
10
+ import argparse
11
+ import yaml
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.data import DataLoader
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+
18
+ # 添加项目根目录到Python路径
19
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
20
+
21
+ from src.models.unet_light import UNetLight
22
+ from src.models.diffusion import DiffusionProcess
23
+ from src.data.dataset import create_data_loaders
24
+ from src.inference.optimization import InferenceBenchmark
25
+ from src.inference.sampler import DDIMSampler
26
+
27
+
28
+ def load_config(config_path: str) -> dict:
29
+ """加载配置文件"""
30
+ with open(config_path, 'r') as f:
31
+ config = yaml.safe_load(f)
32
+ return config
33
+
34
+
35
+ def benchmark_training(config: dict):
36
+ """训练性能基准测试"""
37
+ print("=" * 60)
38
+ print("训练性能基准测试")
39
+ print("=" * 60)
40
+
41
+ # 设备
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
+
44
+ # 加载模型配置
45
+ model_config = load_config('configs/model/unet_light.yaml')
46
+
47
+ # 创建模型
48
+ model = UNetLight(model_config).to(device)
49
+
50
+ # 创建扩散过程
51
+ diffusion_config = load_config('configs/model/diffusion.yaml')
52
+ diffusion = DiffusionProcess(diffusion_config)
53
+
54
+ # 创建数据加载器
55
+ data_config = load_config('configs/data/laion_filtered.yaml')
56
+ train_loader, _ = create_data_loaders(data_config)
57
+
58
+ # 优化器
59
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
60
+
61
+ # 预热
62
+ print("预热...")
63
+ warmup_batches = 5
64
+ for i, batch in enumerate(train_loader):
65
+ if i >= warmup_batches:
66
+ break
67
+
68
+ images = batch['images'].to(device)
69
+ text_embeddings = batch['text_embeddings'].to(device)
70
+
71
+ # 前向传播
72
+ loss = diffusion.compute_loss(model, images, text_embeddings)
73
+
74
+ # 反向传播
75
+ loss.backward()
76
+ optimizer.zero_grad()
77
+
78
+ # 同步
79
+ if torch.cuda.is_available():
80
+ torch.cuda.synchronize()
81
+
82
+ # 基准测试
83
+ print("运行训练基准测试...")
84
+ num_batches = 20
85
+ batch_times = []
86
+ memory_usage = []
87
+
88
+ model.train()
89
+
90
+ for i, batch in enumerate(train_loader):
91
+ if i >= num_batches:
92
+ break
93
+
94
+ images = batch['images'].to(device)
95
+ text_embeddings = batch['text_embeddings'].to(device)
96
+
97
+ # 开始计时
98
+ start_time = time.time()
99
+
100
+ # 前向传播
101
+ loss = diffusion.compute_loss(model, images, text_embeddings)
102
+
103
+ # 反向传播
104
+ loss.backward()
105
+ optimizer.step()
106
+ optimizer.zero_grad()
107
+
108
+ # 同步
109
+ if torch.cuda.is_available():
110
+ torch.cuda.synchronize()
111
+
112
+ # 结束计时
113
+ end_time = time.time()
114
+ batch_time = end_time - start_time
115
+ batch_times.append(batch_time)
116
+
117
+ # 记录内存使用
118
+ if torch.cuda.is_available():
119
+ memory_allocated = torch.cuda.memory_allocated() / 1024**3
120
+ memory_usage.append(memory_allocated)
121
+
122
+ # 进度
123
+ print(f"批次 {i+1}/{num_batches}: {batch_time:.3f}s")
124
+
125
+ # 统计
126
+ batch_times = np.array(batch_times)
127
+
128
+ print("\n" + "=" * 60)
129
+ print("训练基准测试结果:")
130
+ print(f" 平均批次时间: {batch_times.mean():.3f} ± {batch_times.std():.3f} s")
131
+ print(f" 最小批次时间: {batch_times.min():.3f} s")
132
+ print(f" 最大批次时间: {batch_times.max():.3f} s")
133
+ print(f" 吞吐量: {1 / batch_times.mean():.2f} batches/s")
134
+
135
+ if memory_usage:
136
+ memory_usage = np.array(memory_usage)
137
+ print(f" 平均GPU内存使用: {memory_usage.mean():.2f} ± {memory_usage.std():.2f} GB")
138
+
139
+ print("=" * 60)
140
+
141
+ return batch_times.mean()
142
+
143
+
144
+ def benchmark_inference(config: dict):
145
+ """推理性能基准测试"""
146
+ print("\n" + "=" * 60)
147
+ print("推理性能基准测试")
148
+ print("=" * 60)
149
+
150
+ # 设备
151
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
152
+
153
+ # 加载模型配置
154
+ model_config = load_config('configs/model/unet_light.yaml')
155
+
156
+ # 创建模型
157
+ model = UNetLight(model_config).to(device)
158
+ model.eval()
159
+
160
+ # 创建扩散过程
161
+ diffusion_config = load_config('configs/model/diffusion.yaml')
162
+ diffusion = DiffusionProcess(diffusion_config)
163
+
164
+ # 创建基准测试器
165
+ benchmark = InferenceBenchmark(model, device)
166
+
167
+ # 测试不同分辨率
168
+ resolutions = [(256, 256), (512, 512), (768, 768)]
169
+ results = {}
170
+
171
+ for height, width in resolutions:
172
+ print(f"\n测试分辨率: {width}x{height}")
173
+
174
+ # 潜在空间大小
175
+ latent_height = height // 8
176
+ latent_width = width // 8
177
+
178
+ # 运行基准测试
179
+ stats = benchmark.benchmark(
180
+ input_shape=(1, model.in_channels, latent_height, latent_width),
181
+ num_iterations=10,
182
+ warmup_iterations=3
183
+ )
184
+
185
+ results[f"{width}x{height}"] = stats
186
+
187
+ # 打印总结
188
+ print("\n" + "=" * 60)
189
+ print("推理基准测试总结:")
190
+
191
+ for resolution, stats in results.items():
192
+ print(f"\n 分辨率 {resolution}:")
193
+ print(f" 平均时间: {stats['mean_ms']:.1f} ms")
194
+ print(f" FPS: {stats['fps']:.1f}")
195
+
196
+ print("=" * 60)
197
+
198
+ return results
199
+
200
+
201
+ def benchmark_sampling(config: dict):
202
+ """采样性能基准测试"""
203
+ print("\n" + "=" * 60)
204
+ print("采样性能基准测试")
205
+ print("=" * 60)
206
+
207
+ # 设备
208
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
209
+
210
+ # 加载模型配置
211
+ model_config = load_config('configs/model/unet_light.yaml')
212
+
213
+ # 创建模型
214
+ model = UNetLight(model_config).to(device)
215
+ model.eval()
216
+
217
+ # 创建扩散过程
218
+ diffusion_config = load_config('configs/model/diffusion.yaml')
219
+ diffusion = DiffusionProcess(diffusion_config)
220
+
221
+ # 创建采样器
222
+ sampler = DDIMSampler(model, diffusion, num_inference_steps=50)
223
+
224
+ # 测试不同采样步数
225
+ step_configs = [20, 30, 50]
226
+ results = {}
227
+
228
+ for num_steps in step_configs:
229
+ print(f"\n测试采样步数: {num_steps}")
230
+
231
+ # 设置采样步数
232
+ sampler.set_timesteps(num_steps)
233
+
234
+ # 准备输入
235
+ prompt_embeds = torch.randn(1, 77, 768, device=device)
236
+
237
+ # 预热
238
+ print(" 预热...")
239
+ with torch.no_grad():
240
+ for _ in range(3):
241
+ _ = sampler.sample(
242
+ prompt_embeds=prompt_embeds,
243
+ height=512,
244
+ width=512,
245
+ progress_bar=False
246
+ )
247
+
248
+ # 基准测试
249
+ print(" 运行基准测试...")
250
+ times = []
251
+
252
+ for i in range(5):
253
+ start_time = time.time()
254
+
255
+ with torch.no_grad():
256
+ _ = sampler.sample(
257
+ prompt_embeds=prompt_embeds,
258
+ height=512,
259
+ width=512,
260
+ progress_bar=False
261
+ )
262
+
263
+ if torch.cuda.is_available():
264
+ torch.cuda.synchronize()
265
+
266
+ end_time = time.time()
267
+ times.append(end_time - start_time)
268
+
269
+ print(f" 迭代 {i+1}: {times[-1]:.2f}s")
270
+
271
+ # 统计
272
+ times = np.array(times)
273
+
274
+ results[num_steps] = {
275
+ 'mean_time': times.mean(),
276
+ 'std_time': times.std(),
277
+ 'fps': 1 / times.mean()
278
+ }
279
+
280
+ # 打印总结
281
+ print("\n" + "=" * 60)
282
+ print("采样基准测试总结:")
283
+
284
+ for num_steps, stats in results.items():
285
+ print(f"\n 采样步数 {num_steps}:")
286
+ print(f" 平均时间: {stats['mean_time']:.2f} ± {stats['std_time']:.2f} s")
287
+ print(f" FPS: {stats['fps']:.2f}")
288
+
289
+ print("=" * 60)
290
+
291
+ return results
292
+
293
+
294
+ def benchmark_memory(config: dict):
295
+ """内存使用基准测试"""
296
+ print("\n" + "=" * 60)
297
+ print("内存使用基准测试")
298
+ print("=" * 60)
299
+
300
+ if not torch.cuda.is_available():
301
+ print("GPU不可用,跳过内存基准测试")
302
+ return {}
303
+
304
+ # 设备
305
+ device = torch.device('cuda')
306
+
307
+ # 加载模型配置
308
+ model_config = load_config('configs/model/unet_light.yaml')
309
+
310
+ # 测试不同批次大小
311
+ batch_sizes = [1, 2, 4, 8]
312
+ results = {}
313
+
314
+ for batch_size in batch_sizes:
315
+ print(f"\n测试批次大小: {batch_size}")
316
+
317
+ # 创建模型
318
+ model = UNetLight(model_config).to(device)
319
+ model.eval()
320
+
321
+ # 准备输入
322
+ input_shape = (batch_size, model.in_channels, 64, 64)
323
+ x = torch.randn(*input_shape, device=device)
324
+ t = torch.tensor([500] * batch_size, device=device)
325
+ context = torch.randn(batch_size, 77, 768, device=device)
326
+
327
+ # 清空缓存
328
+ torch.cuda.empty_cache()
329
+
330
+ # 记录初始内存
331
+ initial_memory = torch.cuda.memory_allocated()
332
+
333
+ # 前向传播
334
+ with torch.no_grad():
335
+ _ = model(x, t, context)
336
+
337
+ # 记录峰值内存
338
+ peak_memory = torch.cuda.max_memory_allocated()
339
+ current_memory = torch.cuda.memory_allocated()
340
+
341
+ # 计算内存使用
342
+ memory_used = peak_memory - initial_memory
343
+
344
+ results[batch_size] = {
345
+ 'initial_memory_gb': initial_memory / 1024**3,
346
+ 'peak_memory_gb': peak_memory / 1024**3,
347
+ 'current_memory_gb': current_memory / 1024**3,
348
+ 'memory_used_gb': memory_used / 1024**3,
349
+ 'memory_per_sample_gb': memory_used / (batch_size * 1024**3)
350
+ }
351
+
352
+ print(f" 初始内存: {initial_memory / 1024**3:.2f} GB")
353
+ print(f" 峰值内存: {peak_memory / 1024**3:.2f} GB")
354
+ print(f" 当前内存: {current_memory / 1024**3:.2f} GB")
355
+ print(f" 内存使用: {memory_used / 1024**3:.2f} GB")
356
+ print(f" 每样本内存: {memory_used / (batch_size * 1024**3):.2f} GB")
357
+
358
+ # 清理
359
+ del model
360
+ torch.cuda.empty_cache()
361
+
362
+ # 打印总结
363
+ print("\n" + "=" * 60)
364
+ print("内存基准测试总结:")
365
+
366
+ for batch_size, stats in results.items():
367
+ print(f"\n 批次大小 {batch_size}:")
368
+ print(f" 总内存使用: {stats['memory_used_gb']:.2f} GB")
369
+ print(f" 每样本内存: {stats['memory_per_sample_gb']:.2f} GB")
370
+
371
+ print("=" * 60)
372
+
373
+ return results
374
+
375
+
376
+ def generate_report(results: dict, output_file: str = "benchmark_report.md"):
377
+ """生成基准测试报告"""
378
+ print(f"\n生成报告: {output_file}")
379
+
380
+ with open(output_file, 'w') as f:
381
+ f.write("# Lumina 性能基准测试报告\n\n")
382
+ f.write(f"生成时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
383
+
384
+ f.write("## 系统信息\n")
385
+ f.write(f"- PyTorch版本: {torch.__version__}\n")
386
+ f.write(f"- CUDA可用: {torch.cuda.is_available()}\n")
387
+ if torch.cuda.is_available():
388
+ f.write(f"- GPU: {torch.cuda.get_device_name(0)}\n")
389
+ f.write(f"- CUDA版本: {torch.version.cuda}\n")
390
+ f.write(f"- 系统内存: {os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') / 1024**3:.1f} GB\n\n")
391
+
392
+ if 'training' in results:
393
+ f.write("## 训练性能\n")
394
+ f.write(f"- 平均批次时间: {results['training']:.3f} s\n")
395
+ f.write(f"- 吞吐量: {1/results['training']:.2f} batches/s\n\n")
396
+
397
+ if 'inference' in results:
398
+ f.write("## 推理性能\n")
399
+ for resolution, stats in results['inference'].items():
400
+ f.write(f"### 分辨率 {resolution}\n")
401
+ f.write(f"- 平均推理时间: {stats['mean_ms']:.1f} ms\n")
402
+ f.write(f"- FPS: {stats['fps']:.1f}\n\n")
403
+
404
+ if 'sampling' in results:
405
+ f.write("## 采样性能\n")
406
+ for num_steps, stats in results['sampling'].items():
407
+ f.write(f"### 采样步数 {num_steps}\n")
408
+ f.write(f"- 平均采样时间: {stats['mean_time']:.2f} s\n")
409
+ f.write(f"- FPS: {stats['fps']:.2f}\n\n")
410
+
411
+ if 'memory' in results:
412
+ f.write("## 内存使用\n")
413
+ for batch_size, stats in results['memory'].items():
414
+ f.write(f"### 批次大小 {batch_size}\n")
415
+ f.write(f"- 总内存使用: {stats['memory_used_gb']:.2f} GB\n")
416
+ f.write(f"- 每样本内存: {stats['memory_per_sample_gb']:.2f} GB\n\n")
417
+
418
+ f.write("## 建议\n")
419
+ f.write("1. 根据GPU内存选择适当的批次大小\n")
420
+ f.write("2. 推理时使用适当的采样步数平衡质量和速度\n")
421
+ f.write("3. 训练时使用梯度累积来模拟大批次训练\n")
422
+
423
+ print(f"报告已保存: {output_file}")
424
+
425
+
426
+ def main():
427
+ """主函数"""
428
+ parser = argparse.ArgumentParser(description="Lumina性能基准测试")
429
+
430
+ parser.add_argument(
431
+ "--config",
432
+ type=str,
433
+ default="configs/training/p4_optimized.yaml",
434
+ help="配置文件路径"
435
+ )
436
+
437
+ parser.add_argument(
438
+ "--test",
439
+ type=str,
440
+ nargs="+",
441
+ default=['all'],
442
+ choices=['training', 'inference', 'sampling', 'memory', 'all'],
443
+ help="测试项目"
444
+ )
445
+
446
+ parser.add_argument(
447
+ "--output",
448
+ type=str,
449
+ default="benchmark_report.md",
450
+ help="输出报告文件"
451
+ )
452
+
453
+ args = parser.parse_args()
454
+
455
+ # 加载配置
456
+ config = load_config(args.config)
457
+
458
+ # 运行基准测试
459
+ results = {}
460
+
461
+ if 'all' in args.test or 'training' in args.test:
462
+ try:
463
+ results['training'] = benchmark_training(config)
464
+ except Exception as e:
465
+ print(f"训练基准测试失败: {e}")
466
+
467
+ if 'all' in args.test or 'inference' in args.test:
468
+ try:
469
+ results['inference'] = benchmark_inference(config)
470
+ except Exception as e:
471
+ print(f"推理基准测试失败: {e}")
472
+
473
+ if 'all' in args.test or 'sampling' in args.test:
474
+ try:
475
+ results['sampling'] = benchmark_sampling(config)
476
+ except Exception as e:
477
+ print(f"采样基准测试失败: {e}")
478
+
479
+ if 'all' in args.test or 'memory' in args.test:
480
+ try:
481
+ results['memory'] = benchmark_memory(config)
482
+ except Exception as e:
483
+ print(f"内存基准测试失败: {e}")
484
+
485
+ # 生成报告
486
+ generate_report(results, args.output)
487
+
488
+
489
+ if __name__ == "__main__":
490
+ main()
scripts/download_laion.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LAION数据集下载脚本
4
+
5
+ 这个脚本帮助下载LAION数据集的不同版本。
6
+ LAION数据集很大,通常需要下载元数据文件和图像文件。
7
+
8
+ 注意:完整LAION数据集非常大(数TB),建议下载子集或使用现有缓存。
9
+ """
10
+
11
+ import os
12
+ import argparse
13
+ import subprocess
14
+ import pandas as pd
15
+ from pathlib import Path
16
+ import requests
17
+ import json
18
+ from tqdm import tqdm
19
+ import sys
20
+
21
+ def download_file(url, output_path, chunk_size=8192):
22
+ """下载文件并显示进度条"""
23
+ response = requests.get(url, stream=True)
24
+ response.raise_for_status()
25
+
26
+ total_size = int(response.headers.get('content-length', 0))
27
+
28
+ with open(output_path, 'wb') as f, tqdm(
29
+ desc=os.path.basename(output_path),
30
+ total=total_size,
31
+ unit='B',
32
+ unit_scale=True,
33
+ unit_divisor=1024,
34
+ ) as pbar:
35
+ for chunk in response.iter_content(chunk_size=chunk_size):
36
+ f.write(chunk)
37
+ pbar.update(len(chunk))
38
+
39
+ return output_path
40
+
41
+ def download_laion_aesthetic(output_dir="./data/laion", subset="6.5+"):
42
+ """
43
+ 下载LAION-Aesthetic数据集
44
+
45
+ Args:
46
+ output_dir: 输出目录
47
+ subset: 子集版本,可选 "6.5+" (6.5分以上), "7.0+" (7.0分以上)
48
+ """
49
+ os.makedirs(output_dir, exist_ok=True)
50
+
51
+ # LAION-Aesthetic数据集信息
52
+ datasets = {
53
+ "6.5+": {
54
+ "metadata": "https://huggingface.co/datasets/laion/laion-aesthetic-6.5plus/resolve/main/data/00000.parquet",
55
+ "description": "LAION-Aesthetic 6.5+ (美学评分6.5分以上)"
56
+ },
57
+ "7.0+": {
58
+ "metadata": "https://huggingface.co/datasets/laion/laion-aesthetic-7.0plus/resolve/main/data/00000.parquet",
59
+ "description": "LAION-Aesthetic 7.0+ (美学评分7.0分以上)"
60
+ }
61
+ }
62
+
63
+ if subset not in datasets:
64
+ print(f"错误: 不支持的子集 {subset}")
65
+ print(f"可用子集: {list(datasets.keys())}")
66
+ return False
67
+
68
+ dataset_info = datasets[subset]
69
+ print(f"下载 {dataset_info['description']}")
70
+
71
+ # 下载元数据文件
72
+ metadata_url = dataset_info["metadata"]
73
+ metadata_path = os.path.join(output_dir, "metadata.parquet")
74
+
75
+ print(f"下载元数据文件到: {metadata_path}")
76
+ try:
77
+ download_file(metadata_url, metadata_path)
78
+ print(f"元数据文件下载完成: {metadata_path}")
79
+
80
+ # 验证文件
81
+ df = pd.read_parquet(metadata_path)
82
+ print(f"元数据包含 {len(df)} 条记录")
83
+ print(f"列名: {list(df.columns)}")
84
+
85
+ # 保存样本信息
86
+ sample_info = {
87
+ "total_samples": len(df),
88
+ "columns": list(df.columns),
89
+ "subset": subset,
90
+ "description": dataset_info["description"]
91
+ }
92
+
93
+ with open(os.path.join(output_dir, "dataset_info.json"), "w") as f:
94
+ json.dump(sample_info, f, indent=2)
95
+
96
+ print(f"数据集信息已保存到: {os.path.join(output_dir, 'dataset_info.json')}")
97
+
98
+ return True
99
+
100
+ except Exception as e:
101
+ print(f"下载失败: {e}")
102
+ return False
103
+
104
+ def download_laion_5b_sample(output_dir="./data/laion", num_samples=10000):
105
+ """
106
+ 下载LAION-5B数据集的样本
107
+
108
+ Args:
109
+ output_dir: 输出目录
110
+ num_samples: 样本数量
111
+ """
112
+ os.makedirs(output_dir, exist_ok=True)
113
+
114
+ print(f"下载LAION-5B数据集样本 ({num_samples}条记录)")
115
+
116
+ # LAION-5B数据集分片URL示例
117
+ # 注意:完整数据集有数万个分片,这里只下载一个样本分片
118
+ sample_shard = "https://huggingface.co/datasets/laion/laion2b-en/resolve/main/part-00000-5b54c5d5-bbcf-484d-a2ce-0d6f73df1a36-c000.snappy.parquet"
119
+
120
+ metadata_path = os.path.join(output_dir, "metadata_sample.parquet")
121
+
122
+ print(f"下载样本分片到: {metadata_path}")
123
+ try:
124
+ download_file(sample_shard, metadata_path)
125
+ print(f"样本分片下载完成: {metadata_path}")
126
+
127
+ # 读取并采样
128
+ df = pd.read_parquet(metadata_path)
129
+
130
+ if len(df) > num_samples:
131
+ df = df.sample(num_samples, random_state=42)
132
+
133
+ # 保存采样后的数据
134
+ sampled_path = os.path.join(output_dir, "metadata.parquet")
135
+ df.to_parquet(sampled_path)
136
+
137
+ print(f"采样数据保存到: {sampled_path}")
138
+ print(f"采样后包含 {len(df)} 条记录")
139
+ print(f"列名: {list(df.columns)}")
140
+
141
+ # 保存样本信息
142
+ sample_info = {
143
+ "total_samples": len(df),
144
+ "original_samples": len(pd.read_parquet(metadata_path)),
145
+ "columns": list(df.columns),
146
+ "dataset": "LAION-5B-sample",
147
+ "description": f"LAION-5B数据集样本 ({num_samples}条记录)"
148
+ }
149
+
150
+ with open(os.path.join(output_dir, "dataset_info.json"), "w") as f:
151
+ json.dump(sample_info, f, indent=2)
152
+
153
+ print(f"数据集信息已保存到: {os.path.join(output_dir, 'dataset_info.json')}")
154
+
155
+ return True
156
+
157
+ except Exception as e:
158
+ print(f"下载失败: {e}")
159
+ return False
160
+
161
+ def create_dummy_dataset(output_dir="./data/laion", num_samples=100):
162
+ """
163
+ 创建虚拟数据集用于测试
164
+
165
+ Args:
166
+ output_dir: 输出目录
167
+ num_samples: 样本数量
168
+ """
169
+ os.makedirs(output_dir, exist_ok=True)
170
+
171
+ print(f"创建虚拟数据集 ({num_samples}条记录)")
172
+
173
+ import numpy as np
174
+
175
+ # 创建虚拟数据
176
+ data = {
177
+ 'url': [f'https://example.com/image_{i}.jpg' for i in range(num_samples)],
178
+ 'caption': [f'A beautiful image number {i}' for i in range(num_samples)],
179
+ 'aesthetic_score': np.random.uniform(5.0, 9.0, num_samples),
180
+ 'watermark_prob': np.random.uniform(0.0, 1.0, num_samples),
181
+ 'NSFW': ['UNLIKELY'] * num_samples,
182
+ 'image_file': [f'image_{i:06d}.jpg' for i in range(num_samples)]
183
+ }
184
+
185
+ df = pd.DataFrame(data)
186
+
187
+ # 保存元数据
188
+ metadata_path = os.path.join(output_dir, "metadata.parquet")
189
+ df.to_parquet(metadata_path)
190
+
191
+ print(f"虚拟元数据创建完成: {metadata_path}")
192
+ print(f"包含 {len(df)} 条记录")
193
+ print(f"列名: {list(df.columns)}")
194
+
195
+ # 创建虚拟图像目录
196
+ images_dir = os.path.join(output_dir, "images")
197
+ os.makedirs(images_dir, exist_ok=True)
198
+
199
+ print(f"虚拟图像目录: {images_dir}")
200
+ print("注意:虚拟数据集不包含实际图像文件,仅用于测试代码流程")
201
+
202
+ # 保存数据集信息
203
+ sample_info = {
204
+ "total_samples": len(df),
205
+ "columns": list(df.columns),
206
+ "dataset": "dummy",
207
+ "description": f"虚拟测试数据集 ({num_samples}条记录)",
208
+ "note": "不包含实际图像文件,仅用于测试代码流程"
209
+ }
210
+
211
+ with open(os.path.join(output_dir, "dataset_info.json"), "w") as f:
212
+ json.dump(sample_info, f, indent=2)
213
+
214
+ print(f"数据集信息已保存到: {os.path.join(output_dir, 'dataset_info.json')}")
215
+
216
+ return True
217
+
218
+ def main():
219
+ parser = argparse.ArgumentParser(description="下载LAION数据集")
220
+ parser.add_argument("--output-dir", default="./data/laion", help="输出目录")
221
+ parser.add_argument("--subset", default="6.5+", choices=["6.5+", "7.0+"],
222
+ help="LAION-Aesthetic子集版本")
223
+ parser.add_argument("--sample-size", type=int, default=10000,
224
+ help="LAION-5B样本大小")
225
+ parser.add_argument("--dummy", action="store_true",
226
+ help="创建虚拟数据集用于测试")
227
+ parser.add_argument("--dummy-size", type=int, default=100,
228
+ help="虚拟数据集大小")
229
+
230
+ args = parser.parse_args()
231
+
232
+ print("=" * 60)
233
+ print("LAION数据集下载工具")
234
+ print("=" * 60)
235
+
236
+ if args.dummy:
237
+ print("\n创建虚拟数据集模式...")
238
+ success = create_dummy_dataset(args.output_dir, args.dummy_size)
239
+ else:
240
+ print("\n下载LAION-Aesthetic数据集...")
241
+ print(f"输出目录: {args.output_dir}")
242
+ print(f"子集版本: {args.subset}")
243
+ print("\n注意:")
244
+ print("1. LAION数据集很大,下载需要时间和存储空间")
245
+ print("2. 元数据文件通常几百MB到几GB")
246
+ print("3. 图像文件需要额外下载")
247
+ print("4. 建议先使用虚拟数据集测试代码流程")
248
+
249
+ response = input("\n是否继续? (y/n): ")
250
+ if response.lower() != 'y':
251
+ print("取消下载")
252
+ return
253
+
254
+ success = download_laion_aesthetic(args.output_dir, args.subset)
255
+
256
+ if success:
257
+ print("\n" + "=" * 60)
258
+ print("下载完成!")
259
+ print("=" * 60)
260
+ print("\n下一步:")
261
+ print("1. 检查下载的文件:")
262
+ print(f" ls -lh {args.output_dir}/")
263
+ print("2. 测试数据集加载:")
264
+ print(" python -c \"import pandas as pd; df=pd.read_parquet('{}'); print('记录数:', len(df))\"".format(
265
+ os.path.join(args.output_dir, "metadata.parquet")))
266
+ print("3. 运行测试脚本:")
267
+ print(" python src/data/dataset.py")
268
+ else:
269
+ print("\n下载失败,请检查错误信息")
270
+
271
+ if __name__ == "__main__":
272
+ main()
scripts/export.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 模型导出脚本
4
+ 用于导出训练好的模型为不同格式
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import yaml
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ # 添加项目根目录到Python路径
15
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
16
+
17
+ from src.models.unet_light import UNetLight
18
+ from src.models.diffusion import DiffusionProcess
19
+ from src.inference.optimization import ModelOptimizer, ONNXExporter, optimize_model_for_p4
20
+
21
+
22
+ def load_config(config_path: str) -> dict:
23
+ """加载配置文件"""
24
+ with open(config_path, 'r') as f:
25
+ config = yaml.safe_load(f)
26
+ return config
27
+
28
+
29
+ def load_model(checkpoint_path: str, config: dict, device: torch.device) -> nn.Module:
30
+ """加载模型"""
31
+ # 加载模型配置
32
+ model_config_path = config.get('model_config', 'configs/model/unet_light.yaml')
33
+ model_config = load_config(model_config_path)
34
+
35
+ # 创建模型
36
+ model = UNetLight(model_config)
37
+
38
+ # 加载检查点
39
+ print(f"加载检查点: {checkpoint_path}")
40
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
41
+
42
+ # 加载模型权重
43
+ if 'model_state_dict' in checkpoint:
44
+ model.load_state_dict(checkpoint['model_state_dict'])
45
+ elif 'state_dict' in checkpoint:
46
+ model.load_state_dict(checkpoint['state_dict'])
47
+ else:
48
+ model.load_state_dict(checkpoint)
49
+
50
+ # 移动到设备
51
+ model = model.to(device)
52
+ model.eval()
53
+
54
+ print(f"模型加载完成")
55
+
56
+ return model
57
+
58
+
59
+ def export_torchscript(model: nn.Module, output_path: str):
60
+ """导出为TorchScript格式"""
61
+ print(f"导出为TorchScript: {output_path}")
62
+
63
+ # 创建示例输入
64
+ example_input = torch.randn(1, model.in_channels, 64, 64)
65
+ example_timestep = torch.tensor([500])
66
+ example_context = torch.randn(1, 77, 768)
67
+
68
+ # 跟踪模型
69
+ traced_model = torch.jit.trace(
70
+ model,
71
+ (example_input, example_timestep, example_context),
72
+ check_trace=False
73
+ )
74
+
75
+ # 保存
76
+ traced_model.save(output_path)
77
+ print(f"TorchScript模型已保存: {output_path}")
78
+
79
+ return traced_model
80
+
81
+
82
+ def export_onnx(model: nn.Module, output_path: str, opset_version: int = 14):
83
+ """导出为ONNX格式"""
84
+ print(f"导出为ONNX: {output_path}")
85
+
86
+ # 创建示例输入
87
+ example_input = torch.randn(1, model.in_channels, 64, 64)
88
+ example_timestep = torch.tensor([500])
89
+ example_context = torch.randn(1, 77, 768)
90
+
91
+ # 设置动态轴
92
+ dynamic_axes = {
93
+ 'input': {0: 'batch_size'},
94
+ 'timestep': {0: 'batch_size'},
95
+ 'context': {0: 'batch_size'},
96
+ 'output': {0: 'batch_size'}
97
+ }
98
+
99
+ # 导出
100
+ torch.onnx.export(
101
+ model,
102
+ (example_input, example_timestep, example_context),
103
+ output_path,
104
+ input_names=['input', 'timestep', 'context'],
105
+ output_names=['output'],
106
+ dynamic_axes=dynamic_axes,
107
+ opset_version=opset_version,
108
+ do_constant_folding=True,
109
+ verbose=False
110
+ )
111
+
112
+ print(f"ONNX模型已保存: {output_path}")
113
+
114
+ # 验证ONNX模型
115
+ import onnx
116
+ onnx_model = onnx.load(output_path)
117
+ onnx.checker.check_model(onnx_model)
118
+ print("ONNX模型验证成功")
119
+
120
+
121
+ def export_safetensors(model: nn.Module, output_path: str):
122
+ """导出为safetensors格式"""
123
+ try:
124
+ from safetensors.torch import save_file
125
+
126
+ # 转换为safetensors格式
127
+ state_dict = model.state_dict()
128
+ save_file(state_dict, output_path)
129
+
130
+ print(f"Safetensors模型已保存: {output_path}")
131
+ except ImportError:
132
+ print("safetensors未安装,跳过safetensors导出")
133
+ print("安装: pip install safetensors")
134
+
135
+
136
+ def optimize_and_export(
137
+ checkpoint_path: str,
138
+ output_dir: str,
139
+ formats: list = ['torchscript', 'onnx', 'safetensors'],
140
+ optimize_for_p4: bool = True
141
+ ):
142
+ """优化并导出模型"""
143
+ # 创建输出目录
144
+ os.makedirs(output_dir, exist_ok=True)
145
+
146
+ # 加载配置
147
+ config_path = "configs/training/p4_optimized.yaml"
148
+ config = load_config(config_path)
149
+
150
+ # 设备
151
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
152
+
153
+ # 加载模型
154
+ model = load_model(checkpoint_path, config, device)
155
+
156
+ # 优化模型(针对P4)
157
+ if optimize_for_p4:
158
+ print("优化模型(针对P4)...")
159
+ model = optimize_model_for_p4(model)
160
+
161
+ # 获取模型信息
162
+ total_params = sum(p.numel() for p in model.parameters())
163
+ model_size_mb = total_params * 4 / 1024**2 # fp32
164
+
165
+ print(f"\n模型信息:")
166
+ print(f" 参数量: {total_params:,}")
167
+ print(f" 模型大小: {model_size_mb:.2f} MB (fp32)")
168
+
169
+ # 导���为不同格式
170
+ base_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
171
+
172
+ for fmt in formats:
173
+ if fmt == 'torchscript':
174
+ output_path = os.path.join(output_dir, f"{base_name}.torchscript.pt")
175
+ export_torchscript(model, output_path)
176
+
177
+ elif fmt == 'onnx':
178
+ output_path = os.path.join(output_dir, f"{base_name}.onnx")
179
+ export_onnx(model, output_path)
180
+
181
+ elif fmt == 'safetensors':
182
+ output_path = os.path.join(output_dir, f"{base_name}.safetensors")
183
+ export_safetensors(model, output_path)
184
+
185
+ elif fmt == 'pth':
186
+ output_path = os.path.join(output_dir, f"{base_name}.pth")
187
+ torch.save(model.state_dict(), output_path)
188
+ print(f"PyTorch模型已保存: {output_path}")
189
+
190
+ else:
191
+ print(f"未知的格式: {fmt}")
192
+
193
+ print(f"\n所有模型已导出到: {output_dir}")
194
+
195
+
196
+ def create_lite_version(model: nn.Module, reduction_factor: float = 0.5) -> nn.Module:
197
+ """创建轻量版本(通过减少通道数)"""
198
+ # 注意:这是一个示例,需要根据实际模型结构调整
199
+ print(f"创建轻量版本,减少因子: {reduction_factor}")
200
+
201
+ # 这里应该实现具体的轻量化逻辑
202
+ # 例如,减少UNet的通道数
203
+
204
+ return model
205
+
206
+
207
+ def main():
208
+ """主函数"""
209
+ parser = argparse.ArgumentParser(description="导出Lumina模型")
210
+
211
+ parser.add_argument(
212
+ "--checkpoint",
213
+ type=str,
214
+ required=True,
215
+ help="模型检查点路径"
216
+ )
217
+
218
+ parser.add_argument(
219
+ "--output-dir",
220
+ type=str,
221
+ default="./exported_models",
222
+ help="输出目录"
223
+ )
224
+
225
+ parser.add_argument(
226
+ "--formats",
227
+ type=str,
228
+ nargs="+",
229
+ default=['torchscript', 'onnx'],
230
+ choices=['torchscript', 'onnx', 'safetensors', 'pth'],
231
+ help="导出格式"
232
+ )
233
+
234
+ parser.add_argument(
235
+ "--optimize",
236
+ action="store_true",
237
+ help="优化模型(针对P4)"
238
+ )
239
+
240
+ parser.add_argument(
241
+ "--lite",
242
+ action="store_true",
243
+ help="创建轻量版本"
244
+ )
245
+
246
+ parser.add_argument(
247
+ "--lite-factor",
248
+ type=float,
249
+ default=0.5,
250
+ help="轻量化减少因子"
251
+ )
252
+
253
+ args = parser.parse_args()
254
+
255
+ # 检查输入文件
256
+ if not os.path.exists(args.checkpoint):
257
+ print(f"错误: 检查点文件不存在: {args.checkpoint}")
258
+ return
259
+
260
+ # 优化并导出
261
+ optimize_and_export(
262
+ checkpoint_path=args.checkpoint,
263
+ output_dir=args.output_dir,
264
+ formats=args.formats,
265
+ optimize_for_p4=args.optimize
266
+ )
267
+
268
+
269
+ if __name__ == "__main__":
270
+ main()
scripts/train.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Lumina训练脚本
4
+ 用于训练轻量级图像生成模型
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import yaml
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import DataLoader
14
+ import warnings
15
+
16
+ # 添加项目根目录到Python路径
17
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
18
+
19
+ from src.models.unet_light import UNetLight
20
+ from src.models.diffusion import DiffusionProcess, DiffusionModel
21
+ from src.data.dataset import create_data_loaders
22
+ from src.data.text_encoder import create_text_encoder
23
+ from src.training.trainer_p4 import P4Trainer
24
+ from src.training.memory_manager import MemoryOptimizer
25
+ from src.training.callbacks import create_default_callbacks
26
+
27
+
28
+ def load_config(config_path: str) -> dict:
29
+ """加载配置文件"""
30
+ with open(config_path, 'r') as f:
31
+ config = yaml.safe_load(f)
32
+ return config
33
+
34
+
35
+ def setup_environment(config: dict):
36
+ """设置训练环境"""
37
+ # 设置随机种子
38
+ seed = config.get('seed', 42)
39
+ torch.manual_seed(seed)
40
+ if torch.cuda.is_available():
41
+ torch.cuda.manual_seed(seed)
42
+
43
+ # 设置CUDA设备
44
+ device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
45
+ if device == 'cuda' and not torch.cuda.is_available():
46
+ warnings.warn("CUDA不可用,使用CPU")
47
+ device = 'cpu'
48
+
49
+ # 创建输出目录
50
+ output_dir = config.get('output_dir', './output')
51
+ os.makedirs(output_dir, exist_ok=True)
52
+
53
+ # 设置日志
54
+ log_dir = config.get('log_dir', './logs')
55
+ os.makedirs(log_dir, exist_ok=True)
56
+
57
+ print(f"环境设置完成:")
58
+ print(f" 设备: {device}")
59
+ print(f" 随机种子: {seed}")
60
+ print(f" 输出目录: {output_dir}")
61
+ print(f" 日志目录: {log_dir}")
62
+
63
+ return device
64
+
65
+
66
+ def create_model(config: dict, device: torch.device) -> nn.Module:
67
+ """创建模型"""
68
+ # 加载模型配置
69
+ model_config_path = config.get('model_config', 'configs/model/unet_light.yaml')
70
+ model_config = load_config(model_config_path)
71
+
72
+ # 创建UNet模型
73
+ model = UNetLight(model_config)
74
+
75
+ # 加载预训练权重(如果有)
76
+ pretrained_path = config.get('pretrained_path')
77
+ if pretrained_path and os.path.exists(pretrained_path):
78
+ print(f"加载预训练权重: {pretrained_path}")
79
+ checkpoint = torch.load(pretrained_path, map_location='cpu')
80
+ model.load_state_dict(checkpoint['model_state_dict'])
81
+
82
+ # 移动到设备
83
+ model = model.to(device)
84
+
85
+ # 打印模型信息
86
+ total_params = sum(p.numel() for p in model.parameters())
87
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
88
+
89
+ print(f"模型创建完成:")
90
+ print(f" 总参数量: {total_params:,}")
91
+ print(f" 可训练参数量: {trainable_params:,}")
92
+ print(f" 模型大小: {total_params * 4 / 1024**2:.2f} MB (fp32)")
93
+
94
+ return model
95
+
96
+
97
+ def create_diffusion(config: dict) -> DiffusionProcess:
98
+ """创建扩散过程"""
99
+ diffusion_config_path = config.get('diffusion_config', 'configs/model/diffusion.yaml')
100
+ diffusion_config = load_config(diffusion_config_path)
101
+
102
+ diffusion = DiffusionProcess(diffusion_config)
103
+
104
+ print(f"扩散过程创建完成:")
105
+ print(f" 训练时间步: {diffusion.num_train_timesteps}")
106
+ print(f" 推理时间步: {diffusion.num_inference_timesteps}")
107
+ print(f" Beta调度: {diffusion.beta_schedule}")
108
+
109
+ return diffusion
110
+
111
+
112
+ def create_data_pipeline(config: dict):
113
+ """创建数据管道"""
114
+ data_config_path = config.get('data_config', 'configs/data/laion_filtered.yaml')
115
+ data_config = load_config(data_config_path)
116
+
117
+ # 创建文本编码器
118
+ text_encoder = create_text_encoder(data_config)
119
+
120
+ # 创建数据加载器
121
+ train_loader, val_loader = create_data_loaders(data_config)
122
+
123
+ print(f"数据管道创建完成:")
124
+ print(f" 训练集大小: {len(train_loader.dataset)}")
125
+ print(f" 验证集大小: {len(val_loader.dataset) if val_loader else 0}")
126
+ print(f" 批次大小: {train_loader.batch_size}")
127
+ print(f" 梯度累积步数: {config.get('gradient_accumulation_steps', 8)}")
128
+
129
+ return train_loader, val_loader, text_encoder
130
+
131
+
132
+ def create_optimizer(model: nn.Module, config: dict):
133
+ """创建优化器"""
134
+ optimizer_config = config.get('optimizer', {})
135
+ optimizer_type = optimizer_config.get('type', 'AdamW')
136
+ learning_rate = optimizer_config.get('learning_rate', 1e-4)
137
+ weight_decay = optimizer_config.get('weight_decay', 0.01)
138
+
139
+ if optimizer_type == 'AdamW':
140
+ optimizer = torch.optim.AdamW(
141
+ model.parameters(),
142
+ lr=learning_rate,
143
+ weight_decay=weight_decay,
144
+ betas=(0.9, 0.999),
145
+ eps=1e-8
146
+ )
147
+ elif optimizer_type == 'Adam':
148
+ optimizer = torch.optim.Adam(
149
+ model.parameters(),
150
+ lr=learning_rate,
151
+ weight_decay=weight_decay
152
+ )
153
+ else:
154
+ raise ValueError(f"未知的优化器类型: {optimizer_type}")
155
+
156
+ print(f"优化器创建完成:")
157
+ print(f" 类型: {optimizer_type}")
158
+ print(f" 学习率: {learning_rate}")
159
+ print(f" 权重衰减: {weight_decay}")
160
+
161
+ return optimizer
162
+
163
+
164
+ def setup_memory_optimization(model: nn.Module, optimizer, config: dict):
165
+ """设置内存优化"""
166
+ memory_optimizer = MemoryOptimizer(config)
167
+ memory_optimizer.setup_model_optimizations(model, optimizer)
168
+
169
+ # 打印内存信息
170
+ if torch.cuda.is_available():
171
+ allocated = torch.cuda.memory_allocated() / 1024**3
172
+ reserved = torch.cuda.memory_reserved() / 1024**3
173
+ print(f"内存优化设置完成:")
174
+ print(f" GPU已分配: {allocated:.2f} GB")
175
+ print(f" GPU已保留: {reserved:.2f} GB")
176
+
177
+ return memory_optimizer
178
+
179
+
180
+ def train(config_path: str, resume_from: str = None):
181
+ """训练主函数"""
182
+ print("=" * 60)
183
+ print("Lumina 训练开始")
184
+ print("=" * 60)
185
+
186
+ # 加载配置
187
+ config = load_config(config_path)
188
+
189
+ # 设置环境
190
+ device = setup_environment(config)
191
+
192
+ # 创建模型
193
+ model = create_model(config, device)
194
+
195
+ # 创建扩散过程
196
+ diffusion = create_diffusion(config)
197
+
198
+ # 创建扩散模型
199
+ diffusion_model = DiffusionModel(model, diffusion)
200
+
201
+ # 创建数据管道
202
+ train_loader, val_loader, text_encoder = create_data_pipeline(config)
203
+
204
+ # 创建优化器
205
+ optimizer = create_optimizer(model, config)
206
+
207
+ # 设置内存优化
208
+ memory_optimizer = setup_memory_optimization(model, optimizer, config)
209
+
210
+ # 创建训练器
211
+ trainer = P4Trainer(
212
+ model=model,
213
+ diffusion=diffusion,
214
+ optimizer=optimizer,
215
+ train_loader=train_loader,
216
+ val_loader=val_loader,
217
+ config=config,
218
+ device=device
219
+ )
220
+
221
+ # 创建回调
222
+ callbacks = create_default_callbacks(config)
223
+
224
+ # 加载检查点(如果存在)
225
+ if resume_from and os.path.exists(resume_from):
226
+ print(f"从检查点恢复训练: {resume_from}")
227
+ trainer.load_checkpoint(resume_from)
228
+
229
+ # 开始训练
230
+ try:
231
+ print("\n开始训练...")
232
+ trainer.train()
233
+
234
+ print("\n" + "=" * 60)
235
+ print("训练完成!")
236
+ print(f"最佳验证损失: {trainer.best_loss:.4f}")
237
+ print(f"总训练步数: {trainer.global_step}")
238
+ print("=" * 60)
239
+
240
+ except KeyboardInterrupt:
241
+ print("\n训练被中断")
242
+ except Exception as e:
243
+ print(f"\n训练出错: {e}")
244
+ import traceback
245
+ traceback.print_exc()
246
+
247
+ finally:
248
+ # 保存最终检查点
249
+ final_checkpoint = os.path.join(
250
+ config.get('checkpoint_dir', './checkpoints'),
251
+ 'final_model.pt'
252
+ )
253
+ trainer.save_checkpoint(final_checkpoint)
254
+
255
+
256
+ def main():
257
+ """主函数"""
258
+ parser = argparse.ArgumentParser(description="训练Lumina图像生成模型")
259
+
260
+ parser.add_argument(
261
+ "--config",
262
+ type=str,
263
+ default="configs/training/p4_optimized.yaml",
264
+ help="训练配置文件路径"
265
+ )
266
+
267
+ parser.add_argument(
268
+ "--resume",
269
+ type=str,
270
+ help="从检查点恢复训练"
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--debug",
275
+ action="store_true",
276
+ help="调试模式"
277
+ )
278
+
279
+ args = parser.parse_args()
280
+
281
+ # 调试模式设置
282
+ if args.debug:
283
+ import warnings
284
+ warnings.filterwarnings("always")
285
+ torch.autograd.set_detect_anomaly(True)
286
+ print("调试模式已启用")
287
+
288
+ # 开始训练
289
+ train(args.config, args.resume)
290
+
291
+
292
+ if __name__ == "__main__":
293
+ main()
src/data/__pycache__/dataset.cpython-313.pyc ADDED
Binary file (13.8 kB). View file
 
src/data/dataset.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from PIL import Image
4
+ import pandas as pd
5
+ import os
6
+ from typing import Dict, List, Optional, Tuple
7
+ import numpy as np
8
+ import json
9
+ from pathlib import Path
10
+
11
+
12
+ class LAIONDataset(Dataset):
13
+ """LAION数据集"""
14
+ def __init__(self, config: dict, transform=None, split: str = 'train'):
15
+ self.config = config
16
+ self.transform = transform
17
+ self.split = split
18
+
19
+ # 加载元数据
20
+ metadata_path = config['dataset'].get('metadata_file', './data/laion/metadata.parquet')
21
+ self.metadata = pd.read_parquet(metadata_path)
22
+
23
+ # 应用过滤条件
24
+ self._apply_filters(config.get('filters', {}))
25
+
26
+ # 数据拆分
27
+ self._split_data(config.get('split', {}))
28
+
29
+ # 缓存
30
+ self.use_cache = config.get('use_cache', False)
31
+ self.cache_dir = config.get('cache_dir', './data/cache')
32
+ if self.use_cache:
33
+ os.makedirs(self.cache_dir, exist_ok=True)
34
+
35
+ # 限制样本数
36
+ max_samples = config.get('max_samples', None)
37
+ if max_samples is not None and len(self.metadata) > max_samples:
38
+ self.metadata = self.metadata.sample(max_samples, random_state=42)
39
+
40
+ # 文本缓存
41
+ self.text_cache = {}
42
+
43
+ print(f"数据集加载完成: {len(self.metadata)} 个样本 ({split}集)")
44
+
45
+ def _apply_filters(self, filters: dict):
46
+ """应用过滤条件"""
47
+ if 'aesthetic_score' in filters:
48
+ threshold = filters['aesthetic_score']
49
+ if 'aesthetic_score' in self.metadata.columns:
50
+ self.metadata = self.metadata[self.metadata['aesthetic_score'] >= threshold]
51
+
52
+ if 'watermark_prob' in filters:
53
+ threshold = filters['watermark_prob']
54
+ if 'watermark_prob' in self.metadata.columns:
55
+ self.metadata = self.metadata[self.metadata['watermark_prob'] <= threshold]
56
+
57
+ if 'nsfw' in filters and not filters['nsfw']:
58
+ if 'NSFW' in self.metadata.columns:
59
+ self.metadata = self.metadata[self.metadata['NSFW'] != 'NSFW']
60
+
61
+ def _split_data(self, split_config: dict):
62
+ """拆分数据集"""
63
+ if self.split not in ['train', 'val']:
64
+ return
65
+
66
+ train_ratio = split_config.get('train', 0.95)
67
+ val_ratio = split_config.get('val', 0.05)
68
+
69
+ # 确保拆分比例之和为1
70
+ total = train_ratio + val_ratio
71
+ train_ratio /= total
72
+ val_ratio /= total
73
+
74
+ # 随机拆分
75
+ seed = split_config.get('seed', 42)
76
+ shuffled = self.metadata.sample(frac=1, random_state=seed).reset_index(drop=True)
77
+
78
+ if self.split == 'train':
79
+ split_point = int(len(shuffled) * train_ratio)
80
+ self.metadata = shuffled[:split_point]
81
+ else:
82
+ split_point = int(len(shuffled) * train_ratio)
83
+ self.metadata = shuffled[split_point:]
84
+
85
+ def __len__(self) -> int:
86
+ return len(self.metadata)
87
+
88
+ def _get_image_path(self, row) -> str:
89
+ """获取图像路径"""
90
+ # 尝试不同的列名
91
+ for col in ['image_file', 'filepath', 'path', 'url_local']:
92
+ if col in row:
93
+ path = row[col]
94
+ # 如果是相对路径,添加基础路径
95
+ if not os.path.isabs(path):
96
+ base_path = self.config['dataset'].get('path', './data/laion')
97
+ path = os.path.join(base_path, path)
98
+ return path
99
+
100
+ # 如果没有找到路径,使用URL哈希
101
+ if 'url' in row:
102
+ import hashlib
103
+ url_hash = hashlib.md5(row['url'].encode()).hexdigest()
104
+ base_path = self.config['dataset'].get('path', './data/laion')
105
+ path = os.path.join(base_path, f"{url_hash}.jpg")
106
+ return path
107
+
108
+ raise ValueError(f"无法找到图像路径: {row}")
109
+
110
+ def __getitem__(self, idx: int) -> Dict:
111
+ row = self.metadata.iloc[idx]
112
+
113
+ # 缓存键
114
+ cache_key = f"{self.split}_{idx}"
115
+
116
+ # 检查缓存
117
+ if self.use_cache and cache_key in self.text_cache:
118
+ text_embedding = self.text_cache[cache_key]
119
+ else:
120
+ # 获取文本描述
121
+ text = row.get('caption', row.get('text', row.get('description', '')))
122
+
123
+ # 这里应该调用文本编码器,但为了简化,我们返回原始文本
124
+ # 在实际使用中,应该使用预训练的CLIP编码器
125
+ text_embedding = text
126
+
127
+ # 缓存
128
+ if self.use_cache:
129
+ self.text_cache[cache_key] = text_embedding
130
+
131
+ # 获取图像
132
+ try:
133
+ image_path = self._get_image_path(row)
134
+ image = Image.open(image_path).convert('RGB')
135
+
136
+ # 应用变换
137
+ if self.transform:
138
+ image = self.transform(image)
139
+ except Exception as e:
140
+ # 如果图像加载失败,返回一个空白图像
141
+ print(f"加载图像失败 {image_path}: {e}")
142
+ image = torch.zeros(3, 512, 512)
143
+ text = "invalid image"
144
+ text_embedding = text
145
+
146
+ return {
147
+ 'image': image,
148
+ 'text': text_embedding if isinstance(text_embedding, str) else '',
149
+ 'text_embedding': text_embedding if not isinstance(text_embedding, str) else None,
150
+ 'image_path': image_path if 'image_path' in locals() else '',
151
+ 'index': idx
152
+ }
153
+
154
+
155
+ class TextImageDataset(Dataset):
156
+ """文本-图像对数据集"""
157
+ def __init__(self, image_dir: str, caption_file: str, transform=None):
158
+ self.image_dir = image_dir
159
+ self.transform = transform
160
+
161
+ # 加载标注文件
162
+ if caption_file.endswith('.json'):
163
+ with open(caption_file, 'r') as f:
164
+ self.captions = json.load(f)
165
+ elif caption_file.endswith('.csv'):
166
+ self.captions = pd.read_csv(caption_file)
167
+ else:
168
+ raise ValueError(f"不支持的标注文件格式: {caption_file}")
169
+
170
+ # 验证图像文件是否存在
171
+ self.valid_samples = []
172
+ for item in self.captions:
173
+ if isinstance(item, dict):
174
+ image_name = item.get('image_name', item.get('file_name', ''))
175
+ caption = item.get('caption', '')
176
+ else:
177
+ image_name = item[0]
178
+ caption = item[1]
179
+
180
+ image_path = os.path.join(self.image_dir, image_name)
181
+ if os.path.exists(image_path):
182
+ self.valid_samples.append((image_path, caption))
183
+
184
+ print(f"找到 {len(self.valid_samples)} 个有效样本")
185
+
186
+ def __len__(self) -> int:
187
+ return len(self.valid_samples)
188
+
189
+ def __getitem__(self, idx: int) -> Dict:
190
+ image_path, caption = self.valid_samples[idx]
191
+
192
+ # 加载图像
193
+ image = Image.open(image_path).convert('RGB')
194
+
195
+ # 应用变换
196
+ if self.transform:
197
+ image = self.transform(image)
198
+
199
+ return {
200
+ 'image': image,
201
+ 'text': caption,
202
+ 'image_path': image_path
203
+ }
204
+
205
+
206
+ class CachedDataset(Dataset):
207
+ """缓存数据集,加速训练"""
208
+ def __init__(self, dataset: Dataset, cache_dir: str = './cache'):
209
+ self.dataset = dataset
210
+ self.cache_dir = cache_dir
211
+ os.makedirs(cache_dir, exist_ok=True)
212
+
213
+ self.cache_files = []
214
+ for i in range(len(dataset)):
215
+ cache_file = os.path.join(cache_dir, f'sample_{i}.pt')
216
+ self.cache_files.append(cache_file)
217
+
218
+ def __len__(self) -> int:
219
+ return len(self.dataset)
220
+
221
+ def __getitem__(self, idx: int) -> Dict:
222
+ cache_file = self.cache_files[idx]
223
+
224
+ # 如果缓存存在,直接加载
225
+ if os.path.exists(cache_file):
226
+ try:
227
+ return torch.load(cache_file)
228
+ except:
229
+ pass
230
+
231
+ # 否则,从原始数据集加载并缓存
232
+ sample = self.dataset[idx]
233
+ torch.save(sample, cache_file)
234
+
235
+ return sample
236
+
237
+
238
+ def create_data_loaders(config: dict) -> Tuple[DataLoader, Optional[DataLoader]]:
239
+ """创建数据加载器"""
240
+ from .preprocessing import get_transform
241
+
242
+ # 获取数据变换
243
+ train_transform = get_transform(config, mode='train')
244
+ val_transform = get_transform(config, mode='val')
245
+
246
+ # 创建数据集
247
+ train_dataset = LAIONDataset(config, transform=train_transform, split='train')
248
+ val_dataset = LAIONDataset(config, transform=val_transform, split='val')
249
+
250
+ # 可选:启用缓存
251
+ if config.get('cache_dataset', True):
252
+ train_dataset = CachedDataset(train_dataset, cache_dir='./data/cache/train')
253
+ val_dataset = CachedDataset(val_dataset, cache_dir='./data/cache/val')
254
+
255
+ # 创建数据加载器
256
+ train_loader = DataLoader(
257
+ train_dataset,
258
+ batch_size=config.get('batch_size', 1),
259
+ shuffle=config.get('shuffle', True),
260
+ num_workers=config.get('num_workers', 2),
261
+ pin_memory=config.get('pin_memory', True),
262
+ prefetch_factor=config.get('prefetch_factor', 2),
263
+ persistent_workers=config.get('persistent_workers', True)
264
+ )
265
+
266
+ val_loader = DataLoader(
267
+ val_dataset,
268
+ batch_size=1, # 验证时批次大小为1
269
+ shuffle=False,
270
+ num_workers=config.get('num_workers', 2),
271
+ pin_memory=True
272
+ )
273
+
274
+ return train_loader, val_loader
275
+
276
+
277
+ def test_dataset():
278
+ """测试数据集"""
279
+ import yaml
280
+
281
+ # 加载配置
282
+ with open('configs/data/laion_filtered.yaml', 'r') as f:
283
+ config = yaml.safe_load(f)
284
+
285
+ # 创建数据集
286
+ dataset = LAIONDataset(config, split='train')
287
+
288
+ # 测试样本
289
+ sample = dataset[0]
290
+ print(f"样本键: {list(sample.keys())}")
291
+ print(f"图像形状: {sample['image'].shape if hasattr(sample['image'], 'shape') else type(sample['image'])}")
292
+ print(f"文本: {sample['text'][:100]}...")
293
+
294
+ return dataset
295
+
296
+
297
+ if __name__ == '__main__':
298
+ test_dataset()
src/data/preprocessing.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ from PIL import Image
4
+ import numpy as np
5
+ from typing import Dict, List, Optional, Tuple
6
+ import random
7
+
8
+
9
+ def get_transform(config: dict, mode: str = 'train') -> T.Compose:
10
+ """获取数据预处理变换"""
11
+ preprocessing = config.get('preprocessing', {})
12
+ target_size = preprocessing.get('target_size', 512)
13
+ resize_mode = preprocessing.get('resize_mode', 'center_crop')
14
+
15
+ transforms_list = []
16
+
17
+ # 训练和验证的不同变换
18
+ if mode == 'train':
19
+ # 随机裁剪
20
+ if preprocessing.get('random_crop', True):
21
+ transforms_list.append(T.RandomResizedCrop(
22
+ target_size,
23
+ scale=(0.8, 1.0),
24
+ ratio=(0.8, 1.2)
25
+ ))
26
+ else:
27
+ transforms_list.append(T.Resize(target_size, interpolation=T.InterpolationMode.BILINEAR))
28
+
29
+ # 随机水平翻转
30
+ if preprocessing.get('random_flip', True):
31
+ transforms_list.append(T.RandomHorizontalFlip(p=0.5))
32
+
33
+ # 颜色抖动
34
+ if preprocessing.get('color_jitter', 0.05) > 0:
35
+ color_jitter = preprocessing['color_jitter']
36
+ transforms_list.append(T.ColorJitter(
37
+ brightness=color_jitter,
38
+ contrast=color_jitter,
39
+ saturation=color_jitter,
40
+ hue=min(0.1, color_jitter)
41
+ ))
42
+
43
+ # 随机旋转
44
+ if preprocessing.get('random_rotation', 0.0) > 0:
45
+ max_angle = preprocessing['random_rotation']
46
+ transforms_list.append(T.RandomRotation(degrees=(-max_angle, max_angle)))
47
+
48
+ else: # 验证/测试模式
49
+ # 中心裁剪
50
+ if resize_mode == 'center_crop':
51
+ transforms_list.extend([
52
+ T.Resize(target_size, interpolation=T.InterpolationMode.BILINEAR),
53
+ T.CenterCrop(target_size)
54
+ ])
55
+ elif resize_mode == 'resize':
56
+ transforms_list.append(T.Resize(target_size, interpolation=T.InterpolationMode.BILINEAR))
57
+ elif resize_mode == 'random_crop':
58
+ transforms_list.append(T.RandomCrop(target_size))
59
+ else:
60
+ raise ValueError(f"未知的resize_mode: {resize_mode}")
61
+
62
+ # 转换为Tensor
63
+ transforms_list.append(T.ToTensor())
64
+
65
+ # 归一化
66
+ normalize_config = preprocessing.get('normalize', {})
67
+ mean = normalize_config.get('mean', [0.5, 0.5, 0.5])
68
+ std = normalize_config.get('std', [0.5, 0.5, 0.5])
69
+ transforms_list.append(T.Normalize(mean=mean, std=std))
70
+
71
+ return T.Compose(transforms_list)
72
+
73
+
74
+ class TextPreprocessor:
75
+ """文本预处理器"""
76
+ def __init__(self, config: dict):
77
+ self.config = config.get('preprocessing', {})
78
+
79
+ # 文本处理参数
80
+ self.max_length = self.config.get('max_length', 77)
81
+ self.truncation = self.config.get('truncation', True)
82
+ self.padding = self.config.get('padding', 'max_length')
83
+
84
+ # 尝试加载tokenizer
85
+ self.tokenizer = None
86
+ self._init_tokenizer()
87
+
88
+ def _init_tokenizer(self):
89
+ """初始化tokenizer"""
90
+ try:
91
+ from transformers import CLIPTokenizer
92
+ tokenizer_name = self.config.get('tokenizer', 'openai/clip-vit-base-patch32')
93
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name)
94
+ except ImportError:
95
+ print("警告: 未安装transformers,无法使用CLIP tokenizer")
96
+ except Exception as e:
97
+ print(f"加载tokenizer失败: {e}")
98
+
99
+ def preprocess_text(self, text: str) -> Dict:
100
+ """预处理文本"""
101
+ if self.tokenizer is not None:
102
+ # 使用CLIP tokenizer
103
+ inputs = self.tokenizer(
104
+ text,
105
+ max_length=self.max_length,
106
+ padding=self.padding,
107
+ truncation=self.truncation,
108
+ return_tensors="pt"
109
+ )
110
+ return {
111
+ 'input_ids': inputs['input_ids'].squeeze(0),
112
+ 'attention_mask': inputs['attention_mask'].squeeze(0)
113
+ }
114
+ else:
115
+ # 简单的文本处理
116
+ return {
117
+ 'text': text,
118
+ 'length': len(text)
119
+ }
120
+
121
+ def batch_preprocess(self, texts: List[str]) -> Dict:
122
+ """批量预处理文本"""
123
+ if self.tokenizer is not None:
124
+ inputs = self.tokenizer(
125
+ texts,
126
+ max_length=self.max_length,
127
+ padding=self.padding,
128
+ truncation=self.truncation,
129
+ return_tensors="pt"
130
+ )
131
+ return inputs
132
+ else:
133
+ return {'texts': texts}
134
+
135
+
136
+ class ImagePreprocessor:
137
+ """图像预处理器"""
138
+ def __init__(self, config: dict):
139
+ self.config = config.get('preprocessing', {})
140
+ self.transform = get_transform(config, mode='train')
141
+
142
+ def preprocess_image(self, image: Image.Image) -> torch.Tensor:
143
+ """预处理单张图像"""
144
+ return self.transform(image)
145
+
146
+ def batch_preprocess(self, images: List[Image.Image]) -> torch.Tensor:
147
+ """批量预处理图像"""
148
+ return torch.stack([self.transform(img) for img in images])
149
+
150
+ def preprocess_for_vae(self, image: torch.Tensor) -> torch.Tensor:
151
+ """为VAE编码预处理图像"""
152
+ # VAE期望输入在[-1, 1]范围内
153
+ return image * 2.0 - 1.0
154
+
155
+ def postprocess_from_vae(self, latents: torch.Tensor) -> torch.Tensor:
156
+ """从VAE解码后处理图像"""
157
+ # 将[-1, 1]范围转换回[0, 1]
158
+ return (latents + 1.0) / 2.0
159
+
160
+
161
+ class DataPreprocessor:
162
+ """数据预处理器(整合文本和图像处理)"""
163
+ def __init__(self, config: dict):
164
+ self.config = config
165
+ self.image_preprocessor = ImagePreprocessor(config)
166
+ self.text_preprocessor = TextPreprocessor(config)
167
+
168
+ # 文本编码器
169
+ self.text_encoder = None
170
+ self._init_text_encoder()
171
+
172
+ def _init_text_encoder(self):
173
+ """初始化文本编码器"""
174
+ try:
175
+ from transformers import CLIPTextModel
176
+ model_name = self.config.get('preprocessing', {}).get('text_encoder', 'openai/clip-vit-base-patch32')
177
+ self.text_encoder = CLIPTextModel.from_pretrained(model_name)
178
+
179
+ # 冻结参数
180
+ for param in self.text_encoder.parameters():
181
+ param.requires_grad = False
182
+
183
+ # 设置为评估模式
184
+ self.text_encoder.eval()
185
+
186
+ print(f"已加载文本编码器: {model_name}")
187
+ except Exception as e:
188
+ print(f"加载文本编码器失败: {e}")
189
+
190
+ def encode_text(self, text: str) -> torch.Tensor:
191
+ """编码文本为嵌入向量"""
192
+ if self.text_encoder is None:
193
+ raise ValueError("文本编码器未初始化")
194
+
195
+ # 预处理文本
196
+ inputs = self.text_preprocessor.preprocess_text(text)
197
+
198
+ # 编码
199
+ with torch.no_grad():
200
+ if 'input_ids' in inputs:
201
+ outputs = self.text_encoder(
202
+ input_ids=inputs['input_ids'].unsqueeze(0),
203
+ attention_mask=inputs['attention_mask'].unsqueeze(0) if 'attention_mask' in inputs else None
204
+ )
205
+ return outputs.last_hidden_state.squeeze(0)
206
+ else:
207
+ # 回退到简单的嵌入
208
+ return torch.randn(77, 768) # 默认维度
209
+
210
+ def batch_encode_text(self, texts: List[str]) -> torch.Tensor:
211
+ """批量编码文本"""
212
+ if self.text_encoder is None:
213
+ raise ValueError("文本编码器未初始化")
214
+
215
+ # 预处理文本
216
+ inputs = self.text_preprocessor.batch_preprocess(texts)
217
+
218
+ # 编码
219
+ with torch.no_grad():
220
+ if 'input_ids' in inputs:
221
+ outputs = self.text_encoder(
222
+ input_ids=inputs['input_ids'],
223
+ attention_mask=inputs.get('attention_mask', None)
224
+ )
225
+ return outputs.last_hidden_state
226
+ else:
227
+ # 回退到简单的嵌入
228
+ batch_size = len(texts)
229
+ return torch.randn(batch_size, 77, 768)
230
+
231
+ def preprocess_batch(self, batch: List[Dict]) -> Dict:
232
+ """预处理批次数据"""
233
+ images = [item['image'] for item in batch]
234
+ texts = [item['text'] for item in batch]
235
+
236
+ # 预处理图像
237
+ image_tensors = self.image_preprocessor.batch_preprocess(images)
238
+
239
+ # 编码文本
240
+ if self.text_encoder is not None:
241
+ text_embeddings = self.batch_encode_text(texts)
242
+ else:
243
+ text_embeddings = None
244
+
245
+ return {
246
+ 'images': image_tensors,
247
+ 'text_embeddings': text_embeddings,
248
+ 'texts': texts,
249
+ 'image_paths': [item.get('image_path', '') for item in batch]
250
+ }
251
+
252
+
253
+ def test_preprocessing():
254
+ """测试预处理"""
255
+ import yaml
256
+ from PIL import Image
257
+
258
+ # 创建测试图像
259
+ test_image = Image.new('RGB', (512, 512), color='red')
260
+
261
+ # 加载配置
262
+ with open('configs/data/laion_filtered.yaml', 'r') as f:
263
+ config = yaml.safe_load(f)
264
+
265
+ # 测试图像预处理
266
+ image_preprocessor = ImagePreprocessor(config)
267
+ processed_image = image_preprocessor.preprocess_image(test_image)
268
+
269
+ print(f"原始图像: {test_image.size}")
270
+ print(f"处理后图像形状: {processed_image.shape}")
271
+
272
+ # 测试文本预处理
273
+ text_preprocessor = TextPreprocessor(config)
274
+ processed_text = text_preprocessor.preprocess_text("A red square image")
275
+
276
+ print(f"文本处理结果: {processed_text}")
277
+
278
+ # 测试数据预处理器
279
+ data_preprocessor = DataPreprocessor(config)
280
+
281
+ test_batch = [
282
+ {'image': test_image, 'text': "A red square"},
283
+ {'image': test_image, 'text': "A blue circle"}
284
+ ]
285
+
286
+ processed_batch = data_preprocessor.preprocess_batch(test_batch)
287
+
288
+ print(f"批次图像形状: {processed_batch['images'].shape}")
289
+ print(f"文本嵌入形状: {processed_batch['text_embeddings'].shape if processed_batch['text_embeddings'] is not None else 'None'}")
290
+
291
+ return processed_batch
292
+
293
+
294
+ if __name__ == '__main__':
295
+ test_preprocessing()
src/data/text_encoder.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPConfig
4
+ from typing import Optional, Tuple, List
5
+ import os
6
+
7
+
8
+ class LightTextEncoder(nn.Module):
9
+ """轻量级文本编码器"""
10
+ def __init__(self, config: dict):
11
+ super().__init__()
12
+ self.config = config
13
+
14
+ # 编码器参数
15
+ self.vocab_size = config.get('vocab_size', 49408)
16
+ self.hidden_size = config.get('hidden_size', 512)
17
+ self.num_hidden_layers = config.get('num_hidden_layers', 8)
18
+ self.num_attention_heads = config.get('num_attention_heads', 8)
19
+ self.max_position_embeddings = config.get('max_position_embeddings', 77)
20
+
21
+ # 构建编码器
22
+ self.token_embedding = nn.Embedding(self.vocab_size, self.hidden_size)
23
+ self.position_embedding = nn.Embedding(self.max_position_embeddings, self.hidden_size)
24
+
25
+ # Transformer层
26
+ self.layers = nn.ModuleList([
27
+ TransformerLayer(self.hidden_size, self.num_attention_heads)
28
+ for _ in range(self.num_hidden_layers)
29
+ ])
30
+
31
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size)
32
+
33
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
34
+ # 嵌入
35
+ token_embeddings = self.token_embedding(input_ids)
36
+ position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0)
37
+ position_embeddings = self.position_embedding(position_ids)
38
+
39
+ hidden_states = token_embeddings + position_embeddings
40
+
41
+ # Transformer层
42
+ for layer in self.layers:
43
+ hidden_states = layer(hidden_states, attention_mask)
44
+
45
+ # 最终层归一化
46
+ hidden_states = self.final_layer_norm(hidden_states)
47
+
48
+ return hidden_states
49
+
50
+
51
+ class TransformerLayer(nn.Module):
52
+ """Transformer层"""
53
+ def __init__(self, hidden_size: int, num_heads: int):
54
+ super().__init__()
55
+ self.attention = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
56
+ self.attention_norm = nn.LayerNorm(hidden_size)
57
+
58
+ self.mlp = nn.Sequential(
59
+ nn.Linear(hidden_size, hidden_size * 4),
60
+ nn.GELU(),
61
+ nn.Linear(hidden_size * 4, hidden_size)
62
+ )
63
+ self.mlp_norm = nn.LayerNorm(hidden_size)
64
+
65
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
66
+ # 自注意力
67
+ attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
68
+ x = self.attention_norm(x + attn_output)
69
+
70
+ # 前馈网络
71
+ mlp_output = self.mlp(x)
72
+ x = self.mlp_norm(x + mlp_output)
73
+
74
+ return x
75
+
76
+
77
+ class CLIPTextEncoderWrapper:
78
+ """CLIP文本编码器包装器"""
79
+ def __init__(self, model_name: str = 'openai/clip-vit-base-patch32', device: str = 'cuda'):
80
+ self.model_name = model_name
81
+ self.device = device
82
+
83
+ # 加载tokenizer和模型
84
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
85
+
86
+ # 只加载文本模型
87
+ self.text_model = CLIPTextModel.from_pretrained(model_name).to(device)
88
+
89
+ # 冻结参数
90
+ for param in self.text_model.parameters():
91
+ param.requires_grad = False
92
+
93
+ # 设置为评估模式
94
+ self.text_model.eval()
95
+
96
+ print(f"已加载CLIP文本编码器: {model_name}")
97
+
98
+ def encode(self, texts: List[str], return_tensors: str = 'pt') -> torch.Tensor:
99
+ """编码文本"""
100
+ # Tokenize
101
+ inputs = self.tokenizer(
102
+ texts,
103
+ padding=True,
104
+ truncation=True,
105
+ max_length=77,
106
+ return_tensors=return_tensors
107
+ )
108
+
109
+ # 移动到设备
110
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
111
+
112
+ # 编码
113
+ with torch.no_grad():
114
+ outputs = self.text_model(**inputs)
115
+ return outputs.last_hidden_state
116
+
117
+ def encode_batch(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
118
+ """分批编码文本"""
119
+ all_embeddings = []
120
+
121
+ for i in range(0, len(texts), batch_size):
122
+ batch_texts = texts[i:i + batch_size]
123
+ batch_embeddings = self.encode(batch_texts)
124
+ all_embeddings.append(batch_embeddings.cpu())
125
+
126
+ return torch.cat(all_embeddings, dim=0)
127
+
128
+ def save_embeddings(self, texts: List[str], save_path: str):
129
+ """保存文本嵌入"""
130
+ embeddings = self.encode_batch(texts)
131
+ torch.save(embeddings, save_path)
132
+ print(f"嵌入已保存到: {save_path}")
133
+
134
+
135
+ class CachedTextEncoder:
136
+ """带缓存的文本编码器"""
137
+ def __init__(self, encoder, cache_dir: str = './text_cache'):
138
+ self.encoder = encoder
139
+ self.cache_dir = cache_dir
140
+ os.makedirs(cache_dir, exist_ok=True)
141
+
142
+ # 内存缓存
143
+ self.memory_cache = {}
144
+
145
+ def encode(self, text: str) -> torch.Tensor:
146
+ """编码文本,使用缓存"""
147
+ # 生成缓存键
148
+ import hashlib
149
+ cache_key = hashlib.md5(text.encode()).hexdigest()
150
+
151
+ # 检查内存缓存
152
+ if cache_key in self.memory_cache:
153
+ return self.memory_cache[cache_key]
154
+
155
+ # 检查磁盘缓存
156
+ cache_file = os.path.join(self.cache_dir, f"{cache_key}.pt")
157
+ if os.path.exists(cache_file):
158
+ embedding = torch.load(cache_file)
159
+ self.memory_cache[cache_key] = embedding
160
+ return embedding
161
+
162
+ # 编码并缓存
163
+ embedding = self.encoder.encode([text])[0]
164
+
165
+ # 保存到内存缓存
166
+ self.memory_cache[cache_key] = embedding
167
+
168
+ # 保存到磁盘缓存
169
+ torch.save(embedding, cache_file)
170
+
171
+ return embedding
172
+
173
+ def encode_batch(self, texts: List[str]) -> torch.Tensor:
174
+ """批量编码文本"""
175
+ embeddings = []
176
+
177
+ for text in texts:
178
+ embedding = self.encode(text)
179
+ embeddings.append(embedding.unsqueeze(0))
180
+
181
+ return torch.cat(embeddings, dim=0)
182
+
183
+
184
+ def create_text_encoder(config: dict) -> CLIPTextEncoderWrapper:
185
+ """创建文本编码器"""
186
+ model_name = config.get('preprocessing', {}).get('tokenizer', 'openai/clip-vit-base-patch32')
187
+ device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
188
+
189
+ encoder = CLIPTextEncoderWrapper(model_name, device)
190
+
191
+ # 如果需要缓存,包装一层
192
+ if config.get('use_cache', True):
193
+ cache_dir = config.get('cache_dir', './data/text_cache')
194
+ encoder = CachedTextEncoder(encoder, cache_dir)
195
+
196
+ return encoder
197
+
198
+
199
+ def test_text_encoder():
200
+ """测试文本编码器"""
201
+ config = {
202
+ 'preprocessing': {
203
+ 'tokenizer': 'openai/clip-vit-base-patch32'
204
+ },
205
+ 'device': 'cuda' if torch.cuda.is_available() else 'cpu'
206
+ }
207
+
208
+ encoder = create_text_encoder(config)
209
+
210
+ # 测试编码
211
+ texts = [
212
+ "A beautiful sunset over the mountains",
213
+ "A cute cat playing with a ball",
214
+ "An astronaut riding a horse on Mars"
215
+ ]
216
+
217
+ embeddings = encoder.encode(texts)
218
+
219
+ print(f"文本数量: {len(texts)}")
220
+ print(f"嵌入形状: {embeddings.shape}")
221
+ print(f"嵌入范围: [{embeddings.min():.4f}, {embeddings.max():.4f}]")
222
+
223
+ return encoder, embeddings
224
+
225
+
226
+ if __name__ == '__main__':
227
+ encoder, embeddings = test_text_encoder()
src/inference/api.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
+ from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List, Dict, Any
5
+ import torch
6
+ import io
7
+ from PIL import Image
8
+ import base64
9
+ import json
10
+ import time
11
+ from datetime import datetime
12
+ import asyncio
13
+ from concurrent.futures import ThreadPoolExecutor
14
+ import uuid
15
+ import os
16
+
17
+ from .sampler import TextToImagePipeline, SamplerFactory
18
+ from ..data.text_encoder import CLIPTextEncoderWrapper
19
+
20
+
21
+ # 请求/响应模型
22
+ class TextToImageRequest(BaseModel):
23
+ prompt: str
24
+ negative_prompt: Optional[str] = ""
25
+ width: int = 512
26
+ height: int = 512
27
+ num_steps: int = 50
28
+ guidance_scale: float = 7.5
29
+ num_images: int = 1
30
+ seed: Optional[int] = None
31
+ sampler: str = "ddim"
32
+
33
+
34
+ class ImageResponse(BaseModel):
35
+ images: List[str] # base64编码的图像
36
+ metadata: Dict[str, Any]
37
+ request_id: str
38
+ generation_time: float
39
+
40
+
41
+ class BatchRequest(BaseModel):
42
+ requests: List[TextToImageRequest]
43
+ priority: int = 0
44
+
45
+
46
+ class StatusResponse(BaseModel):
47
+ status: str
48
+ queue_length: int
49
+ active_tasks: int
50
+ gpu_memory_usage: float
51
+ uptime: float
52
+
53
+
54
+ # API应用
55
+ class LuminaAPI:
56
+ """Lumina API服务器"""
57
+ def __init__(
58
+ self,
59
+ model,
60
+ diffusion,
61
+ text_encoder,
62
+ vae_decoder=None,
63
+ host: str = "0.0.0.0",
64
+ port: int = 8000,
65
+ max_queue_size: int = 100,
66
+ max_workers: int = 2
67
+ ):
68
+ self.model = model
69
+ self.diffusion = diffusion
70
+ self.text_encoder = text_encoder
71
+ self.vae_decoder = vae_decoder
72
+ self.host = host
73
+ self.port = port
74
+
75
+ # 创建FastAPI应用
76
+ self.app = FastAPI(
77
+ title="Lumina Image Generation API",
78
+ description="轻量级图像生成模型API",
79
+ version="1.0.0"
80
+ )
81
+
82
+ # 任务队列
83
+ self.task_queue = asyncio.Queue(maxsize=max_queue_size)
84
+ self.executor = ThreadPoolExecutor(max_workers=max_workers)
85
+ self.active_tasks = 0
86
+
87
+ # 请求历史
88
+ self.request_history = []
89
+ self.max_history = 1000
90
+
91
+ # 统计信息
92
+ self.start_time = time.time()
93
+ self.total_requests = 0
94
+ self.total_images = 0
95
+
96
+ # 初始化管道
97
+ self.pipeline = TextToImagePipeline(
98
+ model=model,
99
+ diffusion=diffusion,
100
+ text_encoder=text_encoder,
101
+ vae_decoder=vae_decoder,
102
+ sampler_type="ddim"
103
+ )
104
+
105
+ # 设置路由
106
+ self._setup_routes()
107
+
108
+ def _setup_routes(self):
109
+ """设置API路由"""
110
+
111
+ @self.app.get("/")
112
+ async def root():
113
+ return {
114
+ "message": "Lumina Image Generation API",
115
+ "version": "1.0.0",
116
+ "docs": "/docs",
117
+ "endpoints": [
118
+ "/generate",
119
+ "/batch_generate",
120
+ "/status",
121
+ "/health"
122
+ ]
123
+ }
124
+
125
+ @self.app.get("/health")
126
+ async def health_check():
127
+ """健康检查"""
128
+ gpu_available = torch.cuda.is_available()
129
+ gpu_memory = torch.cuda.memory_allocated() / 1024**3 if gpu_available else 0
130
+
131
+ return {
132
+ "status": "healthy",
133
+ "gpu_available": gpu_available,
134
+ "gpu_memory_gb": gpu_memory,
135
+ "model_loaded": self.model is not None,
136
+ "text_encoder_loaded": self.text_encoder is not None
137
+ }
138
+
139
+ @self.app.get("/status")
140
+ async def get_status():
141
+ """获取服务状态"""
142
+ uptime = time.time() - self.start_time
143
+
144
+ # GPU内存使用
145
+ if torch.cuda.is_available():
146
+ gpu_memory = torch.cuda.memory_allocated() / 1024**3
147
+ else:
148
+ gpu_memory = 0
149
+
150
+ return StatusResponse(
151
+ status="running",
152
+ queue_length=self.task_queue.qsize(),
153
+ active_tasks=self.active_tasks,
154
+ gpu_memory_usage=gpu_memory,
155
+ uptime=uptime
156
+ )
157
+
158
+ @self.app.post("/generate", response_model=ImageResponse)
159
+ async def generate_image(request: TextToImageRequest):
160
+ """生成单个图像"""
161
+ request_id = str(uuid.uuid4())
162
+
163
+ # 记录请求
164
+ self.request_history.append({
165
+ "request_id": request_id,
166
+ "prompt": request.prompt,
167
+ "timestamp": datetime.now().isoformat()
168
+ })
169
+
170
+ # 限制历史记录大小
171
+ if len(self.request_history) > self.max_history:
172
+ self.request_history = self.request_history[-self.max_history:]
173
+
174
+ # 生成图像
175
+ start_time = time.time()
176
+
177
+ try:
178
+ # 在线程池中运行生成任务
179
+ loop = asyncio.get_event_loop()
180
+ images = await loop.run_in_executor(
181
+ self.executor,
182
+ self._generate_sync,
183
+ request
184
+ )
185
+
186
+ generation_time = time.time() - start_time
187
+
188
+ # 转换为base64
189
+ image_b64_list = []
190
+ for img_tensor in images:
191
+ if isinstance(img_tensor, torch.Tensor):
192
+ # 转换为PIL图像
193
+ if img_tensor.dim() == 4:
194
+ img_tensor = img_tensor.squeeze(0)
195
+
196
+ # 归一化到[0, 255]
197
+ img_tensor = torch.clamp(img_tensor * 255, 0, 255).byte()
198
+
199
+ # 转换为numpy数组
200
+ if img_tensor.shape[0] == 3: # CHW格式
201
+ img_array = img_tensor.permute(1, 2, 0).cpu().numpy()
202
+ else:
203
+ img_array = img_tensor.cpu().numpy()
204
+
205
+ # 转换为PIL图像
206
+ img = Image.fromarray(img_array)
207
+
208
+ # 转换为base64
209
+ buffered = io.BytesIO()
210
+ img.save(buffered, format="PNG")
211
+ img_b64 = base64.b64encode(buffered.getvalue()).decode()
212
+ image_b64_list.append(img_b64)
213
+ else:
214
+ image_b64_list.append("")
215
+
216
+ # 更新统计信息
217
+ self.total_requests += 1
218
+ self.total_images += len(images)
219
+
220
+ return ImageResponse(
221
+ images=image_b64_list,
222
+ metadata={
223
+ "prompt": request.prompt,
224
+ "negative_prompt": request.negative_prompt,
225
+ "width": request.width,
226
+ "height": request.height,
227
+ "num_steps": request.num_steps,
228
+ "guidance_scale": request.guidance_scale,
229
+ "seed": request.seed,
230
+ "sampler": request.sampler
231
+ },
232
+ request_id=request_id,
233
+ generation_time=generation_time
234
+ )
235
+
236
+ except Exception as e:
237
+ raise HTTPException(status_code=500, detail=str(e))
238
+
239
+ @self.app.post("/batch_generate")
240
+ async def batch_generate(batch_request: BatchRequest):
241
+ """批量生成图像"""
242
+ request_ids = [str(uuid.uuid4()) for _ in batch_request.requests]
243
+ results = []
244
+
245
+ # 为每个请求生成任务
246
+ tasks = []
247
+ for req, req_id in zip(batch_request.requests, request_ids):
248
+ task = asyncio.create_task(self._process_single_request(req, req_id))
249
+ tasks.append(task)
250
+
251
+ # 等待所有任务完成
252
+ results = await asyncio.gather(*tasks, return_exceptions=True)
253
+
254
+ # 处理结果
255
+ successful_results = []
256
+ failed_results = []
257
+
258
+ for result in results:
259
+ if isinstance(result, Exception):
260
+ failed_results.append({"error": str(result)})
261
+ else:
262
+ successful_results.append(result)
263
+
264
+ return {
265
+ "successful": successful_results,
266
+ "failed": failed_results,
267
+ "total_requests": len(batch_request.requests),
268
+ "successful_count": len(successful_results),
269
+ "failed_count": len(failed_results)
270
+ }
271
+
272
+ @self.app.get("/history")
273
+ async def get_history(limit: int = 50):
274
+ """获取请求历史"""
275
+ return self.request_history[-limit:]
276
+
277
+ @self.app.post("/txt2img") # Stable Diffusion兼容端点
278
+ async def txt2img(
279
+ prompt: str = Form(...),
280
+ negative_prompt: str = Form(""),
281
+ width: int = Form(512),
282
+ height: int = Form(512),
283
+ steps: int = Form(50),
284
+ cfg_scale: float = Form(7.5),
285
+ seed: int = Form(-1)
286
+ ):
287
+ """兼容Stable Diffusion WebUI的端点"""
288
+ request = TextToImageRequest(
289
+ prompt=prompt,
290
+ negative_prompt=negative_prompt,
291
+ width=width,
292
+ height=height,
293
+ num_steps=steps,
294
+ guidance_scale=cfg_scale,
295
+ seed=seed if seed != -1 else None
296
+ )
297
+
298
+ response = await generate_image(request)
299
+
300
+ # 返回第一个图像
301
+ if response.images:
302
+ return StreamingResponse(
303
+ io.BytesIO(base64.b64decode(response.images[0])),
304
+ media_type="image/png"
305
+ )
306
+ else:
307
+ raise HTTPException(status_code=500, detail="生成失败")
308
+
309
+ @self.app.get("/stats")
310
+ async def get_stats():
311
+ """获取统计信息"""
312
+ uptime = time.time() - self.start_time
313
+
314
+ return {
315
+ "total_requests": self.total_requests,
316
+ "total_images": self.total_images,
317
+ "requests_per_minute": self.total_requests / (uptime / 60) if uptime > 0 else 0,
318
+ "avg_generation_time": None, # 可以添加计时逻辑
319
+ "uptime_seconds": uptime,
320
+ "queue_size": self.task_queue.qsize(),
321
+ "active_workers": self.active_tasks
322
+ }
323
+
324
+ def _generate_sync(self, request: TextToImageRequest) -> List[torch.Tensor]:
325
+ """同步生成图像(在单独的线程中运行)"""
326
+ # 更新管道采样器
327
+ self.pipeline.sampler = SamplerFactory.create_sampler(
328
+ request.sampler,
329
+ self.model,
330
+ self.diffusion,
331
+ request.num_steps
332
+ )
333
+
334
+ # 生成图像
335
+ images = self.pipeline(
336
+ prompt=request.prompt,
337
+ negative_prompt=request.negative_prompt,
338
+ height=request.height,
339
+ width=request.width,
340
+ num_inference_steps=request.num_steps,
341
+ guidance_scale=request.guidance_scale,
342
+ num_images=request.num_images,
343
+ seed=request.seed,
344
+ progress_bar=False
345
+ )
346
+
347
+ return images
348
+
349
+ async def _process_single_request(self, request: TextToImageRequest, request_id: str) -> Dict:
350
+ """处理单个请求"""
351
+ try:
352
+ # 在队列中添加任务
353
+ await self.task_queue.put((request, request_id))
354
+
355
+ # 更新活动任务计数
356
+ self.active_tasks += 1
357
+
358
+ # 处理任务
359
+ loop = asyncio.get_event_loop()
360
+ images = await loop.run_in_executor(
361
+ self.executor,
362
+ self._generate_sync,
363
+ request
364
+ )
365
+
366
+ # 转换图像
367
+ image_b64_list = []
368
+ for img_tensor in images:
369
+ # 简化的转换逻辑
370
+ img_b64 = "placeholder" # 实际应该转换为base64
371
+ image_b64_list.append(img_b64)
372
+
373
+ # 更新活动任务计数
374
+ self.active_tasks -= 1
375
+
376
+ return {
377
+ "request_id": request_id,
378
+ "images": image_b64_list,
379
+ "success": True
380
+ }
381
+
382
+ except Exception as e:
383
+ self.active_tasks -= 1
384
+ return {
385
+ "request_id": request_id,
386
+ "error": str(e),
387
+ "success": False
388
+ }
389
+
390
+ def run(self):
391
+ """运行API服务器"""
392
+ import uvicorn
393
+ print(f"启动Lumina API服务器在 http://{self.host}:{self.port}")
394
+ print(f"API文档: http://{self.host}:{self.port}/docs")
395
+
396
+ uvicorn.run(
397
+ self.app,
398
+ host=self.host,
399
+ port=self.port,
400
+ log_level="info"
401
+ )
402
+
403
+
404
+ class SimpleWebUI:
405
+ """简单的Web UI(使用Gradio)"""
406
+ def __init__(self, pipeline: TextToImagePipeline):
407
+ self.pipeline = pipeline
408
+
409
+ def create_interface(self):
410
+ """创建Gradio界面"""
411
+ try:
412
+ import gradio as gr
413
+ except ImportError:
414
+ print("Gradio未安装,无法创建Web UI")
415
+ return None
416
+
417
+ def generate_image_ui(
418
+ prompt,
419
+ negative_prompt,
420
+ width,
421
+ height,
422
+ num_steps,
423
+ guidance_scale,
424
+ seed,
425
+ sampler
426
+ ):
427
+ """UI生成函数"""
428
+ # 设置种子
429
+ if seed and seed > 0:
430
+ torch.manual_seed(seed)
431
+ if torch.cuda.is_available():
432
+ torch.cuda.manual_seed(seed)
433
+
434
+ # 生成图像
435
+ images = self.pipeline(
436
+ prompt=prompt,
437
+ negative_prompt=negative_prompt,
438
+ height=height,
439
+ width=width,
440
+ num_inference_steps=num_steps,
441
+ guidance_scale=guidance_scale,
442
+ num_images=1,
443
+ seed=seed if seed > 0 else None,
444
+ progress_bar=False
445
+ )
446
+
447
+ # 转换为PIL图像
448
+ if images:
449
+ img_tensor = images[0]
450
+ if isinstance(img_tensor, torch.Tensor):
451
+ if img_tensor.dim() == 4:
452
+ img_tensor = img_tensor.squeeze(0)
453
+
454
+ # 归一化到[0, 255]
455
+ img_tensor = torch.clamp(img_tensor * 255, 0, 255).byte()
456
+
457
+ # 转换为numpy数组
458
+ if img_tensor.shape[0] == 3: # CHW格式
459
+ img_array = img_tensor.permute(1, 2, 0).cpu().numpy()
460
+ else:
461
+ img_array = img_tensor.cpu().numpy()
462
+
463
+ # 转换为PIL图像
464
+ from PIL import Image
465
+ img = Image.fromarray(img_array)
466
+ return img
467
+
468
+ return None
469
+
470
+ # 创建界面
471
+ with gr.Blocks(title="Lumina Image Generator") as interface:
472
+ gr.Markdown("# 🎨 Lumina - 轻量级图像生成")
473
+ gr.Markdown("基于扩散模型的文本到图像生成系统")
474
+
475
+ with gr.Row():
476
+ with gr.Column():
477
+ prompt = gr.Textbox(
478
+ label="提示词",
479
+ placeholder="输入描述图像的文本...",
480
+ lines=3
481
+ )
482
+ negative_prompt = gr.Textbox(
483
+ label="负面提示词",
484
+ placeholder="不想在图像中出现的内容...",
485
+ lines=2
486
+ )
487
+
488
+ with gr.Row():
489
+ width = gr.Slider(
490
+ minimum=256,
491
+ maximum=1024,
492
+ value=512,
493
+ step=64,
494
+ label="宽度"
495
+ )
496
+ height = gr.Slider(
497
+ minimum=256,
498
+ maximum=1024,
499
+ value=512,
500
+ step=64,
501
+ label="高度"
502
+ )
503
+
504
+ with gr.Row():
505
+ num_steps = gr.Slider(
506
+ minimum=1,
507
+ maximum=100,
508
+ value=30,
509
+ step=1,
510
+ label="采样步数"
511
+ )
512
+ guidance_scale = gr.Slider(
513
+ minimum=1.0,
514
+ maximum=20.0,
515
+ value=7.5,
516
+ step=0.5,
517
+ label="引导强度"
518
+ )
519
+
520
+ with gr.Row():
521
+ seed = gr.Number(
522
+ value=-1,
523
+ label="随机种子 (-1为随机)"
524
+ )
525
+ sampler = gr.Dropdown(
526
+ choices=["ddim", "dpm", "lcm"],
527
+ value="ddim",
528
+ label="采样器"
529
+ )
530
+
531
+ generate_btn = gr.Button("生成图像", variant="primary")
532
+
533
+ with gr.Column():
534
+ output_image = gr.Image(
535
+ label="生成的图像",
536
+ type="pil"
537
+ )
538
+
539
+ # 示例
540
+ gr.Markdown("### 示例提示词")
541
+ examples = gr.Examples(
542
+ examples=[
543
+ ["A beautiful sunset over mountains, digital art", "", 512, 512, 30, 7.5, -1],
544
+ ["A cute cat playing with a ball of yarn", "blurry, deformed", 512, 512, 25, 8.0, -1],
545
+ ["An astronaut riding a horse on Mars", "cartoon, anime", 512, 512, 40, 7.0, -1]
546
+ ],
547
+ inputs=[prompt, negative_prompt, width, height, num_steps, guidance_scale, seed]
548
+ )
549
+
550
+ # 事件处理
551
+ generate_btn.click(
552
+ fn=generate_image_ui,
553
+ inputs=[prompt, negative_prompt, width, height, num_steps, guidance_scale, seed, sampler],
554
+ outputs=output_image
555
+ )
556
+
557
+ return interface
558
+
559
+ def launch(self, share: bool = False, server_name: str = "0.0.0.0", server_port: int = 7860):
560
+ """启动Web UI"""
561
+ interface = self.create_interface()
562
+ if interface:
563
+ interface.launch(
564
+ share=share,
565
+ server_name=server_name,
566
+ server_port=server_port
567
+ )
568
+
569
+
570
+ def create_api_server(config: dict, model, diffusion, text_encoder, vae_decoder=None):
571
+ """创建API服务器"""
572
+ # 确定主机和端口
573
+ host = config.get('host', '0.0.0.0')
574
+ port = config.get('port', 8000)
575
+
576
+ # 创建API服务器
577
+ api_server = LuminaAPI(
578
+ model=model,
579
+ diffusion=diffusion,
580
+ text_encoder=text_encoder,
581
+ vae_decoder=vae_decoder,
582
+ host=host,
583
+ port=port,
584
+ max_queue_size=config.get('max_queue_size', 100),
585
+ max_workers=config.get('max_workers', 2)
586
+ )
587
+
588
+ return api_server
589
+
590
+
591
+ def test_api():
592
+ """测试API"""
593
+ import torch.nn as nn
594
+
595
+ # 创建模拟组件
596
+ class MockModel(nn.Module):
597
+ def forward(self, x, t, context):
598
+ return torch.randn_like(x)
599
+
600
+ class MockDiffusion:
601
+ pass
602
+
603
+ class MockTextEncoder:
604
+ def encode(self, texts):
605
+ return torch.randn(len(texts), 77, 768)
606
+
607
+ model = MockModel()
608
+ diffusion = MockDiffusion()
609
+ text_encoder = MockTextEncoder()
610
+
611
+ # 创建API服务器
612
+ api = LuminaAPI(
613
+ model=model,
614
+ diffusion=diffusion,
615
+ text_encoder=text_encoder
616
+ )
617
+
618
+ print("API服务器创建成功")
619
+ print("端点:")
620
+ print(" POST /generate - 生成图像")
621
+ print(" GET /health - 健康检查")
622
+ print(" GET /status - 状态信息")
623
+
624
+ return api
625
+
626
+
627
+ if __name__ == '__main__':
628
+ # 测试API
629
+ api = test_api()
630
+
631
+ # 注意:实际运行需要调用 api.run()
src/inference/optimization.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Dict, Any
4
+ import onnx
5
+ import onnxruntime as ort
6
+ import os
7
+
8
+
9
+ class ModelOptimizer:
10
+ """模型优化器"""
11
+ def __init__(self, model: nn.Module, device: str = 'cuda'):
12
+ self.model = model
13
+ self.device = device
14
+ self.optimized_model = None
15
+
16
+ def optimize_for_inference(self, use_jit: bool = True, use_cuda_graph: bool = False):
17
+ """优化模型用于推理"""
18
+ self.model.eval()
19
+
20
+ # 应用一系列优化
21
+ if use_jit:
22
+ self._jit_compile()
23
+
24
+ if use_cuda_graph and torch.cuda.is_available():
25
+ self._capture_cuda_graph()
26
+
27
+ # 应用其他优化
28
+ self._apply_inference_optimizations()
29
+
30
+ return self.optimized_model or self.model
31
+
32
+ def _jit_compile(self):
33
+ """使用TorchScript编译模型"""
34
+ try:
35
+ # 创建示例输入
36
+ example_input = torch.randn(1, 4, 64, 64, device=self.device)
37
+ example_timestep = torch.tensor([500], device=self.device)
38
+ example_context = torch.randn(1, 77, 768, device=self.device)
39
+
40
+ # 脚本编译
41
+ scripted_model = torch.jit.trace(
42
+ self.model,
43
+ (example_input, example_timestep, example_context),
44
+ check_trace=False
45
+ )
46
+
47
+ self.optimized_model = scripted_model
48
+ print("模型已使用TorchScript编译")
49
+ except Exception as e:
50
+ print(f"TorchScript编译失败: {e}")
51
+
52
+ def _capture_cuda_graph(self):
53
+ """捕获CUDA图(用于重复推理)"""
54
+ if not torch.cuda.is_available():
55
+ return
56
+
57
+ # 创建静态输入
58
+ static_input = torch.randn(1, 4, 64, 64, device='cuda', dtype=torch.float16)
59
+ static_timestep = torch.tensor([500], device='cuda')
60
+ static_context = torch.randn(1, 77, 768, device='cuda', dtype=torch.float16)
61
+
62
+ # 预热
63
+ with torch.no_grad():
64
+ for _ in range(3):
65
+ _ = self.model(static_input, static_timestep, static_context)
66
+
67
+ # 捕获图
68
+ graph = torch.cuda.CUDAGraph()
69
+ with torch.cuda.graph(graph):
70
+ static_output = self.model(static_input, static_timestep, static_context)
71
+
72
+ # 创建包装函数
73
+ def graph_executor(input_tensor, timestep, context):
74
+ static_input.copy_(input_tensor)
75
+ static_timestep.copy_(timestep)
76
+ static_context.copy_(context)
77
+ graph.replay()
78
+ return static_output.clone()
79
+
80
+ self.optimized_model = graph_executor
81
+ print("已捕获CUDA图")
82
+
83
+ def _apply_inference_optimizations(self):
84
+ """应用推理优化"""
85
+ # 设置为评估模式
86
+ self.model.eval()
87
+
88
+ # 融合操作(如果可用)
89
+ if hasattr(torch, 'compile'):
90
+ try:
91
+ self.model = torch.compile(self.model, mode='max-autotune')
92
+ print("模型已使用torch.compile优化")
93
+ except:
94
+ pass
95
+
96
+ # 使用半精度
97
+ if self.device == 'cuda':
98
+ self.model.half()
99
+ print("模型已转换为半精度")
100
+
101
+ def quantize(self, quantization_mode: str = 'dynamic'):
102
+ """量化模型"""
103
+ if quantization_mode == 'dynamic':
104
+ # 动态量化
105
+ quantized_model = torch.quantization.quantize_dynamic(
106
+ self.model,
107
+ {nn.Linear, nn.Conv2d},
108
+ dtype=torch.qint8
109
+ )
110
+ self.optimized_model = quantized_model
111
+ print("模型已动态量化")
112
+
113
+ elif quantization_mode == 'static':
114
+ # 静态量化需要校准数据
115
+ print("静态量化需要校准数据,暂未实现")
116
+
117
+ else:
118
+ raise ValueError(f"未知的量化模式: {quantization_mode}")
119
+
120
+ def prune(self, pruning_rate: float = 0.2):
121
+ """剪枝模型"""
122
+ from torch.nn.utils import prune
123
+
124
+ # 对线性层和卷积层进行剪枝
125
+ for name, module in self.model.named_modules():
126
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
127
+ prune.l1_unstructured(module, name='weight', amount=pruning_rate)
128
+ prune.remove(module, 'weight')
129
+
130
+ print(f"模型已剪枝,剪枝率: {pruning_rate}")
131
+
132
+ def get_model_size(self) -> Dict[str, float]:
133
+ """获取模型大小"""
134
+ # 计算参数量
135
+ total_params = sum(p.numel() for p in self.model.parameters())
136
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
137
+
138
+ # 计算模型大小(MB)
139
+ param_size = 0
140
+ for param in self.model.parameters():
141
+ param_size += param.nelement() * param.element_size()
142
+
143
+ buffer_size = 0
144
+ for buffer in self.model.buffers():
145
+ buffer_size += buffer.nelement() * buffer.element_size()
146
+
147
+ size_mb = (param_size + buffer_size) / 1024**2
148
+
149
+ return {
150
+ 'total_params': total_params,
151
+ 'trainable_params': trainable_params,
152
+ 'size_mb': size_mb
153
+ }
154
+
155
+
156
+ class ONNXExporter:
157
+ """ONNX导出器"""
158
+ def __init__(self, model: nn.Module):
159
+ self.model = model
160
+
161
+ def export(
162
+ self,
163
+ output_path: str,
164
+ input_shape: tuple = (1, 4, 64, 64),
165
+ opset_version: int = 14,
166
+ dynamic_axes: Optional[Dict] = None
167
+ ):
168
+ """导出为ONNX格式"""
169
+ # 设置为评估模式
170
+ self.model.eval()
171
+
172
+ # 创建示例输入
173
+ dummy_input = torch.randn(*input_shape)
174
+ dummy_timestep = torch.tensor([500])
175
+ dummy_context = torch.randn(1, 77, 768)
176
+
177
+ # 默认动态轴
178
+ if dynamic_axes is None:
179
+ dynamic_axes = {
180
+ 'input': {0: 'batch_size'},
181
+ 'timestep': {0: 'batch_size'},
182
+ 'context': {0: 'batch_size'},
183
+ 'output': {0: 'batch_size'}
184
+ }
185
+
186
+ # 导出
187
+ torch.onnx.export(
188
+ self.model,
189
+ (dummy_input, dummy_timestep, dummy_context),
190
+ output_path,
191
+ input_names=['input', 'timestep', 'context'],
192
+ output_names=['output'],
193
+ dynamic_axes=dynamic_axes,
194
+ opset_version=opset_version,
195
+ do_constant_folding=True,
196
+ verbose=False
197
+ )
198
+
199
+ print(f"模型已导出为ONNX: {output_path}")
200
+
201
+ # 验证ONNX模型
202
+ self._validate_onnx(output_path)
203
+
204
+ def _validate_onnx(self, onnx_path: str):
205
+ """验证ONNX模型"""
206
+ try:
207
+ onnx_model = onnx.load(onnx_path)
208
+ onnx.checker.check_model(onnx_model)
209
+ print("ONNX模型验证成功")
210
+ except Exception as e:
211
+ print(f"ONNX模型验证失败: {e}")
212
+
213
+ def optimize_onnx(self, onnx_path: str, optimized_path: str):
214
+ """优化ONNX模型"""
215
+ try:
216
+ # 使用ONNX Runtime优化
217
+ sess_options = ort.SessionOptions()
218
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
219
+
220
+ # 创建优化会话
221
+ ort_session = ort.InferenceSession(onnx_path, sess_options)
222
+
223
+ # 保存优化后的模型
224
+ optimized_model = ort_session.get_model()
225
+ onnx.save(optimized_model, optimized_path)
226
+
227
+ print(f"ONNX模型已优化: {optimized_path}")
228
+ except Exception as e:
229
+ print(f"ONNX优化失败: {e}")
230
+
231
+
232
+ class MemoryEfficientInference:
233
+ """内存高效推理"""
234
+ def __init__(self, model: nn.Module, chunk_size: int = 32):
235
+ self.model = model
236
+ self.chunk_size = chunk_size
237
+
238
+ def chunked_inference(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
239
+ """分块推理,减少内存使用"""
240
+ B, C, H, W = x.shape
241
+ output = torch.zeros_like(x)
242
+
243
+ # 分块处理
244
+ for i in range(0, H, self.chunk_size):
245
+ for j in range(0, W, self.chunk_size):
246
+ # 提取块
247
+ chunk = x[:, :, i:i+self.chunk_size, j:j+self.chunk_size]
248
+
249
+ # 推理
250
+ with torch.no_grad():
251
+ chunk_output = self.model(chunk, t, context)
252
+
253
+ # 存储结果
254
+ output[:, :, i:i+self.chunk_size, j:j+self.chunk_size] = chunk_output
255
+
256
+ return output
257
+
258
+ def tiled_inference(self, x: torch.Tensor, t: torch.Tensor, context: torch.Tensor, tile_size: int = 512) -> torch.Tensor:
259
+ """平铺推理,用于大图像"""
260
+ B, C, H, W = x.shape
261
+
262
+ # 如果图像不大,直接推理
263
+ if H <= tile_size and W <= tile_size:
264
+ with torch.no_grad():
265
+ return self.model(x, t, context)
266
+
267
+ # 计算平铺数量
268
+ n_tiles_h = (H + tile_size - 1) // tile_size
269
+ n_tiles_w = (W + tile_size - 1) // tile_size
270
+
271
+ output = torch.zeros_like(x)
272
+
273
+ # 处理每个平铺
274
+ for i in range(n_tiles_h):
275
+ for j in range(n_tiles_w):
276
+ # 计算平铺位置
277
+ h_start = i * tile_size
278
+ w_start = j * tile_size
279
+ h_end = min(h_start + tile_size, H)
280
+ w_end = min(w_start + tile_size, W)
281
+
282
+ # 提取平铺
283
+ tile = x[:, :, h_start:h_end, w_start:w_end]
284
+
285
+ # 推理
286
+ with torch.no_grad():
287
+ tile_output = self.model(tile, t, context)
288
+
289
+ # 存储结果
290
+ output[:, :, h_start:h_end, w_start:w_end] = tile_output
291
+
292
+ return output
293
+
294
+
295
+ class InferenceBenchmark:
296
+ """推理基准测试"""
297
+ def __init__(self, model: nn.Module, device: str = 'cuda'):
298
+ self.model = model
299
+ self.device = device
300
+
301
+ def benchmark(
302
+ self,
303
+ input_shape: tuple = (1, 4, 64, 64),
304
+ num_iterations: int = 100,
305
+ warmup_iterations: int = 10
306
+ ) -> Dict[str, float]:
307
+ """运行基准测试"""
308
+ # 准备输入
309
+ x = torch.randn(*input_shape, device=self.device)
310
+ t = torch.tensor([500], device=self.device)
311
+ context = torch.randn(1, 77, 768, device=self.device)
312
+
313
+ # 预热
314
+ print("预热...")
315
+ with torch.no_grad():
316
+ for _ in range(warmup_iterations):
317
+ _ = self.model(x, t, context)
318
+
319
+ # 同步
320
+ if torch.cuda.is_available():
321
+ torch.cuda.synchronize()
322
+
323
+ # 基准测试
324
+ print("运行基准测试...")
325
+ import time
326
+
327
+ times = []
328
+
329
+ for i in range(num_iterations):
330
+ start_time = time.time()
331
+
332
+ with torch.no_grad():
333
+ _ = self.model(x, t, context)
334
+
335
+ if torch.cuda.is_available():
336
+ torch.cuda.synchronize()
337
+
338
+ end_time = time.time()
339
+ times.append(end_time - start_time)
340
+
341
+ # 统计
342
+ times = torch.tensor(times)
343
+
344
+ stats = {
345
+ 'mean_ms': times.mean().item() * 1000,
346
+ 'std_ms': times.std().item() * 1000,
347
+ 'min_ms': times.min().item() * 1000,
348
+ 'max_ms': times.max().item() * 1000,
349
+ 'fps': 1 / times.mean().item(),
350
+ 'num_iterations': num_iterations
351
+ }
352
+
353
+ # 打印结果
354
+ print("\n" + "="*50)
355
+ print("推理基准测试结果:")
356
+ print(f"平均推理时间: {stats['mean_ms']:.2f} ms")
357
+ print(f"标准差: {stats['std_ms']:.2f} ms")
358
+ print(f"最小推理时间: {stats['min_ms']:.2f} ms")
359
+ print(f"最大推理时间: {stats['max_ms']:.2f} ms")
360
+ print(f"FPS: {stats['fps']:.2f}")
361
+ print("="*50)
362
+
363
+ return stats
364
+
365
+
366
+ def optimize_model_for_p4(model: nn.Module) -> nn.Module:
367
+ """为P4优化模型"""
368
+ optimizer = ModelOptimizer(model)
369
+
370
+ # 获取模型大小
371
+ size_info = optimizer.get_model_size()
372
+ print(f"优化前模型大小: {size_info['size_mb']:.2f} MB")
373
+
374
+ # 应用优化
375
+ optimized_model = optimizer.optimize_for_inference(
376
+ use_jit=True,
377
+ use_cuda_graph=False # P4可能不支持
378
+ )
379
+
380
+ # 量化(可选)
381
+ if size_info['size_mb'] > 500: # 如果模型大于500MB,进行量化
382
+ optimizer.quantize('dynamic')
383
+
384
+ # 获取优化后的模型大小
385
+ size_info_after = optimizer.get_model_size()
386
+ print(f"优化后模型大小: {size_info_after['size_mb']:.2f} MB")
387
+ print(f"压缩比: {size_info['size_mb'] / size_info_after['size_mb']:.2f}x")
388
+
389
+ return optimized_model
390
+
391
+
392
+ def test_optimization():
393
+ """测试优化"""
394
+ import torch.nn as nn
395
+
396
+ # 创建模拟模型
397
+ class MockModel(nn.Module):
398
+ def __init__(self):
399
+ super().__init__()
400
+ self.conv1 = nn.Conv2d(4, 64, 3, padding=1)
401
+ self.conv2 = nn.Conv2d(64, 4, 3, padding=1)
402
+
403
+ def forward(self, x, t, context):
404
+ x = self.conv1(x)
405
+ x = nn.functional.relu(x)
406
+ x = self.conv2(x)
407
+ return x
408
+
409
+ model = MockModel()
410
+
411
+ # 测试优化器
412
+ optimizer = ModelOptimizer(model)
413
+ optimized_model = optimizer.optimize_for_inference()
414
+
415
+ # 测试基准测试
416
+ benchmark = InferenceBenchmark(model)
417
+ stats = benchmark.benchmark(num_iterations=10)
418
+
419
+ # 测试ONNX导出
420
+ exporter = ONNXExporter(model)
421
+ exporter.export('./test_model.onnx', input_shape=(1, 4, 64, 64))
422
+
423
+ return optimized_model, stats
424
+
425
+
426
+ if __name__ == '__main__':
427
+ optimized_model, stats = test_optimization()
src/inference/sampler.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Tuple, List, Union
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import math
7
+
8
+
9
+ class DDIMSampler:
10
+ """DDIM采样器"""
11
+ def __init__(self, model: nn.Module, diffusion, num_inference_steps: int = 50):
12
+ self.model = model
13
+ self.diffusion = diffusion
14
+ self.num_inference_steps = num_inference_steps
15
+
16
+ # 设置时间步
17
+ self.set_timesteps(num_inference_steps)
18
+
19
+ def set_timesteps(self, num_inference_steps: int):
20
+ """设置推理时间步"""
21
+ self.num_inference_steps = num_inference_steps
22
+
23
+ # 选择时间步
24
+ if self.diffusion.num_train_timesteps == num_inference_steps:
25
+ timesteps = np.arange(0, self.diffusion.num_train_timesteps)
26
+ else:
27
+ step_ratio = self.diffusion.num_train_timesteps // num_inference_steps
28
+ timesteps = np.arange(0, self.diffusion.num_train_timesteps, step_ratio)
29
+
30
+ self.timesteps = torch.from_numpy(timesteps).long().flip(0)
31
+
32
+ @torch.no_grad()
33
+ def step(
34
+ self,
35
+ model_output: torch.Tensor,
36
+ timestep: int,
37
+ sample: torch.Tensor,
38
+ eta: float = 0.0,
39
+ use_clipped_model_output: bool = False
40
+ ) -> torch.Tensor:
41
+ """DDIM单步采样"""
42
+ # 获取当前和上一个时间步
43
+ prev_timestep = timestep - self.diffusion.num_train_timesteps // self.num_inference_steps
44
+
45
+ # 提取alpha参数
46
+ alpha_prod_t = self.diffusion.extract(self.diffusion.alphas_cumprod, timestep, sample.shape)
47
+ alpha_prod_t_prev = self.diffusion.extract(
48
+ self.diffusion.alphas_cumprod,
49
+ prev_timestep,
50
+ sample.shape
51
+ ) if prev_timestep >= 0 else torch.ones_like(alpha_prod_t)
52
+
53
+ # 根据预测类型处理模型输出
54
+ if self.diffusion.prediction_type == "epsilon":
55
+ pred_original_sample = (sample - (1 - alpha_prod_t) ** 0.5 * model_output) / alpha_prod_t ** 0.5
56
+ pred_epsilon = model_output
57
+ elif self.diffusion.prediction_type == "sample":
58
+ pred_original_sample = model_output
59
+ pred_epsilon = (sample - alpha_prod_t ** 0.5 * pred_original_sample) / (1 - alpha_prod_t) ** 0.5
60
+ elif self.diffusion.prediction_type == "v_prediction":
61
+ pred_original_sample = (alpha_prod_t ** 0.5) * sample - (1 - alpha_prod_t) ** 0.5 * model_output
62
+ pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (1 - alpha_prod_t) ** 0.5 * sample
63
+ else:
64
+ raise ValueError(f"Unsupported prediction type: {self.diffusion.prediction_type}")
65
+
66
+ # 裁剪预测的原始样本
67
+ if use_clipped_model_output:
68
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
69
+
70
+ # 计算x_t-1的方差
71
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
72
+ std_dev_t = eta * variance ** 0.5
73
+
74
+ # 当eta > 0时,使用随机采样
75
+ if eta > 0:
76
+ noise = torch.randn_like(model_output)
77
+ variance = std_dev_t ** 2
78
+ else:
79
+ noise = 0
80
+ variance = 0
81
+
82
+ # 计算x_t-1的均值
83
+ pred_sample_direction = (1 - alpha_prod_t_prev - variance) ** 0.5 * pred_epsilon
84
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
85
+
86
+ # 添加噪声
87
+ if eta > 0:
88
+ prev_sample = prev_sample + std_dev_t * noise
89
+
90
+ return prev_sample
91
+
92
+ @torch.no_grad()
93
+ def sample(
94
+ self,
95
+ prompt_embeds: torch.Tensor,
96
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
97
+ height: int = 512,
98
+ width: int = 512,
99
+ num_images_per_prompt: int = 1,
100
+ guidance_scale: float = 7.5,
101
+ eta: float = 0.0,
102
+ generator: Optional[torch.Generator] = None,
103
+ progress_bar: bool = True
104
+ ) -> torch.Tensor:
105
+ """生成样本"""
106
+ # 设置模型为评估模式
107
+ self.model.eval()
108
+
109
+ # 批次大小
110
+ batch_size = prompt_embeds.shape[0]
111
+
112
+ # 初始化潜在表示
113
+ latents = torch.randn(
114
+ (batch_size * num_images_per_prompt, self.model.in_channels, height // 8, width // 8),
115
+ device=prompt_embeds.device,
116
+ generator=generator
117
+ )
118
+
119
+ # 准备额外的条件
120
+ if negative_prompt_embeds is not None:
121
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
122
+
123
+ # 分类器自由引导的缩放因子
124
+ do_classifier_free_guidance = guidance_scale > 1.0
125
+ if do_classifier_free_guidance:
126
+ latents = torch.cat([latents] * 2)
127
+
128
+ # 采样循环
129
+ timesteps = self.timesteps.to(latents.device)
130
+
131
+ if progress_bar:
132
+ timesteps = tqdm(timesteps, desc="DDIM Sampling")
133
+
134
+ for i, t in enumerate(timesteps):
135
+ # 扩展潜在表示以匹配引导的批次大小
136
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
137
+ latent_model_input = self.diffusion.scale_model_input(latent_model_input, t)
138
+
139
+ # 预测噪声
140
+ noise_pred = self.model(latent_model_input, t, prompt_embeds)
141
+
142
+ # 执行分类器自由引导
143
+ if do_classifier_free_guidance:
144
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
145
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
146
+
147
+ # 计算上一个样本
148
+ latents = self.step(noise_pred, t, latents, eta)
149
+
150
+ return latents
151
+
152
+
153
+ class DPMSampler:
154
+ """DPM采样器(更快)"""
155
+ def __init__(self, model: nn.Module, diffusion, num_inference_steps: int = 20):
156
+ self.model = model
157
+ self.diffusion = diffusion
158
+ self.num_inference_steps = num_inference_steps
159
+
160
+ @torch.no_grad()
161
+ def sample(
162
+ self,
163
+ prompt_embeds: torch.Tensor,
164
+ height: int = 512,
165
+ width: int = 512,
166
+ guidance_scale: float = 7.5,
167
+ progress_bar: bool = True
168
+ ) -> torch.Tensor:
169
+ """DPM采样"""
170
+ self.model.eval()
171
+
172
+ # 初始化潜在表示
173
+ latents = torch.randn(
174
+ (1, self.model.in_channels, height // 8, width // 8),
175
+ device=prompt_embeds.device
176
+ )
177
+
178
+ # 简化的DPM采样
179
+ timesteps = torch.linspace(1, 0, self.num_inference_steps + 1, device=latents.device)
180
+
181
+ if progress_bar:
182
+ timesteps_iter = tqdm(enumerate(timesteps[:-1]), total=len(timesteps)-1, desc="DPM Sampling")
183
+ else:
184
+ timesteps_iter = enumerate(timesteps[:-1])
185
+
186
+ for i, t in timesteps_iter:
187
+ # 预测噪声
188
+ noise_pred = self.model(latents, t.unsqueeze(0) * 999, prompt_embeds)
189
+
190
+ # 应用引导
191
+ if guidance_scale > 1.0:
192
+ # 简单引导
193
+ noise_pred = noise_pred * guidance_scale
194
+
195
+ # DPM更新步骤
196
+ dt = timesteps[i + 1] - t
197
+ latents = latents + dt * noise_pred
198
+
199
+ return latents
200
+
201
+
202
+ class LCMSampler:
203
+ """LCM(潜在一致性模型)采样器,极快"""
204
+ def __init__(self, model: nn.Module, diffusion, num_inference_steps: int = 4):
205
+ self.model = model
206
+ self.diffusion = diffusion
207
+ self.num_inference_steps = num_inference_steps
208
+
209
+ # LCM特定的参数
210
+ self.c_skip = 1.0
211
+ self.c_out = 1.0
212
+ self.c_in = 1.0
213
+ self.c_noise = 1.0
214
+
215
+ @torch.no_grad()
216
+ def sample(
217
+ self,
218
+ prompt_embeds: torch.Tensor,
219
+ height: int = 512,
220
+ width: int = 512,
221
+ guidance_scale: float = 7.5,
222
+ progress_bar: bool = True
223
+ ) -> torch.Tensor:
224
+ """LCM采样(极快,只需要4-8步)"""
225
+ self.model.eval()
226
+
227
+ # 初始化潜在表示
228
+ latents = torch.randn(
229
+ (1, self.model.in_channels, height // 8, width // 8),
230
+ device=prompt_embeds.device
231
+ )
232
+
233
+ # LCM采样循环
234
+ timesteps = torch.linspace(1, 0, self.num_inference_steps + 1, device=latents.device)
235
+
236
+ if progress_bar:
237
+ timesteps_iter = tqdm(enumerate(timesteps[:-1]), total=len(timesteps)-1, desc="LCM Sampling")
238
+ else:
239
+ timesteps_iter = enumerate(timesteps[:-1])
240
+
241
+ for i, t in timesteps_iter:
242
+ # LCM特定的缩放
243
+ c_skip = self.c_skip
244
+ c_out = self.c_out
245
+ c_in = self.c_in
246
+ c_noise = self.c_noise
247
+
248
+ # 缩放输入
249
+ scaled_latents = c_in * latents
250
+
251
+ # 预测
252
+ noise_pred = self.model(scaled_latents, c_noise * t.unsqueeze(0), prompt_embeds)
253
+
254
+ # LCM更新规则
255
+ denoised = c_skip * latents + c_out * noise_pred
256
+
257
+ # 更新潜在表示
258
+ dt = timesteps[i + 1] - t
259
+ latents = denoised + dt * noise_pred
260
+
261
+ return latents
262
+
263
+
264
+ class SamplerFactory:
265
+ """采样器工厂"""
266
+ @staticmethod
267
+ def create_sampler(
268
+ sampler_type: str,
269
+ model: nn.Module,
270
+ diffusion,
271
+ num_inference_steps: int = 50
272
+ ):
273
+ """创建采样器"""
274
+ if sampler_type == "ddim":
275
+ return DDIMSampler(model, diffusion, num_inference_steps)
276
+ elif sampler_type == "dpm":
277
+ return DPMSampler(model, diffusion, num_inference_steps)
278
+ elif sampler_type == "lcm":
279
+ return LCMSampler(model, diffusion, num_inference_steps)
280
+ else:
281
+ raise ValueError(f"未知的采样器类型: {sampler_type}")
282
+
283
+
284
+ class TextToImagePipeline:
285
+ """文本到图像管道"""
286
+ def __init__(
287
+ self,
288
+ model: nn.Module,
289
+ diffusion,
290
+ text_encoder,
291
+ vae_decoder,
292
+ sampler_type: str = "ddim",
293
+ device: str = "cuda"
294
+ ):
295
+ self.model = model.to(device)
296
+ self.diffusion = diffusion
297
+ self.text_encoder = text_encoder
298
+ self.vae_decoder = vae_decoder
299
+ self.sampler_type = sampler_type
300
+ self.device = device
301
+
302
+ # 创建采样器
303
+ self.sampler = SamplerFactory.create_sampler(
304
+ sampler_type, model, diffusion
305
+ )
306
+
307
+ # 设置为评估模式
308
+ self.model.eval()
309
+ if self.vae_decoder is not None:
310
+ self.vae_decoder.eval()
311
+
312
+ @torch.no_grad()
313
+ def __call__(
314
+ self,
315
+ prompt: str,
316
+ negative_prompt: str = "",
317
+ height: int = 512,
318
+ width: int = 512,
319
+ num_inference_steps: int = 50,
320
+ guidance_scale: float = 7.5,
321
+ num_images: int = 1,
322
+ seed: Optional[int] = None,
323
+ progress_bar: bool = True
324
+ ) -> List:
325
+ """生成图像"""
326
+ # 设置随机种子
327
+ if seed is not None:
328
+ torch.manual_seed(seed)
329
+ if torch.cuda.is_available():
330
+ torch.cuda.manual_seed(seed)
331
+
332
+ # 编码提示
333
+ prompt_embeds = self.text_encoder.encode([prompt]).to(self.device)
334
+ negative_prompt_embeds = None
335
+
336
+ if negative_prompt:
337
+ negative_prompt_embeds = self.text_encoder.encode([negative_prompt]).to(self.device)
338
+
339
+ # 生成潜在表示
340
+ latents = self.sampler.sample(
341
+ prompt_embeds=prompt_embeds,
342
+ negative_prompt_embeds=negative_prompt_embeds,
343
+ height=height,
344
+ width=width,
345
+ num_images_per_prompt=num_images,
346
+ guidance_scale=guidance_scale,
347
+ progress_bar=progress_bar
348
+ )
349
+
350
+ # 解码为图像
351
+ images = []
352
+ for i in range(num_images):
353
+ latent = latents[i:i+1]
354
+
355
+ if self.vae_decoder is not None:
356
+ image = self.vae_decoder(latent)
357
+ else:
358
+ # 如果没有VAE解码器,返回潜在表示
359
+ image = latent
360
+
361
+ images.append(image.cpu())
362
+
363
+ return images
364
+
365
+ def generate_grid(
366
+ self,
367
+ prompts: List[str],
368
+ grid_size: Tuple[int, int] = (2, 2),
369
+ **kwargs
370
+ ) -> torch.Tensor:
371
+ """生成图像网格"""
372
+ images = []
373
+
374
+ for prompt in prompts[:grid_size[0] * grid_size[1]]:
375
+ image = self(prompt, **kwargs)[0]
376
+ images.append(image)
377
+
378
+ # 创建网格
379
+ from torchvision.utils import make_grid
380
+ grid = make_grid(torch.cat(images, dim=0), nrow=grid_size[1])
381
+
382
+ return grid
383
+
384
+
385
+ def test_sampler():
386
+ """测试采样器"""
387
+ import torch.nn as nn
388
+
389
+ # 创建模拟模型
390
+ class MockModel(nn.Module):
391
+ def __init__(self):
392
+ super().__init__()
393
+ self.in_channels = 4
394
+
395
+ def forward(self, x, t, context):
396
+ # 返回随机噪声
397
+ return torch.randn_like(x)
398
+
399
+ # 创建模拟扩散过程
400
+ class MockDiffusion:
401
+ def __init__(self):
402
+ self.num_train_timesteps = 1000
403
+ self.alphas_cumprod = torch.ones(1000)
404
+ self.prediction_type = "epsilon"
405
+
406
+ def extract(self, a, t, x_shape):
407
+ return torch.ones(x_shape[0], 1, 1, 1)
408
+
409
+ def scale_model_input(self, x, t):
410
+ return x
411
+
412
+ model = MockModel()
413
+ diffusion = MockDiffusion()
414
+
415
+ # 测试DDIM采样器
416
+ sampler = DDIMSampler(model, diffusion, num_inference_steps=10)
417
+
418
+ # 测试采样
419
+ prompt_embeds = torch.randn(1, 77, 768)
420
+ latents = sampler.sample(prompt_embeds, height=64, width=64, progress_bar=False)
421
+
422
+ print(f"DDIM采样完成,潜在表示形状: {latents.shape}")
423
+
424
+ return sampler, latents
425
+
426
+
427
+ if __name__ == '__main__':
428
+ test_sampler()
src/models/attention.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Optional, Tuple
6
+
7
+
8
+ class MemoryEfficientAttention(nn.Module):
9
+ """内存高效的多头注意力"""
10
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dropout: float = 0.0):
11
+ super().__init__()
12
+ self.dim = dim
13
+ self.num_heads = num_heads
14
+ self.head_dim = dim // num_heads
15
+ self.scale = self.head_dim ** -0.5
16
+
17
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
18
+ self.proj = nn.Linear(dim, dim)
19
+ self.dropout = nn.Dropout(dropout)
20
+
21
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
22
+ B, N, C = x.shape
23
+
24
+ # 计算QKV
25
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
26
+ q, k, v = qkv[0], qkv[1], qkv[2]
27
+
28
+ # 分块计算注意力,避免OOM
29
+ chunk_size = min(32, N)
30
+ attn_output = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
31
+
32
+ for i in range(0, N, chunk_size):
33
+ q_chunk = q[:, :, i:i+chunk_size, :]
34
+
35
+ # 计算注意力分数
36
+ attn_scores = torch.matmul(q_chunk, k.transpose(-2, -1)) * self.scale
37
+
38
+ if mask is not None:
39
+ attn_scores = attn_scores + mask
40
+
41
+ attn_probs = F.softmax(attn_scores, dim=-1)
42
+ attn_probs = self.dropout(attn_probs)
43
+
44
+ # 计算输出
45
+ attn_output[:, :, i:i+chunk_size, :] = torch.matmul(attn_probs, v)
46
+
47
+ # 合并多头
48
+ attn_output = attn_output.transpose(1, 2).reshape(B, N, C)
49
+
50
+ # 输出投影
51
+ output = self.proj(attn_output)
52
+ output = self.dropout(output)
53
+
54
+ return output
55
+
56
+
57
+ class CrossAttention(nn.Module):
58
+ """交叉注意力(用于文本条件)"""
59
+ def __init__(self, query_dim: int, context_dim: int, num_heads: int = 8, dropout: float = 0.0):
60
+ super().__init__()
61
+ self.query_dim = query_dim
62
+ self.context_dim = context_dim
63
+ self.num_heads = num_heads
64
+ self.head_dim = query_dim // num_heads
65
+ self.scale = self.head_dim ** -0.5
66
+
67
+ self.to_q = nn.Linear(query_dim, query_dim)
68
+ self.to_k = nn.Linear(context_dim, query_dim)
69
+ self.to_v = nn.Linear(context_dim, query_dim)
70
+ self.proj = nn.Linear(query_dim, query_dim)
71
+ self.dropout = nn.Dropout(dropout)
72
+
73
+ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
74
+ B, N, C = x.shape
75
+
76
+ # 计算Q
77
+ q = self.to_q(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
78
+
79
+ # 计算K, V
80
+ k = self.to_k(context).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
81
+ v = self.to_v(context).reshape(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
82
+
83
+ # 分块计算注意力
84
+ chunk_size = min(32, N)
85
+ attn_output = torch.zeros(B, self.num_heads, N, self.head_dim, device=x.device)
86
+
87
+ for i in range(0, N, chunk_size):
88
+ q_chunk = q[:, :, i:i+chunk_size, :]
89
+
90
+ # 计算注意力分数
91
+ attn_scores = torch.matmul(q_chunk, k.transpose(-2, -1)) * self.scale
92
+
93
+ if mask is not None:
94
+ attn_scores = attn_scores + mask
95
+
96
+ attn_probs = F.softmax(attn_scores, dim=-1)
97
+ attn_probs = self.dropout(attn_probs)
98
+
99
+ # 计算输出
100
+ attn_output[:, :, i:i+chunk_size, :] = torch.matmul(attn_probs, v)
101
+
102
+ # 合并多头
103
+ attn_output = attn_output.transpose(1, 2).reshape(B, N, C)
104
+
105
+ # 输出投影
106
+ output = self.proj(attn_output)
107
+
108
+ return output
109
+
110
+
111
+ class FlashAttentionWrapper(nn.Module):
112
+ """FlashAttention包装器(如果可用)"""
113
+ def __init__(self, dim: int, num_heads: int = 8):
114
+ super().__init__()
115
+ self.dim = dim
116
+ self.num_heads = num_heads
117
+
118
+ try:
119
+ from flash_attn import flash_attn_qkvpacked_func
120
+ self.use_flash = True
121
+ except ImportError:
122
+ self.use_flash = False
123
+
124
+ if not self.use_flash:
125
+ self.attention = MemoryEfficientAttention(dim, num_heads)
126
+
127
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
128
+ if self.use_flash:
129
+ return self._flash_attention(x)
130
+ else:
131
+ return self.attention(x)
132
+
133
+ def _flash_attention(self, x: torch.Tensor) -> torch.Tensor:
134
+ # FlashAttention实现
135
+ B, N, C = x.shape
136
+ qkv = x.reshape(B, N, 3, self.num_heads, C // self.num_heads)
137
+ qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, num_heads, N, head_dim]
138
+
139
+ from flash_attn import flash_attn_qkvpacked_func
140
+ output = flash_attn_qkvpacked_func(qkv)
141
+ output = output.reshape(B, N, C)
142
+
143
+ return output
src/models/diffusion.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from typing import Optional, Tuple, Union
5
+ import math
6
+
7
+
8
+ class BetaScheduler:
9
+ """Beta调度器"""
10
+ @staticmethod
11
+ def linear(num_timesteps: int, beta_start: float = 0.0001, beta_end: float = 0.02) -> np.ndarray:
12
+ """线性调度"""
13
+ return np.linspace(beta_start, beta_end, num_timesteps, dtype=np.float32)
14
+
15
+ @staticmethod
16
+ def cosine(num_timesteps: int, s: float = 0.008) -> np.ndarray:
17
+ """余弦调度"""
18
+ steps = num_timesteps + 1
19
+ x = np.linspace(0, num_timesteps, steps)
20
+ alphas_cumprod = np.cos(((x / num_timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
21
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
22
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
23
+ return np.clip(betas, 0, 0.999)
24
+
25
+ @staticmethod
26
+ def scaled_linear(num_timesteps: int) -> np.ndarray:
27
+ """缩放线性调度(Stable Diffusion默认)"""
28
+ beta_start = 0.00085
29
+ beta_end = 0.012
30
+ return np.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
31
+
32
+
33
+ class DiffusionProcess:
34
+ """扩散过程管理"""
35
+ def __init__(self, config: dict):
36
+ self.config = config
37
+ diff_config = config.get('diffusion', {})
38
+
39
+ self.num_train_timesteps = diff_config.get('num_train_timesteps', 1000)
40
+ self.num_inference_timesteps = diff_config.get('num_inference_timesteps', 50)
41
+ self.beta_start = diff_config.get('beta_start', 0.00085)
42
+ self.beta_end = diff_config.get('beta_end', 0.012)
43
+ self.beta_schedule = diff_config.get('beta_schedule', 'scaled_linear')
44
+ self.prediction_type = diff_config.get('prediction_type', 'epsilon')
45
+
46
+ # 初始化调度参数
47
+ self._init_schedule()
48
+
49
+ def _init_schedule(self):
50
+ """初始化扩散调度参数"""
51
+ # 计算betas
52
+ if self.beta_schedule == "linear":
53
+ betas = BetaScheduler.linear(
54
+ self.num_train_timesteps,
55
+ self.beta_start,
56
+ self.beta_end
57
+ )
58
+ elif self.beta_schedule == "cosine":
59
+ betas = BetaScheduler.cosine(self.num_train_timesteps)
60
+ elif self.beta_schedule == "scaled_linear":
61
+ betas = BetaScheduler.scaled_linear(self.num_train_timesteps)
62
+ else:
63
+ raise ValueError(f"Unknown beta schedule: {self.beta_schedule}")
64
+
65
+ self.betas = torch.from_numpy(betas).float()
66
+
67
+ # 计算alphas
68
+ self.alphas = 1.0 - self.betas
69
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
70
+ self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
71
+
72
+ # 计算扩散后验方差
73
+ self.variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
74
+
75
+ # 注册为buffer
76
+ self.register_buffer = lambda name, tensor: setattr(self, name, tensor)
77
+ self.register_buffer('betas', self.betas)
78
+ self.register_buffer('alphas', self.alphas)
79
+ self.register_buffer('alphas_cumprod', self.alphas_cumprod)
80
+ self.register_buffer('alphas_cumprod_prev', self.alphas_cumprod_prev)
81
+ self.register_buffer('variance', self.variance)
82
+
83
+ # 计算采样系数
84
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
85
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - self.alphas_cumprod))
86
+ self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1.0 - self.alphas_cumprod))
87
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1.0 / self.alphas_cumprod))
88
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1.0 / self.alphas_cumprod - 1))
89
+
90
+ def q_sample(self, x_start: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
91
+ """前向扩散过程:加噪"""
92
+ if noise is None:
93
+ noise = torch.randn_like(x_start)
94
+
95
+ sqrt_alphas_cumprod_t = self.extract(self.sqrt_alphas_cumprod, t, x_start.shape)
96
+ sqrt_one_minus_alphas_cumprod_t = self.extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
97
+
98
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
99
+
100
+ def extract(self, a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...]) -> torch.Tensor:
101
+ """从张量a中提取索引t处的值"""
102
+ batch_size = t.shape[0]
103
+ out = a.gather(-1, t.cpu())
104
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
105
+
106
+ def get_loss_weight(self, snr: torch.Tensor, gamma: float = 5.0) -> torch.Tensor:
107
+ """根据SNR计算损失权重"""
108
+ if gamma is None:
109
+ return torch.ones_like(snr)
110
+
111
+ snr = torch.clamp(snr, min=1e-8)
112
+ min_snr = torch.tensor(gamma, device=snr.device)
113
+ weight = torch.minimum(snr, min_snr) / snr
114
+ return weight
115
+
116
+ def compute_snr(self, timesteps: torch.Tensor) -> torch.Tensor:
117
+ """计算信噪比(SNR)"""
118
+ alphas_cumprod = self.extract(self.alphas_cumprod, timesteps, timesteps.shape)
119
+ snr = alphas_cumprod / (1 - alphas_cumprod)
120
+ return snr
121
+
122
+
123
+ class DDIMScheduler:
124
+ """DDIM采样器"""
125
+ def __init__(self, diffusion: DiffusionProcess):
126
+ self.diffusion = diffusion
127
+ self.num_train_timesteps = diffusion.num_train_timesteps
128
+ self.num_inference_timesteps = diffusion.num_inference_timesteps
129
+
130
+ # 设置时间步
131
+ self.set_timesteps(self.num_inference_timesteps)
132
+
133
+ def set_timesteps(self, num_inference_timesteps: int):
134
+ """设置推理时间步"""
135
+ self.num_inference_timesteps = num_inference_timesteps
136
+
137
+ # 选择时间步
138
+ if self.num_train_timesteps == self.num_inference_timesteps:
139
+ self.timesteps = torch.arange(0, self.num_train_timesteps).long()
140
+ else:
141
+ step_ratio = self.num_train_timesteps // self.num_inference_timesteps
142
+ self.timesteps = torch.arange(0, self.num_train_timesteps, step_ratio).long()
143
+
144
+ self.timesteps = self.timesteps.flip(0) # 从T到0
145
+
146
+ @torch.no_grad()
147
+ def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, eta: float = 0.0) -> torch.Tensor:
148
+ """DDIM单步采样"""
149
+ # 获取当前时间步的参数
150
+ prev_timestep = timestep - self.num_train_timesteps // self.num_inference_timesteps
151
+
152
+ # 提取alpha参数
153
+ alpha_prod_t = self.diffusion.extract(self.diffusion.alphas_cumprod, timestep, sample.shape)
154
+ alpha_prod_t_prev = self.diffusion.extract(
155
+ self.diffusion.alphas_cumprod,
156
+ prev_timestep,
157
+ sample.shape
158
+ ) if prev_timestep >= 0 else torch.ones_like(alpha_prod_t)
159
+
160
+ # 根据预测类型处理模型输出
161
+ if self.diffusion.prediction_type == "epsilon":
162
+ pred_original_sample = (sample - (1 - alpha_prod_t) ** 0.5 * model_output) / alpha_prod_t ** 0.5
163
+ pred_epsilon = model_output
164
+ elif self.diffusion.prediction_type == "sample":
165
+ pred_original_sample = model_output
166
+ pred_epsilon = (sample - alpha_prod_t ** 0.5 * pred_original_sample) / (1 - alpha_prod_t) ** 0.5
167
+ elif self.diffusion.prediction_type == "v_prediction":
168
+ pred_original_sample = (alpha_prod_t ** 0.5) * sample - (1 - alpha_prod_t) ** 0.5 * model_output
169
+ pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (1 - alpha_prod_t) ** 0.5 * sample
170
+ else:
171
+ raise ValueError(f"Unsupported prediction type: {self.diffusion.prediction_type}")
172
+
173
+ # 计算x_t-1的方差
174
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
175
+ std_dev_t = eta * variance ** 0.5
176
+
177
+ # 计算x_t-1的均值
178
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon
179
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
180
+
181
+ # 添加噪声
182
+ if eta > 0:
183
+ noise = torch.randn_like(model_output)
184
+ prev_sample = prev_sample + std_dev_t * noise
185
+
186
+ return prev_sample
187
+
188
+
189
+ class DiffusionModel(nn.Module):
190
+ """扩散模型封装"""
191
+ def __init__(self, unet: nn.Module, diffusion: DiffusionProcess):
192
+ super().__init__()
193
+ self.unet = unet
194
+ self.diffusion = diffusion
195
+ self.scheduler = DDIMScheduler(diffusion)
196
+
197
+ def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
198
+ """前向传播:预测噪声"""
199
+ return self.unet(x, timesteps, context)
200
+
201
+ def compute_loss(self, x_start: torch.Tensor, context: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
202
+ """计算扩散损失"""
203
+ if noise is None:
204
+ noise = torch.randn_like(x_start)
205
+
206
+ # 随机采样时间步
207
+ batch_size = x_start.shape[0]
208
+ timesteps = torch.randint(
209
+ 0, self.diffusion.num_train_timesteps,
210
+ (batch_size,), device=x_start.device
211
+ ).long()
212
+
213
+ # 前向扩散
214
+ x_noisy = self.diffusion.q_sample(x_start, timesteps, noise)
215
+
216
+ # 预测噪声
217
+ predicted_noise = self.unet(x_noisy, timesteps, context)
218
+
219
+ # 计算损失
220
+ loss = F.mse_loss(predicted_noise, noise)
221
+
222
+ return loss
223
+
224
+ @torch.no_grad()
225
+ def generate(
226
+ self,
227
+ context: torch.Tensor,
228
+ num_samples: int = 1,
229
+ height: int = 512,
230
+ width: int = 512,
231
+ guidance_scale: float = 7.5
232
+ ) -> torch.Tensor:
233
+ """生成图像"""
234
+ # 初始化噪声
235
+ latents = torch.randn(
236
+ (num_samples, self.unet.in_channels, height // 8, width // 8),
237
+ device=next(self.unet.parameters()).device
238
+ )
239
+
240
+ # DDIM采样
241
+ self.scheduler.set_timesteps(self.diffusion.num_inference_timesteps)
242
+
243
+ for t in self.scheduler.timesteps:
244
+ # 扩展latents以匹配批大小
245
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
246
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
247
+
248
+ # 预测噪声
249
+ timesteps = torch.full((num_samples,), t, device=latents.device).long()
250
+ if guidance_scale > 1.0:
251
+ timesteps = torch.cat([timesteps] * 2)
252
+
253
+ noise_pred = self.unet(latent_model_input, timesteps, context)
254
+
255
+ # 应用分类器自由引导
256
+ if guidance_scale > 1.0:
257
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
258
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
259
+
260
+ # 计算上一个样本
261
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
262
+
263
+ return latents
src/models/unet_light.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Optional, Tuple, Union
5
+ import math
6
+
7
+
8
+ class TimestepEmbedding(nn.Module):
9
+ """时间步嵌入"""
10
+ def __init__(self, embedding_dim: int, time_embed_dim: int):
11
+ super().__init__()
12
+ self.embedding_dim = embedding_dim
13
+ self.time_embed_dim = time_embed_dim
14
+
15
+ self.mlp = nn.Sequential(
16
+ nn.Linear(embedding_dim, time_embed_dim),
17
+ nn.SiLU(),
18
+ nn.Linear(time_embed_dim, time_embed_dim)
19
+ )
20
+
21
+ def forward(self, timestep: torch.Tensor) -> torch.Tensor:
22
+ # 正弦位置编码
23
+ half_dim = self.embedding_dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, device=timestep.device) * -emb)
26
+ emb = timestep[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+
29
+ if self.embedding_dim % 2 == 1:
30
+ emb = F.pad(emb, (0, 1, 0, 0))
31
+
32
+ return self.mlp(emb)
33
+
34
+
35
+ class AttentionBlock(nn.Module):
36
+ """内存高效的注意力块"""
37
+ def __init__(self, channels: int, num_heads: int = 4, use_checkpoint: bool = True):
38
+ super().__init__()
39
+ self.channels = channels
40
+ self.num_heads = num_heads
41
+ self.use_checkpoint = use_checkpoint
42
+
43
+ self.norm = nn.GroupNorm(32, channels)
44
+ self.qkv = nn.Conv2d(channels, channels * 3, 1)
45
+ self.proj_out = nn.Conv2d(channels, channels, 1)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ if self.use_checkpoint and self.training:
49
+ return torch.utils.checkpoint.checkpoint(self._forward, x)
50
+ return self._forward(x)
51
+
52
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ B, C, H, W = x.shape
54
+ qkv = self.qkv(self.norm(x))
55
+ q, k, v = qkv.chunk(3, dim=1)
56
+
57
+ # 分块计算注意力,避免OOM
58
+ chunk_size = min(32, H * W)
59
+ q = q.view(B, self.num_heads, C // self.num_heads, H * W).permute(0, 1, 3, 2)
60
+ k = k.view(B, self.num_heads, C // self.num_heads, H * W).permute(0, 1, 2, 3)
61
+ v = v.view(B, self.num_heads, C // self.num_heads, H * W).permute(0, 1, 3, 2)
62
+
63
+ # 分块计算
64
+ attn_output = torch.zeros_like(q)
65
+ for i in range(0, H * W, chunk_size):
66
+ q_chunk = q[:, :, i:i+chunk_size, :]
67
+ scores = torch.matmul(q_chunk, k) / math.sqrt(C // self.num_heads)
68
+ attn = F.softmax(scores, dim=-1)
69
+ attn_output[:, :, i:i+chunk_size, :] = torch.matmul(attn, v)
70
+
71
+ attn_output = attn_output.permute(0, 1, 3, 2).reshape(B, C, H, W)
72
+ return x + self.proj_out(attn_output)
73
+
74
+
75
+ class ResNetBlock(nn.Module):
76
+ """残差块"""
77
+ def __init__(self, in_channels: int, out_channels: int, time_embed_dim: int, dropout: float = 0.0):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ self.out_channels = out_channels
81
+
82
+ # 第一个归一化+激活+卷积
83
+ self.norm1 = nn.GroupNorm(32, in_channels)
84
+ self.act1 = nn.SiLU()
85
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
86
+
87
+ # 时间嵌入投影
88
+ self.time_emb_proj = nn.Linear(time_embed_dim, out_channels)
89
+
90
+ # 第二个归一化+激活+卷积
91
+ self.norm2 = nn.GroupNorm(32, out_channels)
92
+ self.act2 = nn.SiLU()
93
+ self.dropout = nn.Dropout(dropout)
94
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
95
+
96
+ # 如果需要调整通道数
97
+ if in_channels != out_channels:
98
+ self.skip_conv = nn.Conv2d(in_channels, out_channels, 1)
99
+ else:
100
+ self.skip_conv = nn.Identity()
101
+
102
+ def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
103
+ h = self.conv1(self.act1(self.norm1(x)))
104
+
105
+ # 添加时间嵌入
106
+ time_emb = self.time_emb_proj(F.silu(time_emb))
107
+ h = h + time_emb[:, :, None, None]
108
+
109
+ h = self.conv2(self.dropout(self.act2(self.norm2(h))))
110
+
111
+ return h + self.skip_conv(x)
112
+
113
+
114
+ class CrossAttentionBlock(nn.Module):
115
+ """交叉注意力块(文本条件)"""
116
+ def __init__(self, query_dim: int, context_dim: int, num_heads: int = 4, use_checkpoint: bool = True):
117
+ super().__init__()
118
+ self.query_dim = query_dim
119
+ self.context_dim = context_dim
120
+ self.num_heads = num_heads
121
+ self.use_checkpoint = use_checkpoint
122
+
123
+ self.norm = nn.GroupNorm(32, query_dim)
124
+ self.to_q = nn.Linear(query_dim, query_dim)
125
+ self.to_k = nn.Linear(context_dim, query_dim)
126
+ self.to_v = nn.Linear(context_dim, query_dim)
127
+ self.to_out = nn.Linear(query_dim, query_dim)
128
+
129
+ def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
130
+ if self.use_checkpoint and self.training:
131
+ return torch.utils.checkpoint.checkpoint(self._forward, x, context)
132
+ return self._forward(x, context)
133
+
134
+ def _forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
135
+ B, C, H, W = x.shape
136
+ x_reshaped = x.reshape(B, C, H * W).permute(0, 2, 1)
137
+ x_norm = self.norm(x_reshaped)
138
+
139
+ q = self.to_q(x_norm)
140
+ k = self.to_k(context)
141
+ v = self.to_v(context)
142
+
143
+ # 分头
144
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
145
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
146
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)
147
+
148
+ # 分块计算注意力
149
+ chunk_size = min(32, q.shape[2])
150
+ attn_output = torch.zeros_like(q)
151
+
152
+ for i in range(0, q.shape[2], chunk_size):
153
+ q_chunk = q[:, :, i:i+chunk_size, :]
154
+ scores = torch.matmul(q_chunk, k.transpose(-2, -1)) / math.sqrt(C // self.num_heads)
155
+ attn = F.softmax(scores, dim=-1)
156
+ attn_output[:, :, i:i+chunk_size, :] = torch.matmul(attn, v)
157
+
158
+ attn_output = attn_output.transpose(1, 2).reshape(B, -1, C)
159
+ attn_output = self.to_out(attn_output)
160
+
161
+ return (x_reshaped + attn_output).permute(0, 2, 1).reshape(B, C, H, W)
162
+
163
+
164
+ class DownsampleBlock(nn.Module):
165
+ """下采样块"""
166
+ def __init__(self, channels: int, use_conv: bool = True):
167
+ super().__init__()
168
+ if use_conv:
169
+ self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
170
+ else:
171
+ self.op = nn.AvgPool2d(2)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ return self.op(x)
175
+
176
+
177
+ class UpsampleBlock(nn.Module):
178
+ """上采样块"""
179
+ def __init__(self, channels: int, use_conv: bool = True):
180
+ super().__init__()
181
+ self.channels = channels
182
+ if use_conv:
183
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
184
+ else:
185
+ self.conv = nn.Identity()
186
+
187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
188
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
189
+ return self.conv(x)
190
+
191
+
192
+ class UNetLight(nn.Module):
193
+ """轻量级UNet模型"""
194
+ def __init__(self, config: dict):
195
+ super().__init__()
196
+ self.config = config
197
+ model_config = config.get('model', {})
198
+
199
+ # 基本参数
200
+ self.in_channels = model_config.get('in_channels', 4)
201
+ self.out_channels = model_config.get('out_channels', 4)
202
+ self.base_channels = model_config.get('base_channels', 64)
203
+ self.channel_mults = model_config.get('channel_mults', [1, 2, 4, 8])
204
+ self.num_res_blocks = model_config.get('num_res_blocks', 2)
205
+ self.attention_resolutions = model_config.get('attention_resolutions', [8])
206
+ self.dropout = model_config.get('dropout', 0.0)
207
+ self.use_checkpoint = model_config.get('use_checkpoint', True)
208
+ self.num_heads = model_config.get('num_heads', 4)
209
+
210
+ # 条件参数
211
+ self.context_dim = model_config.get('context_dim', 768)
212
+ self.use_linear_projection = model_config.get('use_linear_projection', True)
213
+
214
+ # 时间嵌入
215
+ self.time_embed_dim = model_config.get('time_embed_dim', 256)
216
+ self.time_embed = TimestepEmbedding(self.base_channels, self.time_embed_dim)
217
+
218
+ # 输入卷积
219
+ self.input_conv = nn.Conv2d(self.in_channels, self.base_channels, 3, padding=1)
220
+
221
+ # 构建下采样路径
222
+ self.down_blocks = nn.ModuleList()
223
+ self.down_attention_blocks = nn.ModuleList()
224
+ self.downsample_blocks = nn.ModuleList()
225
+
226
+ in_ch = self.base_channels
227
+ resolution = 1
228
+
229
+ for i, mult in enumerate(self.channel_mults):
230
+ out_ch = self.base_channels * mult
231
+ resolution *= 2
232
+
233
+ # 残差块
234
+ for _ in range(self.num_res_blocks):
235
+ block = ResNetBlock(in_ch, out_ch, self.time_embed_dim, self.dropout)
236
+ self.down_blocks.append(block)
237
+
238
+ # 在指定分辨率添加注意力
239
+ if resolution in self.attention_resolutions:
240
+ attn = CrossAttentionBlock(out_ch, self.context_dim, self.num_heads, self.use_checkpoint)
241
+ self.down_attention_blocks.append(attn)
242
+ else:
243
+ self.down_attention_blocks.append(None)
244
+
245
+ in_ch = out_ch
246
+
247
+ # 如果不是最后一层,添加下采样
248
+ if i != len(self.channel_mults) - 1:
249
+ downsample = DownsampleBlock(in_ch, use_conv=True)
250
+ self.downsample_blocks.append(downsample)
251
+
252
+ # 中间层
253
+ self.mid_block1 = ResNetBlock(in_ch, in_ch, self.time_embed_dim, self.dropout)
254
+ self.mid_attention = CrossAttentionBlock(in_ch, self.context_dim, self.num_heads, self.use_checkpoint)
255
+ self.mid_block2 = ResNetBlock(in_ch, in_ch, self.time_embed_dim, self.dropout)
256
+
257
+ # 构建上采样路径
258
+ self.up_blocks = nn.ModuleList()
259
+ self.up_attention_blocks = nn.ModuleList()
260
+ self.upsample_blocks = nn.ModuleList()
261
+
262
+ for i, mult in enumerate(reversed(self.channel_mults)):
263
+ out_ch = self.base_channels * mult
264
+ resolution //= 2
265
+
266
+ # 上采样块
267
+ if i > 0:
268
+ upsample = UpsampleBlock(in_ch, use_conv=True)
269
+ self.upsample_blocks.append(upsample)
270
+
271
+ # 残差块
272
+ for j in range(self.num_res_blocks + 1):
273
+ block_in_ch = in_ch * 2 if j == 0 else out_ch
274
+ block = ResNetBlock(block_in_ch, out_ch, self.time_embed_dim, self.dropout)
275
+ self.up_blocks.append(block)
276
+
277
+ # 在指定分辨率添加注意力
278
+ if resolution in self.attention_resolutions:
279
+ attn = CrossAttentionBlock(out_ch, self.context_dim, self.num_heads, self.use_checkpoint)
280
+ self.up_attention_blocks.append(attn)
281
+ else:
282
+ self.up_attention_blocks.append(None)
283
+
284
+ in_ch = out_ch
285
+
286
+ # 输出层
287
+ self.norm_out = nn.GroupNorm(32, self.base_channels)
288
+ self.act_out = nn.SiLU()
289
+ self.output_conv = nn.Conv2d(self.base_channels, self.out_channels, 3, padding=1)
290
+
291
+ # 如果需要,初始化权重
292
+ self.apply(self._init_weights)
293
+
294
+ def _init_weights(self, module):
295
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
296
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
297
+ if module.bias is not None:
298
+ nn.init.constant_(module.bias, 0)
299
+ elif isinstance(module, nn.GroupNorm):
300
+ nn.init.constant_(module.weight, 1)
301
+ nn.init.constant_(module.bias, 0)
302
+
303
+ def forward(self, x: torch.Tensor, timesteps: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
304
+ # 时间嵌入
305
+ time_emb = self.time_embed(timesteps)
306
+
307
+ # 初始卷积
308
+ h = self.input_conv(x)
309
+
310
+ # 存储跳跃连接
311
+ down_h = []
312
+
313
+ # 下采样路径
314
+ down_idx = 0
315
+ attn_idx = 0
316
+
317
+ for i, mult in enumerate(self.channel_mults):
318
+ for j in range(self.num_res_blocks):
319
+ block = self.down_blocks[down_idx]
320
+ h = block(h, time_emb)
321
+ down_idx += 1
322
+
323
+ # 注意力
324
+ if self.down_attention_blocks[attn_idx] is not None:
325
+ h = self.down_attention_blocks[attn_idx](h, context)
326
+ attn_idx += 1
327
+
328
+ down_h.append(h)
329
+
330
+ # 下采样(除了最后一层)
331
+ if i != len(self.channel_mults) - 1:
332
+ downsample = self.downsample_blocks[i]
333
+ h = downsample(h)
334
+
335
+ # 中间层
336
+ h = self.mid_block1(h, time_emb)
337
+ h = self.mid_attention(h, context)
338
+ h = self.mid_block2(h, time_emb)
339
+
340
+ # 上采样路径
341
+ up_idx = 0
342
+ attn_up_idx = 0
343
+ upsample_idx = 0
344
+
345
+ for i, mult in enumerate(reversed(self.channel_mults)):
346
+ # 上采样(除了第一层)
347
+ if i > 0:
348
+ upsample = self.upsample_blocks[upsample_idx]
349
+ h = upsample(h)
350
+ upsample_idx += 1
351
+
352
+ for j in range(self.num_res_blocks + 1):
353
+ # 拼接跳跃连接
354
+ if j == 0:
355
+ skip = down_h.pop()
356
+ h = torch.cat([h, skip], dim=1)
357
+
358
+ block = self.up_blocks[up_idx]
359
+ h = block(h, time_emb)
360
+ up_idx += 1
361
+
362
+ # 注意力
363
+ if self.up_attention_blocks[attn_up_idx] is not None:
364
+ h = self.up_attention_blocks[attn_up_idx](h, context)
365
+ attn_up_idx += 1
366
+
367
+ # 输出层
368
+ h = self.norm_out(h)
369
+ h = self.act_out(h)
370
+ h = self.output_conv(h)
371
+
372
+ return h
373
+
374
+ def enable_gradient_checkpointing(self):
375
+ """启用梯度检查点"""
376
+ self.use_checkpoint = True
377
+ for module in self.modules():
378
+ if hasattr(module, 'use_checkpoint'):
379
+ module.use_checkpoint = True
src/training/callbacks.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Any, Optional, List
3
+ import os
4
+ from datetime import datetime
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torchvision.transforms as T
8
+
9
+
10
+ class Callback:
11
+ """回调基类"""
12
+ def on_train_begin(self, trainer):
13
+ pass
14
+
15
+ def on_train_end(self, trainer):
16
+ pass
17
+
18
+ def on_epoch_begin(self, trainer, epoch):
19
+ pass
20
+
21
+ def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
22
+ pass
23
+
24
+ def on_batch_begin(self, trainer, batch_idx, batch):
25
+ pass
26
+
27
+ def on_batch_end(self, trainer, batch_idx, batch, loss):
28
+ pass
29
+
30
+ def on_validation_begin(self, trainer):
31
+ pass
32
+
33
+ def on_validation_end(self, trainer, val_loss):
34
+ pass
35
+
36
+
37
+ class EarlyStopping(Callback):
38
+ """早停回调"""
39
+ def __init__(self, patience: int = 10, min_delta: float = 1e-4):
40
+ self.patience = patience
41
+ self.min_delta = min_delta
42
+ self.best_loss = float('inf')
43
+ self.counter = 0
44
+ self.should_stop = False
45
+
46
+ def on_validation_end(self, trainer, val_loss):
47
+ if val_loss < self.best_loss - self.min_delta:
48
+ self.best_loss = val_loss
49
+ self.counter = 0
50
+ else:
51
+ self.counter += 1
52
+
53
+ if self.counter >= self.patience:
54
+ self.should_stop = True
55
+ print(f"早停触发,最佳损失: {self.best_loss:.4f}")
56
+
57
+
58
+ class ModelCheckpoint(Callback):
59
+ """模型检查点回调"""
60
+ def __init__(
61
+ self,
62
+ save_dir: str = './checkpoints',
63
+ save_best_only: bool = True,
64
+ save_freq: int = 1,
65
+ monitor: str = 'val_loss',
66
+ mode: str = 'min'
67
+ ):
68
+ self.save_dir = save_dir
69
+ self.save_best_only = save_best_only
70
+ self.save_freq = save_freq
71
+ self.monitor = monitor
72
+ self.mode = mode
73
+
74
+ os.makedirs(save_dir, exist_ok=True)
75
+
76
+ self.best_value = float('inf') if mode == 'min' else -float('inf')
77
+
78
+ def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
79
+ if epoch % self.save_freq != 0:
80
+ return
81
+
82
+ # 获取监控的值
83
+ if self.monitor == 'val_loss':
84
+ value = val_loss
85
+ elif self.monitor == 'train_loss':
86
+ value = train_loss
87
+ else:
88
+ value = val_loss
89
+
90
+ # 检查是否需要保存
91
+ should_save = False
92
+ if self.save_best_only:
93
+ if self.mode == 'min' and value < self.best_value:
94
+ self.best_value = value
95
+ should_save = True
96
+ elif self.mode == 'max' and value > self.best_value:
97
+ self.best_value = value
98
+ should_save = True
99
+ else:
100
+ should_save = True
101
+
102
+ if should_save:
103
+ # 保存检查点
104
+ checkpoint = {
105
+ 'epoch': epoch,
106
+ 'model_state_dict': trainer.model.state_dict(),
107
+ 'optimizer_state_dict': trainer.optimizer.state_dict(),
108
+ 'train_loss': train_loss,
109
+ 'val_loss': val_loss,
110
+ }
111
+
112
+ if trainer.use_ema:
113
+ checkpoint['ema_model_state_dict'] = trainer.ema_model.state_dict()
114
+
115
+ filename = f'checkpoint_epoch_{epoch}.pt' if not self.save_best_only else 'best_model.pt'
116
+ save_path = os.path.join(self.save_dir, filename)
117
+
118
+ torch.save(checkpoint, save_path)
119
+ print(f"检查点已保存: {save_path}")
120
+
121
+
122
+ class LearningRateSchedulerCallback(Callback):
123
+ """学习率调度回调"""
124
+ def __init__(self, scheduler, update_on: str = 'epoch'):
125
+ self.scheduler = scheduler
126
+ self.update_on = update_on # 'epoch' 或 'batch'
127
+
128
+ def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
129
+ if self.update_on == 'epoch':
130
+ self.scheduler.step()
131
+
132
+ def on_batch_end(self, trainer, batch_idx, batch, loss):
133
+ if self.update_on == 'batch':
134
+ self.scheduler.step()
135
+
136
+
137
+ class TensorBoardLogger(Callback):
138
+ """TensorBoard日志记录器"""
139
+ def __init__(self, log_dir: str = './logs'):
140
+ from torch.utils.tensorboard import SummaryWriter
141
+ self.writer = SummaryWriter(log_dir=log_dir)
142
+ self.global_step = 0
143
+
144
+ def on_batch_end(self, trainer, batch_idx, batch, loss):
145
+ self.writer.add_scalar('train/loss', loss, self.global_step)
146
+ self.writer.add_scalar('train/lr', trainer.optimizer.param_groups[0]['lr'], self.global_step)
147
+ self.global_step += 1
148
+
149
+ def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
150
+ self.writer.add_scalar('epoch/train_loss', train_loss, epoch)
151
+ self.writer.add_scalar('epoch/val_loss', val_loss, epoch)
152
+
153
+ def on_train_end(self, trainer):
154
+ self.writer.close()
155
+
156
+
157
+ class SampleGeneratorCallback(Callback):
158
+ """样本生成回调"""
159
+ def __init__(
160
+ self,
161
+ sample_freq: int = 500,
162
+ num_samples: int = 4,
163
+ save_dir: str = './samples'
164
+ ):
165
+ self.sample_freq = sample_freq
166
+ self.num_samples = num_samples
167
+ self.save_dir = save_dir
168
+
169
+ os.makedirs(save_dir, exist_ok=True)
170
+
171
+ def on_batch_end(self, trainer, batch_idx, batch, loss):
172
+ if trainer.global_step % self.sample_freq != 0:
173
+ return
174
+
175
+ # 生成样本
176
+ trainer.model.eval()
177
+
178
+ with torch.no_grad():
179
+ # 使用验证集的提示
180
+ sample_batch = next(iter(trainer.val_loader))
181
+ text_embeddings = sample_batch['text_embeddings'][:self.num_samples].to(trainer.device)
182
+
183
+ # 生成潜在表示
184
+ latents = trainer.diffusion.generate(
185
+ context=text_embeddings,
186
+ num_samples=self.num_samples,
187
+ guidance_scale=7.5
188
+ )
189
+
190
+ # 保存样本
191
+ for i in range(self.num_samples):
192
+ sample_path = os.path.join(
193
+ self.save_dir,
194
+ f'step_{trainer.global_step}_sample_{i}.pt'
195
+ )
196
+ torch.save(latents[i].cpu(), sample_path)
197
+
198
+ trainer.model.train()
199
+
200
+
201
+ class MemoryMonitorCallback(Callback):
202
+ """内存监控回调"""
203
+ def __init__(self, monitor_freq: int = 100):
204
+ self.monitor_freq = monitor_freq
205
+
206
+ def on_batch_end(self, trainer, batch_idx, batch, loss):
207
+ if trainer.global_step % self.monitor_freq == 0:
208
+ if hasattr(trainer, 'memory_manager'):
209
+ trainer.memory_manager.print_memory_stats()
210
+
211
+
212
+ class GradientMonitorCallback(Callback):
213
+ """梯度监控回调"""
214
+ def __init__(self, monitor_freq: int = 100):
215
+ self.monitor_freq = monitor_freq
216
+
217
+ def on_batch_end(self, trainer, batch_idx, batch, loss):
218
+ if trainer.global_step % self.monitor_freq == 0:
219
+ grad_norm = self._compute_gradient_norm(trainer.model)
220
+
221
+ if hasattr(trainer, 'writer'):
222
+ trainer.writer.add_scalar('train/grad_norm', grad_norm, trainer.global_step)
223
+
224
+ def _compute_gradient_norm(self, model) -> float:
225
+ total_norm = 0.0
226
+ for p in model.parameters():
227
+ if p.grad is not None:
228
+ param_norm = p.grad.data.norm(2)
229
+ total_norm += param_norm.item() ** 2
230
+ return total_norm ** 0.5
231
+
232
+
233
+ class CallbackHandler:
234
+ """回调处理器"""
235
+ def __init__(self):
236
+ self.callbacks = []
237
+
238
+ def add_callback(self, callback: Callback):
239
+ self.callbacks.append(callback)
240
+
241
+ def on_train_begin(self, trainer):
242
+ for callback in self.callbacks:
243
+ callback.on_train_begin(trainer)
244
+
245
+ def on_train_end(self, trainer):
246
+ for callback in self.callbacks:
247
+ callback.on_train_end(trainer)
248
+
249
+ def on_epoch_begin(self, trainer, epoch):
250
+ for callback in self.callbacks:
251
+ callback.on_epoch_begin(trainer, epoch)
252
+
253
+ def on_epoch_end(self, trainer, epoch, train_loss, val_loss):
254
+ for callback in self.callbacks:
255
+ callback.on_epoch_end(trainer, epoch, train_loss, val_loss)
256
+
257
+ def on_batch_begin(self, trainer, batch_idx, batch):
258
+ for callback in self.callbacks:
259
+ callback.on_batch_begin(trainer, batch_idx, batch)
260
+
261
+ def on_batch_end(self, trainer, batch_idx, batch, loss):
262
+ for callback in self.callbacks:
263
+ callback.on_batch_end(trainer, batch_idx, batch, loss)
264
+
265
+ def on_validation_begin(self, trainer):
266
+ for callback in self.callbacks:
267
+ callback.on_validation_begin(trainer)
268
+
269
+ def on_validation_end(self, trainer, val_loss):
270
+ for callback in self.callbacks:
271
+ callback.on_validation_end(trainer, val_loss)
272
+
273
+
274
+ def create_default_callbacks(config: dict) -> CallbackHandler:
275
+ """创建默认回调"""
276
+ handler = CallbackHandler()
277
+
278
+ # 模型检查点
279
+ checkpoint_callback = ModelCheckpoint(
280
+ save_dir=config.get('checkpoint_dir', './checkpoints'),
281
+ save_best_only=config.get('save_best_model', True),
282
+ save_freq=config.get('save_checkpoint_every', 1),
283
+ monitor='val_loss',
284
+ mode='min'
285
+ )
286
+ handler.add_callback(checkpoint_callback)
287
+
288
+ # TensorBoard日志
289
+ if config.get('use_tensorboard', True):
290
+ tb_logger = TensorBoardLogger(
291
+ log_dir=config.get('log_dir', './logs')
292
+ )
293
+ handler.add_callback(tb_logger)
294
+
295
+ # 样本生成
296
+ if config.get('sample_steps', 500) > 0:
297
+ sample_callback = SampleGeneratorCallback(
298
+ sample_freq=config.get('sample_steps', 500),
299
+ num_samples=4,
300
+ save_dir=config.get('sample_dir', './samples')
301
+ )
302
+ handler.add_callback(sample_callback)
303
+
304
+ # 内存监控
305
+ memory_callback = MemoryMonitorCallback(
306
+ monitor_freq=config.get('log_steps', 50)
307
+ )
308
+ handler.add_callback(memory_callback)
309
+
310
+ # 梯度监控
311
+ grad_callback = GradientMonitorCallback(
312
+ monitor_freq=config.get('log_steps', 50)
313
+ )
314
+ handler.add_callback(grad_callback)
315
+
316
+ # 早停
317
+ if config.get('early_stopping', False):
318
+ early_stop = EarlyStopping(
319
+ patience=config.get('early_stopping_patience', 10),
320
+ min_delta=config.get('early_stopping_min_delta', 1e-4)
321
+ )
322
+ handler.add_callback(early_stop)
323
+
324
+ return handler
src/training/memory_manager.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ from typing import Optional
4
+ import psutil
5
+ import os
6
+
7
+
8
+ class CPUMemoryManager:
9
+ """CPU内存管理器"""
10
+ def __init__(self, warning_threshold: float = 0.9):
11
+ """
12
+ 参数:
13
+ warning_threshold: 内存使用率警告阈值 (0-1)
14
+ """
15
+ self.warning_threshold = warning_threshold
16
+
17
+ def get_memory_usage(self) -> tuple:
18
+ """获取内存使用情况"""
19
+ process = psutil.Process(os.getpid())
20
+ memory_info = process.memory_info()
21
+
22
+ # 获取系统内存信息
23
+ system_memory = psutil.virtual_memory()
24
+
25
+ return {
26
+ 'process_rss_mb': memory_info.rss / 1024 / 1024,
27
+ 'process_vms_mb': memory_info.vms / 1024 / 1024,
28
+ 'system_total_mb': system_memory.total / 1024 / 1024,
29
+ 'system_available_mb': system_memory.available / 1024 / 1024,
30
+ 'system_used_percent': system_memory.percent
31
+ }
32
+
33
+ def check_memory(self) -> bool:
34
+ """检查内存使用是否安全"""
35
+ memory_info = self.get_memory_usage()
36
+
37
+ if memory_info['system_used_percent'] > self.warning_threshold * 100:
38
+ print(f"警告: 系统内存使用率过高: {memory_info['system_used_percent']:.1f}%")
39
+ return False
40
+
41
+ return True
42
+
43
+
44
+ class OptimizerCPUOffload:
45
+ """优化器状态CPU卸载"""
46
+ def __init__(self, optimizer: torch.optim.Optimizer):
47
+ self.optimizer = optimizer
48
+ self.original_states = {}
49
+
50
+ def offload_to_cpu(self):
51
+ """将优化器状态卸载到CPU"""
52
+ for param_group in self.optimizer.param_groups:
53
+ for param in param_group['params']:
54
+ if param in self.optimizer.state:
55
+ state = self.optimizer.state[param]
56
+ for key in list(state.keys()):
57
+ if torch.is_tensor(state[key]):
58
+ # 移动到CPU并保留引用
59
+ self.original_states[(param, key)] = state[key]
60
+ state[key] = state[key].cpu()
61
+
62
+ def load_to_gpu(self, device: torch.device):
63
+ """将优化器状态加载回GPU"""
64
+ for param_group in self.optimizer.param_groups:
65
+ for param in param_group['params']:
66
+ if param in self.optimizer.state:
67
+ state = self.optimizer.state[param]
68
+ for key in list(state.keys()):
69
+ if (param, key) in self.original_states:
70
+ state[key] = self.original_states[(param, key)].to(device)
71
+ del self.original_states[(param, key)]
72
+
73
+
74
+ class ActivationCPUOffload:
75
+ """激活值CPU卸载"""
76
+ def __init__(self, model: torch.nn.Module):
77
+ self.model = model
78
+ self.hooks = []
79
+
80
+ def register_hooks(self):
81
+ """注册前向钩子来卸载激活值"""
82
+ def hook_fn(module, input, output):
83
+ if torch.is_tensor(output):
84
+ return output.cpu()
85
+ elif isinstance(output, tuple):
86
+ return tuple(x.cpu() if torch.is_tensor(x) else x for x in output)
87
+ return output
88
+
89
+ # 为每个模块注册钩子
90
+ for name, module in self.model.named_modules():
91
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm)):
92
+ hook = module.register_forward_hook(hook_fn)
93
+ self.hooks.append(hook)
94
+
95
+ def remove_hooks(self):
96
+ """移除所有钩子"""
97
+ for hook in self.hooks:
98
+ hook.remove()
99
+ self.hooks = []
100
+
101
+
102
+ class MemoryOptimizer:
103
+ """综合内存优化器"""
104
+ def __init__(self, config: dict):
105
+ self.config = config
106
+
107
+ # GPU内存管理
108
+ self.gpu_warning_threshold = config.get('warning_threshold_gb', 6.0) * 1024**3
109
+ self.gpu_critical_threshold = config.get('memory_threshold_gb', 6.5) * 1024**3
110
+
111
+ # CPU内存管理
112
+ self.cpu_manager = CPUMemoryManager(
113
+ warning_threshold=config.get('cpu_warning_threshold', 0.85)
114
+ )
115
+
116
+ # 清理频率
117
+ self.cleanup_frequency = config.get('cleanup_frequency', 100)
118
+
119
+ # 状态跟踪
120
+ self.optimizer_offloader = None
121
+ self.activation_offloader = None
122
+
123
+ def setup_model_optimizations(self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None):
124
+ """设置模型优化"""
125
+ # 启用梯度检查点
126
+ if hasattr(model, 'enable_gradient_checkpointing'):
127
+ model.enable_gradient_checkpointing()
128
+
129
+ # 设置优化器CPU卸载
130
+ if optimizer is not None and self.config.get('optimizer_on_cpu', True):
131
+ self.optimizer_offloader = OptimizerCPUOffload(optimizer)
132
+
133
+ # 设置激活值CPU卸载
134
+ if self.config.get('cpu_offload', True):
135
+ self.activation_offloader = ActivationCPUOffload(model)
136
+ self.activation_offloader.register_hooks()
137
+
138
+ # 设置注意力分片
139
+ if self.config.get('attention_slicing', 'auto') == 'auto':
140
+ self._enable_attention_slicing(model)
141
+
142
+ def _enable_attention_slicing(self, model: torch.nn.Module):
143
+ """启用注意力分片"""
144
+ for module in model.modules():
145
+ if hasattr(module, 'set_attention_slice'):
146
+ module.set_attention_slice('auto')
147
+
148
+ def step_start(self):
149
+ """训练步骤开始时的内存管理"""
150
+ # 将优化器状态加载到GPU(如果需要)
151
+ if self.optimizer_offloader is not None:
152
+ device = next(self.optimizer_offloader.optimizer.param_groups[0]['params'][0].device)
153
+ self.optimizer_offloader.load_to_gpu(device)
154
+
155
+ # 检查内存
156
+ self.check_all_memory()
157
+
158
+ def step_end(self, step: int):
159
+ """训练步骤结束时的内存管理"""
160
+ # 将优化器状态卸载到CPU
161
+ if self.optimizer_offloader is not None:
162
+ self.optimizer_offloader.offload_to_cpu()
163
+
164
+ # 定期清理
165
+ if step % self.cleanup_frequency == 0:
166
+ self.cleanup()
167
+
168
+ # 检查内存
169
+ self.check_all_memory()
170
+
171
+ def check_all_memory(self):
172
+ """检查所有内存"""
173
+ # 检查GPU内存
174
+ gpu_allocated = torch.cuda.memory_allocated()
175
+ if gpu_allocated > self.gpu_critical_threshold:
176
+ self._handle_gpu_oom()
177
+ elif gpu_allocated > self.gpu_warning_threshold:
178
+ print(f"GPU内存警告: {gpu_allocated / 1024**3:.2f} GB")
179
+
180
+ # 检查CPU内存
181
+ if not self.cpu_manager.check_memory():
182
+ self._handle_cpu_oom()
183
+
184
+ def _handle_gpu_oom(self):
185
+ """处理GPU OOM"""
186
+ print("GPU内存不足,尝试清理...")
187
+ self.cleanup(force=True)
188
+
189
+ # 如果仍然不足,抛出异常
190
+ if torch.cuda.memory_allocated() > self.gpu_critical_threshold:
191
+ raise RuntimeError("GPU内存不足,无法继续训练")
192
+
193
+ def _handle_cpu_oom(self):
194
+ """处理CPU OOM"""
195
+ print("CPU内存不足,尝试清理...")
196
+ gc.collect()
197
+
198
+ def cleanup(self, force: bool = False):
199
+ """清理内存"""
200
+ gc.collect()
201
+
202
+ if torch.cuda.is_available():
203
+ torch.cuda.empty_cache()
204
+
205
+ # 如果强制清理,尝试更激进的清理
206
+ if force:
207
+ torch.cuda.synchronize()
208
+ torch.cuda.ipc_collect()
209
+
210
+ def get_memory_stats(self) -> dict:
211
+ """获取内存统计信息"""
212
+ stats = {}
213
+
214
+ # GPU统计
215
+ if torch.cuda.is_available():
216
+ stats['gpu'] = {
217
+ 'allocated_gb': torch.cuda.memory_allocated() / 1024**3,
218
+ 'reserved_gb': torch.cuda.memory_reserved() / 1024**3,
219
+ 'max_allocated_gb': torch.cuda.max_memory_allocated() / 1024**3,
220
+ }
221
+
222
+ # CPU统计
223
+ cpu_stats = self.cpu_manager.get_memory_usage()
224
+ stats['cpu'] = cpu_stats
225
+
226
+ return stats
227
+
228
+ def print_memory_stats(self):
229
+ """打印内存统计信息"""
230
+ stats = self.get_memory_stats()
231
+
232
+ print("=" * 50)
233
+ print("内存使用统计:")
234
+
235
+ if 'gpu' in stats:
236
+ gpu = stats['gpu']
237
+ print(f"GPU - 已分配: {gpu['allocated_gb']:.2f} GB, "
238
+ f"已保留: {gpu['reserved_gb']:.2f} GB, "
239
+ f"最大分配: {gpu['max_allocated_gb']:.2f} GB")
240
+
241
+ if 'cpu' in stats:
242
+ cpu = stats['cpu']
243
+ print(f"CPU - 进程RSS: {cpu['process_rss_mb']:.1f} MB, "
244
+ f"系统使用率: {cpu['system_used_percent']:.1f}%")
245
+ print("=" * 50)
src/training/trainer_p4.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.cuda.amp import autocast, GradScaler
4
+ from torch.utils.data import DataLoader
5
+ from typing import Optional, Dict, Any
6
+ import wandb
7
+ import os
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+
11
+
12
+ class MemoryManager:
13
+ """P4显存管理器"""
14
+ def __init__(self, config: dict):
15
+ self.config = config
16
+ self.warning_threshold = config.get('warning_threshold_gb', 6.0) * 1024**3
17
+ self.critical_threshold = config.get('memory_threshold_gb', 6.5) * 1024**3
18
+ self.cleanup_frequency = config.get('cleanup_frequency', 100)
19
+
20
+ def check_memory(self, step: int):
21
+ """检查显存使用情况"""
22
+ if step % self.cleanup_frequency == 0:
23
+ self.cleanup()
24
+
25
+ allocated = torch.cuda.memory_allocated()
26
+ if allocated > self.critical_threshold:
27
+ raise RuntimeError(f"显存超出临界阈值: {allocated / 1024**3:.2f} GB")
28
+ elif allocated > self.warning_threshold:
29
+ print(f"警告: 显存使用较高: {allocated / 1024**3:.2f} GB")
30
+
31
+ def cleanup(self):
32
+ """清理显存"""
33
+ import gc
34
+ gc.collect()
35
+ torch.cuda.empty_cache()
36
+
37
+ def auto_adjust_batch_size(self, model: nn.Module, data_shape: tuple) -> int:
38
+ """自动调整批次大小"""
39
+ max_batch = 1
40
+ device = next(model.parameters()).device
41
+
42
+ for batch_size in [1, 2, 4, 8]:
43
+ try:
44
+ # 测试内存
45
+ dummy_input = torch.randn(batch_size, *data_shape, device=device)
46
+ dummy_timestep = torch.randint(0, 1000, (batch_size,), device=device)
47
+ dummy_context = torch.randn(batch_size, 77, 768, device=device)
48
+
49
+ with torch.no_grad():
50
+ _ = model(dummy_input, dummy_timestep, dummy_context)
51
+
52
+ torch.cuda.empty_cache()
53
+ max_batch = batch_size
54
+ except RuntimeError as e:
55
+ if "CUDA out of memory" in str(e):
56
+ break
57
+ else:
58
+ raise e
59
+
60
+ return max_batch
61
+
62
+
63
+ class GradientAccumulationScheduler:
64
+ """梯度累积调度器"""
65
+ def __init__(self, config: dict):
66
+ self.initial_steps = config.get('gradient_accumulation_steps', 8)
67
+ self.current_steps = self.initial_steps
68
+ self.warmup_epochs = config.get('warmup_epochs', 5)
69
+
70
+ def update(self, epoch: int):
71
+ """根据epoch更新累积步数"""
72
+ if epoch < self.warmup_epochs:
73
+ self.current_steps = self.initial_steps
74
+ else:
75
+ # 逐步减少累积步数以加快训练
76
+ self.current_steps = max(4, self.current_steps // 2)
77
+
78
+
79
+ class P4Trainer:
80
+ """针对P4优化的训练器"""
81
+ def __init__(
82
+ self,
83
+ model: nn.Module,
84
+ diffusion: DiffusionProcess,
85
+ optimizer: torch.optim.Optimizer,
86
+ train_loader: DataLoader,
87
+ val_loader: Optional[DataLoader],
88
+ config: dict,
89
+ device: torch.device
90
+ ):
91
+ self.model = model
92
+ self.diffusion = diffusion
93
+ self.optimizer = optimizer
94
+ self.train_loader = train_loader
95
+ self.val_loader = val_loader
96
+ self.config = config
97
+ self.device = device
98
+
99
+ # 训练状态
100
+ self.current_epoch = 0
101
+ self.global_step = 0
102
+ self.best_loss = float('inf')
103
+
104
+ # 初始化工具
105
+ self.memory_manager = MemoryManager(config)
106
+ self.grad_scheduler = GradientAccumulationScheduler(config)
107
+
108
+ # 混合精度训练
109
+ self.use_amp = config.get('mixed_precision', 'fp16') != 'no'
110
+ self.scaler = GradScaler(enabled=self.use_amp)
111
+
112
+ # 学习率调度器
113
+ self.lr_scheduler = self._create_lr_scheduler(config)
114
+
115
+ # EMA模型
116
+ self.use_ema = config.get('use_ema', True)
117
+ if self.use_ema:
118
+ self.ema_model = self._create_ema_model(model, config.get('ema_decay', 0.9999))
119
+
120
+ # 日志记录
121
+ self.use_wandb = config.get('use_wandb', False)
122
+ self.log_dir = config.get('log_dir', './logs')
123
+ os.makedirs(self.log_dir, exist_ok=True)
124
+
125
+ # 检查点
126
+ self.checkpoint_dir = config.get('checkpoint_dir', './checkpoints')
127
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
128
+
129
+ def _create_lr_scheduler(self, config: dict):
130
+ """创建学习率调度器"""
131
+ scheduler_type = config.get('learning_rate_scheduler', 'cosine')
132
+ warmup_steps = config.get('warmup_steps', 1000)
133
+
134
+ if scheduler_type == 'cosine':
135
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
136
+ self.optimizer,
137
+ T_max=config.get('max_epochs', 50) * len(self.train_loader),
138
+ eta_min=1e-6
139
+ )
140
+ elif scheduler_type == 'linear':
141
+ scheduler = torch.optim.lr_scheduler.LinearLR(
142
+ self.optimizer,
143
+ start_factor=0.01,
144
+ total_iters=warmup_steps
145
+ )
146
+ else:
147
+ scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0)
148
+
149
+ return scheduler
150
+
151
+ def _create_ema_model(self, model: nn.Module, decay: float):
152
+ """创建EMA模型"""
153
+ from torch.optim.swa_utils import AveragedModel
154
+ return AveragedModel(model, device=self.device, avg_fn=lambda avg, new, decay=decay: decay * avg + (1 - decay) * new)
155
+
156
+ def train_epoch(self) -> float:
157
+ """训练一个epoch"""
158
+ self.model.train()
159
+ total_loss = 0.0
160
+ num_batches = len(self.train_loader)
161
+
162
+ # 梯度累积
163
+ accumulation_steps = self.grad_scheduler.current_steps
164
+ self.optimizer.zero_grad()
165
+
166
+ pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}")
167
+ for batch_idx, batch in enumerate(pbar):
168
+ # 将数据移到设备
169
+ images = batch['images'].to(self.device)
170
+ text_embeddings = batch['text_embeddings'].to(self.device)
171
+
172
+ # 混合精度前向传播
173
+ with autocast(enabled=self.use_amp):
174
+ loss = self.diffusion.compute_loss(images, text_embeddings)
175
+ loss = loss / accumulation_steps
176
+
177
+ # 反向传播
178
+ self.scaler.scale(loss).backward()
179
+
180
+ # 梯度累积更新
181
+ if (batch_idx + 1) % accumulation_steps == 0:
182
+ # 梯度裁剪
183
+ self.scaler.unscale_(self.optimizer)
184
+ torch.nn.utils.clip_grad_norm_(
185
+ self.model.parameters(),
186
+ max_norm=self.config.get('gradient_clip', 1.0)
187
+ )
188
+
189
+ # 更新参数
190
+ self.scaler.step(self.optimizer)
191
+ self.scaler.update()
192
+ self.optimizer.zero_grad()
193
+
194
+ # 更新EMA模型
195
+ if self.use_ema:
196
+ self.ema_model.update_parameters(self.model)
197
+
198
+ # 更新学习率
199
+ self.lr_scheduler.step()
200
+
201
+ self.global_step += 1
202
+
203
+ # 记录损失
204
+ total_loss += loss.item() * accumulation_steps
205
+ current_loss = total_loss / (batch_idx + 1)
206
+
207
+ # 更新进度条
208
+ pbar.set_postfix({
209
+ 'loss': f'{current_loss:.4f}',
210
+ 'lr': f'{self.optimizer.param_groups[0]["lr"]:.2e}'
211
+ })
212
+
213
+ # 记录日志
214
+ if self.global_step % self.config.get('log_steps', 50) == 0:
215
+ self._log_metrics({
216
+ 'train/loss': current_loss,
217
+ 'train/lr': self.optimizer.param_groups[0]['lr'],
218
+ 'train/grad_norm': self._get_grad_norm(),
219
+ })
220
+
221
+ # 生成样本
222
+ if self.global_step % self.config.get('sample_steps', 500) == 0:
223
+ self._generate_samples()
224
+
225
+ # 显存管理
226
+ self.memory_manager.check_memory(self.global_step)
227
+
228
+ epoch_loss = total_loss / num_batches
229
+ return epoch_loss
230
+
231
+ @torch.no_grad()
232
+ def validate(self) -> float:
233
+ """验证"""
234
+ if self.val_loader is None:
235
+ return float('inf')
236
+
237
+ self.model.eval()
238
+ total_loss = 0.0
239
+
240
+ for batch in tqdm(self.val_loader, desc="Validation"):
241
+ images = batch['images'].to(self.device)
242
+ text_embeddings = batch['text_embeddings'].to(self.device)
243
+
244
+ with autocast(enabled=self.use_amp):
245
+ loss = self.diffusion.compute_loss(images, text_embeddings)
246
+
247
+ total_loss += loss.item()
248
+
249
+ val_loss = total_loss / len(self.val_loader)
250
+
251
+ # 记录验证指标
252
+ self._log_metrics({'val/loss': val_loss})
253
+
254
+ return val_loss
255
+
256
+ def train(self, num_epochs: Optional[int] = None):
257
+ """训练循环"""
258
+ if num_epochs is None:
259
+ num_epochs = self.config.get('max_epochs', 50)
260
+
261
+ for epoch in range(self.current_epoch, num_epochs):
262
+ self.current_epoch = epoch
263
+
264
+ # 更新梯度累积策略
265
+ self.grad_scheduler.update(epoch)
266
+
267
+ # 训练一个epoch
268
+ train_loss = self.train_epoch()
269
+
270
+ # 验证
271
+ val_loss = self.validate()
272
+
273
+ # 保存最佳模型
274
+ if val_loss < self.best_loss:
275
+ self.best_loss = val_loss
276
+ self.save_checkpoint('best_model.pt')
277
+
278
+ # 定期保存检查点
279
+ if (epoch + 1) % self.config.get('save_checkpoint_every', 5) == 0:
280
+ self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pt')
281
+
282
+ # 记录epoch指标
283
+ self._log_metrics({
284
+ 'epoch/train_loss': train_loss,
285
+ 'epoch/val_loss': val_loss,
286
+ 'epoch': epoch
287
+ })
288
+
289
+ print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
290
+
291
+ def _get_grad_norm(self) -> float:
292
+ """计算梯度范数"""
293
+ total_norm = 0.0
294
+ for p in self.model.parameters():
295
+ if p.grad is not None:
296
+ param_norm = p.grad.data.norm(2)
297
+ total_norm += param_norm.item() ** 2
298
+ return total_norm ** 0.5
299
+
300
+ @torch.no_grad()
301
+ def _generate_samples(self):
302
+ """生成样本用于监控"""
303
+ self.model.eval()
304
+
305
+ # 使用验证集的前几个提示
306
+ sample_batch = next(iter(self.val_loader))
307
+ text_embeddings = sample_batch['text_embeddings'][:4].to(self.device)
308
+
309
+ # 生成样本
310
+ with autocast(enabled=self.use_amp):
311
+ latents = self.diffusion.generate(
312
+ context=text_embeddings,
313
+ num_samples=4,
314
+ guidance_scale=7.5
315
+ )
316
+
317
+ # 解码为图像
318
+ # 这里需要VAE解码器,暂时保存潜在表示
319
+ if self.use_wandb:
320
+ wandb.log({
321
+ 'samples/latents': wandb.Image(latents[0].cpu().numpy())
322
+ })
323
+
324
+ def _log_metrics(self, metrics: Dict[str, Any]):
325
+ """记录指标"""
326
+ if self.use_wandb:
327
+ wandb.log(metrics)
328
+
329
+ # 同时记录到本地文件
330
+ log_file = os.path.join(self.log_dir, 'training_log.csv')
331
+ with open(log_file, 'a') as f:
332
+ if self.global_step == 0:
333
+ header = ','.join(['step'] + list(metrics.keys()))
334
+ f.write(header + '\n')
335
+
336
+ values = ','.join([str(self.global_step)] + [str(v) for v in metrics.values()])
337
+ f.write(values + '\n')
338
+
339
+ def save_checkpoint(self, filename: str):
340
+ """保存检查点"""
341
+ checkpoint = {
342
+ 'epoch': self.current_epoch,
343
+ 'global_step': self.global_step,
344
+ 'model_state_dict': self.model.state_dict(),
345
+ 'optimizer_state_dict': self.optimizer.state_dict(),
346
+ 'scaler_state_dict': self.scaler.state_dict(),
347
+ 'best_loss': self.best_loss,
348
+ 'config': self.config
349
+ }
350
+
351
+ if self.use_ema:
352
+ checkpoint['ema_model_state_dict'] = self.ema_model.state_dict()
353
+
354
+ save_path = os.path.join(self.checkpoint_dir, filename)
355
+ torch.save(checkpoint, save_path)
356
+
357
+ # 如果启用压缩,保存压缩版本
358
+ if self.config.get('save_compressed', True):
359
+ torch.save(checkpoint, save_path.replace('.pt', '_compressed.pt'), _use_new_zipfile_serialization=True)
360
+
361
+ print(f"检查点已保存: {save_path}")
362
+
363
+ def load_checkpoint(self, checkpoint_path: str):
364
+ """加载检查点"""
365
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
366
+
367
+ self.model.load_state_dict(checkpoint['model_state_dict'])
368
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
369
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
370
+
371
+ if self.use_ema and 'ema_model_state_dict' in checkpoint:
372
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
373
+
374
+ self.current_epoch = checkpoint['epoch']
375
+ self.global_step = checkpoint['global_step']
376
+ self.best_loss = checkpoint['best_loss']
377
+
378
+ print(f"已加载检查点: {checkpoint_path}")
tests/test_basic.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 基础测试
4
+ 测试项目的基本功能
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+
13
+ # 添加项目根目录到Python路径
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
15
+
16
+ from src.models.unet_light import UNetLight, TimestepEmbedding, ResNetBlock
17
+ from src.models.attention import MemoryEfficientAttention
18
+ from src.models.diffusion import DiffusionProcess
19
+
20
+
21
+ def test_timestep_embedding():
22
+ """测试时间步嵌入"""
23
+ print("测试时间步嵌入...")
24
+
25
+ embedding_dim = 256
26
+ time_embed_dim = 512
27
+
28
+ embedder = TimestepEmbedding(embedding_dim, time_embed_dim)
29
+
30
+ # 测试前向传播
31
+ timesteps = torch.tensor([100, 200, 300])
32
+ embeddings = embedder(timesteps)
33
+
34
+ assert embeddings.shape == (3, time_embed_dim)
35
+ print(f" 形状正确: {embeddings.shape}")
36
+
37
+ return embedder
38
+
39
+
40
+ def test_resnet_block():
41
+ """测试残差块"""
42
+ print("测试残差块...")
43
+
44
+ in_channels = 64
45
+ out_channels = 128
46
+ time_embed_dim = 256
47
+
48
+ block = ResNetBlock(in_channels, out_channels, time_embed_dim)
49
+
50
+ # 测试前向传播
51
+ x = torch.randn(2, in_channels, 32, 32)
52
+ time_emb = torch.randn(2, time_embed_dim)
53
+
54
+ output = block(x, time_emb)
55
+
56
+ assert output.shape == (2, out_channels, 32, 32)
57
+ print(f" 形状正确: {output.shape}")
58
+
59
+ # 测试跳跃连接
60
+ block_same = ResNetBlock(in_channels, in_channels, time_embed_dim)
61
+ output_same = block_same(x, time_emb)
62
+
63
+ assert output_same.shape == x.shape
64
+ print(f" 跳跃连接正确")
65
+
66
+ return block
67
+
68
+
69
+ def test_attention():
70
+ """测试注意力机制"""
71
+ print("测试注意力机制...")
72
+
73
+ dim = 256
74
+ num_heads = 8
75
+
76
+ attention = MemoryEfficientAttention(dim, num_heads)
77
+
78
+ # 测试前向传播
79
+ x = torch.randn(2, 16, dim) # [batch, seq_len, dim]
80
+ output = attention(x)
81
+
82
+ assert output.shape == x.shape
83
+ print(f" 形状正确: {output.shape}")
84
+
85
+ return attention
86
+
87
+
88
+ def test_unet_light():
89
+ """测试轻量UNet"""
90
+ print("测试轻量UNet...")
91
+
92
+ config = {
93
+ 'model': {
94
+ 'in_channels': 4,
95
+ 'out_channels': 4,
96
+ 'base_channels': 32, # 测试用小模型
97
+ 'channel_mults': [1, 2, 4],
98
+ 'num_res_blocks': 1,
99
+ 'attention_resolutions': [8],
100
+ 'dropout': 0.0,
101
+ 'use_checkpoint': False,
102
+ 'num_heads': 4,
103
+ 'context_dim': 256,
104
+ 'use_linear_projection': True,
105
+ 'time_embed_dim': 128
106
+ }
107
+ }
108
+
109
+ model = UNetLight(config)
110
+
111
+ # 测试前向传播
112
+ batch_size = 2
113
+ x = torch.randn(batch_size, 4, 64, 64)
114
+ timesteps = torch.randint(0, 1000, (batch_size,))
115
+ context = torch.randn(batch_size, 77, 256)
116
+
117
+ output = model(x, timesteps, context)
118
+
119
+ assert output.shape == x.shape
120
+ print(f" 形状正确: {output.shape}")
121
+
122
+ # 测试梯度检查点
123
+ model.enable_gradient_checkpointing()
124
+ print(f" 梯度检查点已启用")
125
+
126
+ return model
127
+
128
+
129
+ def test_diffusion_process():
130
+ """测试扩散过程"""
131
+ print("测试扩散过程...")
132
+
133
+ config = {
134
+ 'diffusion': {
135
+ 'beta_schedule': 'linear',
136
+ 'beta_start': 0.0001,
137
+ 'beta_end': 0.02,
138
+ 'num_train_timesteps': 100,
139
+ 'num_inference_timesteps': 20
140
+ }
141
+ }
142
+
143
+ diffusion = DiffusionProcess(config)
144
+
145
+ # 测试前向扩散
146
+ x_start = torch.randn(2, 3, 32, 32)
147
+ t = torch.randint(0, 100, (2,))
148
+
149
+ x_noisy = diffusion.q_sample(x_start, t)
150
+
151
+ assert x_noisy.shape == x_start.shape
152
+ print(f" 前向扩散形状正确: {x_noisy.shape}")
153
+
154
+ # 测试参数提取
155
+ extracted = diffusion.extract(diffusion.sqrt_alphas_cumprod, t, x_start.shape)
156
+ assert extracted.shape == (2, 1, 1, 1)
157
+ print(f" 参数提取形状正确: {extracted.shape}")
158
+
159
+ return diffusion
160
+
161
+
162
+ def test_memory_efficiency():
163
+ """测试内存效率"""
164
+ print("测试内存效率...")
165
+
166
+ # 测试模型在不同批次大小下的内存使用
167
+ config = {
168
+ 'model': {
169
+ 'in_channels': 4,
170
+ 'out_channels': 4,
171
+ 'base_channels': 32,
172
+ 'channel_mults': [1, 2],
173
+ 'num_res_blocks': 1,
174
+ 'attention_resolutions': [],
175
+ 'dropout': 0.0,
176
+ 'use_checkpoint': False,
177
+ 'num_heads': 4,
178
+ 'context_dim': 256,
179
+ 'use_linear_projection': True,
180
+ 'time_embed_dim': 128
181
+ }
182
+ }
183
+
184
+ model = UNetLight(config)
185
+ model.eval()
186
+
187
+ if torch.cuda.is_available():
188
+ device = torch.device('cuda')
189
+ model = model.to(device)
190
+
191
+ print(" GPU内存测试:")
192
+
193
+ for batch_size in [1, 2, 4]:
194
+ # 清空缓存
195
+ torch.cuda.empty_cache()
196
+
197
+ # 记录初始内存
198
+ initial_memory = torch.cuda.memory_allocated()
199
+
200
+ # 前向传播
201
+ x = torch.randn(batch_size, 4, 64, 64, device=device)
202
+ t = torch.randint(0, 1000, (batch_size,), device=device)
203
+ context = torch.randn(batch_size, 77, 256, device=device)
204
+
205
+ with torch.no_grad():
206
+ _ = model(x, t, context)
207
+
208
+ # 记录峰值内存
209
+ peak_memory = torch.cuda.max_memory_allocated()
210
+ memory_used = (peak_memory - initial_memory) / 1024**3 # GB
211
+
212
+ print(f" 批次大小 {batch_size}: {memory_used:.2f} GB")
213
+
214
+ else:
215
+ print(" GPU不可用,跳过内存测试")
216
+
217
+ return model
218
+
219
+
220
+ def run_all_tests():
221
+ """运行所有测试"""
222
+ print("=" * 60)
223
+ print("运行Lumina基础测试")
224
+ print("=" * 60)
225
+
226
+ try:
227
+ # 测试各个组件
228
+ test_timestep_embedding()
229
+ test_resnet_block()
230
+ test_attention()
231
+ test_unet_light()
232
+ test_diffusion_process()
233
+ test_memory_efficiency()
234
+
235
+ print("\n" + "=" * 60)
236
+ print("所有测试通过!")
237
+ print("=" * 60)
238
+
239
+ return True
240
+
241
+ except Exception as e:
242
+ print(f"\n测试失败: {e}")
243
+ import traceback
244
+ traceback.print_exc()
245
+ return False
246
+
247
+
248
+ if __name__ == "__main__":
249
+ success = run_all_tests()
250
+ sys.exit(0 if success else 1)