#!/usr/bin/env python3 """ Lumina训练脚本 用于训练轻量级图像生成模型 """ import os import sys import argparse import yaml import torch import torch.nn as nn from torch.utils.data import DataLoader import warnings # 添加项目根目录到Python路径 sys.path.append(os.path.dirname(os.path.dirname(__file__))) from src.models.unet_light import UNetLight from src.models.diffusion import DiffusionProcess, DiffusionModel from src.data.dataset import create_data_loaders from src.data.text_encoder import create_text_encoder from src.training.trainer_p4 import P4Trainer from src.training.memory_manager import MemoryOptimizer from src.training.callbacks import create_default_callbacks def load_config(config_path: str) -> dict: """加载配置文件""" with open(config_path, 'r') as f: config = yaml.safe_load(f) return config def setup_environment(config: dict): """设置训练环境""" # 设置随机种子 seed = config.get('seed', 42) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # 设置CUDA设备 device = config.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') if device == 'cuda' and not torch.cuda.is_available(): warnings.warn("CUDA不可用,使用CPU") device = 'cpu' # 创建输出目录 output_dir = config.get('output_dir', './output') os.makedirs(output_dir, exist_ok=True) # 设置日志 log_dir = config.get('log_dir', './logs') os.makedirs(log_dir, exist_ok=True) print(f"环境设置完成:") print(f" 设备: {device}") print(f" 随机种子: {seed}") print(f" 输出目录: {output_dir}") print(f" 日志目录: {log_dir}") return device def create_model(config: dict, device: torch.device) -> nn.Module: """创建模型""" # 加载模型配置 model_config_path = config.get('model_config', 'configs/model/unet_light.yaml') model_config = load_config(model_config_path) # 创建UNet模型 model = UNetLight(model_config) # 加载预训练权重(如果有) pretrained_path = config.get('pretrained_path') if pretrained_path and os.path.exists(pretrained_path): print(f"加载预训练权重: {pretrained_path}") checkpoint = torch.load(pretrained_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) # 移动到设备 model = model.to(device) # 打印模型信息 total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"模型创建完成:") print(f" 总参数量: {total_params:,}") print(f" 可训练参数量: {trainable_params:,}") print(f" 模型大小: {total_params * 4 / 1024**2:.2f} MB (fp32)") return model def create_diffusion(config: dict) -> DiffusionProcess: """创建扩散过程""" diffusion_config_path = config.get('diffusion_config', 'configs/model/diffusion.yaml') diffusion_config = load_config(diffusion_config_path) diffusion = DiffusionProcess(diffusion_config) print(f"扩散过程创建完成:") print(f" 训练时间步: {diffusion.num_train_timesteps}") print(f" 推理时间步: {diffusion.num_inference_timesteps}") print(f" Beta调度: {diffusion.beta_schedule}") return diffusion def create_data_pipeline(config: dict): """创建数据管道""" data_config_path = config.get('data_config', 'configs/data/laion_filtered.yaml') data_config = load_config(data_config_path) # 创建文本编码器 text_encoder = create_text_encoder(data_config) # 创建数据加载器 train_loader, val_loader = create_data_loaders(data_config) print(f"数据管道创建完成:") print(f" 训练集大小: {len(train_loader.dataset)}") print(f" 验证集大小: {len(val_loader.dataset) if val_loader else 0}") print(f" 批次大小: {train_loader.batch_size}") print(f" 梯度累积步数: {config.get('gradient_accumulation_steps', 8)}") return train_loader, val_loader, text_encoder def create_optimizer(model: nn.Module, config: dict): """创建优化器""" optimizer_config = config.get('optimizer', {}) optimizer_type = optimizer_config.get('type', 'AdamW') learning_rate = optimizer_config.get('learning_rate', 1e-4) weight_decay = optimizer_config.get('weight_decay', 0.01) if optimizer_type == 'AdamW': optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.999), eps=1e-8 ) elif optimizer_type == 'Adam': optimizer = torch.optim.Adam( model.parameters(), lr=learning_rate, weight_decay=weight_decay ) else: raise ValueError(f"未知的优化器类型: {optimizer_type}") print(f"优化器创建完成:") print(f" 类型: {optimizer_type}") print(f" 学习率: {learning_rate}") print(f" 权重衰减: {weight_decay}") return optimizer def setup_memory_optimization(model: nn.Module, optimizer, config: dict): """设置内存优化""" memory_optimizer = MemoryOptimizer(config) memory_optimizer.setup_model_optimizations(model, optimizer) # 打印内存信息 if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / 1024**3 reserved = torch.cuda.memory_reserved() / 1024**3 print(f"内存优化设置完成:") print(f" GPU已分配: {allocated:.2f} GB") print(f" GPU已保留: {reserved:.2f} GB") return memory_optimizer def train(config_path: str, resume_from: str = None): """训练主函数""" print("=" * 60) print("Lumina 训练开始") print("=" * 60) # 加载配置 config = load_config(config_path) # 设置环境 device = setup_environment(config) # 创建模型 model = create_model(config, device) # 创建扩散过程 diffusion = create_diffusion(config) # 创建扩散模型 diffusion_model = DiffusionModel(model, diffusion) # 创建数据管道 train_loader, val_loader, text_encoder = create_data_pipeline(config) # 创建优化器 optimizer = create_optimizer(model, config) # 设置内存优化 memory_optimizer = setup_memory_optimization(model, optimizer, config) # 创建训练器 trainer = P4Trainer( model=model, diffusion=diffusion, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, config=config, device=device ) # 创建回调 callbacks = create_default_callbacks(config) # 加载检查点(如果存在) if resume_from and os.path.exists(resume_from): print(f"从检查点恢复训练: {resume_from}") trainer.load_checkpoint(resume_from) # 开始训练 try: print("\n开始训练...") trainer.train() print("\n" + "=" * 60) print("训练完成!") print(f"最佳验证损失: {trainer.best_loss:.4f}") print(f"总训练步数: {trainer.global_step}") print("=" * 60) except KeyboardInterrupt: print("\n训练被中断") except Exception as e: print(f"\n训练出错: {e}") import traceback traceback.print_exc() finally: # 保存最终检查点 final_checkpoint = os.path.join( config.get('checkpoint_dir', './checkpoints'), 'final_model.pt' ) trainer.save_checkpoint(final_checkpoint) def main(): """主函数""" parser = argparse.ArgumentParser(description="训练Lumina图像生成模型") parser.add_argument( "--config", type=str, default="configs/training/p4_optimized.yaml", help="训练配置文件路径" ) parser.add_argument( "--resume", type=str, help="从检查点恢复训练" ) parser.add_argument( "--debug", action="store_true", help="调试模式" ) args = parser.parse_args() # 调试模式设置 if args.debug: import warnings warnings.filterwarnings("always") torch.autograd.set_detect_anomaly(True) print("调试模式已启用") # 开始训练 train(args.config, args.resume) if __name__ == "__main__": main()