import torch import torch.nn as nn from torch.cuda.amp import autocast, GradScaler from torch.utils.data import DataLoader from typing import Optional, Dict, Any import wandb import os from tqdm import tqdm import numpy as np class MemoryManager: """P4显存管理器""" def __init__(self, config: dict): self.config = config self.warning_threshold = config.get('warning_threshold_gb', 6.0) * 1024**3 self.critical_threshold = config.get('memory_threshold_gb', 6.5) * 1024**3 self.cleanup_frequency = config.get('cleanup_frequency', 100) def check_memory(self, step: int): """检查显存使用情况""" if step % self.cleanup_frequency == 0: self.cleanup() allocated = torch.cuda.memory_allocated() if allocated > self.critical_threshold: raise RuntimeError(f"显存超出临界阈值: {allocated / 1024**3:.2f} GB") elif allocated > self.warning_threshold: print(f"警告: 显存使用较高: {allocated / 1024**3:.2f} GB") def cleanup(self): """清理显存""" import gc gc.collect() torch.cuda.empty_cache() def auto_adjust_batch_size(self, model: nn.Module, data_shape: tuple) -> int: """自动调整批次大小""" max_batch = 1 device = next(model.parameters()).device for batch_size in [1, 2, 4, 8]: try: # 测试内存 dummy_input = torch.randn(batch_size, *data_shape, device=device) dummy_timestep = torch.randint(0, 1000, (batch_size,), device=device) dummy_context = torch.randn(batch_size, 77, 768, device=device) with torch.no_grad(): _ = model(dummy_input, dummy_timestep, dummy_context) torch.cuda.empty_cache() max_batch = batch_size except RuntimeError as e: if "CUDA out of memory" in str(e): break else: raise e return max_batch class GradientAccumulationScheduler: """梯度累积调度器""" def __init__(self, config: dict): self.initial_steps = config.get('gradient_accumulation_steps', 8) self.current_steps = self.initial_steps self.warmup_epochs = config.get('warmup_epochs', 5) def update(self, epoch: int): """根据epoch更新累积步数""" if epoch < self.warmup_epochs: self.current_steps = self.initial_steps else: # 逐步减少累积步数以加快训练 self.current_steps = max(4, self.current_steps // 2) class P4Trainer: """针对P4优化的训练器""" def __init__( self, model: nn.Module, diffusion: DiffusionProcess, optimizer: torch.optim.Optimizer, train_loader: DataLoader, val_loader: Optional[DataLoader], config: dict, device: torch.device ): self.model = model self.diffusion = diffusion self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.config = config self.device = device # 训练状态 self.current_epoch = 0 self.global_step = 0 self.best_loss = float('inf') # 初始化工具 self.memory_manager = MemoryManager(config) self.grad_scheduler = GradientAccumulationScheduler(config) # 混合精度训练 self.use_amp = config.get('mixed_precision', 'fp16') != 'no' self.scaler = GradScaler(enabled=self.use_amp) # 学习率调度器 self.lr_scheduler = self._create_lr_scheduler(config) # EMA模型 self.use_ema = config.get('use_ema', True) if self.use_ema: self.ema_model = self._create_ema_model(model, config.get('ema_decay', 0.9999)) # 日志记录 self.use_wandb = config.get('use_wandb', False) self.log_dir = config.get('log_dir', './logs') os.makedirs(self.log_dir, exist_ok=True) # 检查点 self.checkpoint_dir = config.get('checkpoint_dir', './checkpoints') os.makedirs(self.checkpoint_dir, exist_ok=True) def _create_lr_scheduler(self, config: dict): """创建学习率调度器""" scheduler_type = config.get('learning_rate_scheduler', 'cosine') warmup_steps = config.get('warmup_steps', 1000) if scheduler_type == 'cosine': scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=config.get('max_epochs', 50) * len(self.train_loader), eta_min=1e-6 ) elif scheduler_type == 'linear': scheduler = torch.optim.lr_scheduler.LinearLR( self.optimizer, start_factor=0.01, total_iters=warmup_steps ) else: scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0) return scheduler def _create_ema_model(self, model: nn.Module, decay: float): """创建EMA模型""" from torch.optim.swa_utils import AveragedModel return AveragedModel(model, device=self.device, avg_fn=lambda avg, new, decay=decay: decay * avg + (1 - decay) * new) def train_epoch(self) -> float: """训练一个epoch""" self.model.train() total_loss = 0.0 num_batches = len(self.train_loader) # 梯度累积 accumulation_steps = self.grad_scheduler.current_steps self.optimizer.zero_grad() pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}") for batch_idx, batch in enumerate(pbar): # 将数据移到设备 images = batch['images'].to(self.device) text_embeddings = batch['text_embeddings'].to(self.device) # 混合精度前向传播 with autocast(enabled=self.use_amp): loss = self.diffusion.compute_loss(images, text_embeddings) loss = loss / accumulation_steps # 反向传播 self.scaler.scale(loss).backward() # 梯度累积更新 if (batch_idx + 1) % accumulation_steps == 0: # 梯度裁剪 self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=self.config.get('gradient_clip', 1.0) ) # 更新参数 self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() # 更新EMA模型 if self.use_ema: self.ema_model.update_parameters(self.model) # 更新学习率 self.lr_scheduler.step() self.global_step += 1 # 记录损失 total_loss += loss.item() * accumulation_steps current_loss = total_loss / (batch_idx + 1) # 更新进度条 pbar.set_postfix({ 'loss': f'{current_loss:.4f}', 'lr': f'{self.optimizer.param_groups[0]["lr"]:.2e}' }) # 记录日志 if self.global_step % self.config.get('log_steps', 50) == 0: self._log_metrics({ 'train/loss': current_loss, 'train/lr': self.optimizer.param_groups[0]['lr'], 'train/grad_norm': self._get_grad_norm(), }) # 生成样本 if self.global_step % self.config.get('sample_steps', 500) == 0: self._generate_samples() # 显存管理 self.memory_manager.check_memory(self.global_step) epoch_loss = total_loss / num_batches return epoch_loss @torch.no_grad() def validate(self) -> float: """验证""" if self.val_loader is None: return float('inf') self.model.eval() total_loss = 0.0 for batch in tqdm(self.val_loader, desc="Validation"): images = batch['images'].to(self.device) text_embeddings = batch['text_embeddings'].to(self.device) with autocast(enabled=self.use_amp): loss = self.diffusion.compute_loss(images, text_embeddings) total_loss += loss.item() val_loss = total_loss / len(self.val_loader) # 记录验证指标 self._log_metrics({'val/loss': val_loss}) return val_loss def train(self, num_epochs: Optional[int] = None): """训练循环""" if num_epochs is None: num_epochs = self.config.get('max_epochs', 50) for epoch in range(self.current_epoch, num_epochs): self.current_epoch = epoch # 更新梯度累积策略 self.grad_scheduler.update(epoch) # 训练一个epoch train_loss = self.train_epoch() # 验证 val_loss = self.validate() # 保存最佳模型 if val_loss < self.best_loss: self.best_loss = val_loss self.save_checkpoint('best_model.pt') # 定期保存检查点 if (epoch + 1) % self.config.get('save_checkpoint_every', 5) == 0: self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pt') # 记录epoch指标 self._log_metrics({ 'epoch/train_loss': train_loss, 'epoch/val_loss': val_loss, 'epoch': epoch }) print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}") def _get_grad_norm(self) -> float: """计算梯度范数""" total_norm = 0.0 for p in self.model.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 return total_norm ** 0.5 @torch.no_grad() def _generate_samples(self): """生成样本用于监控""" self.model.eval() # 使用验证集的前几个提示 sample_batch = next(iter(self.val_loader)) text_embeddings = sample_batch['text_embeddings'][:4].to(self.device) # 生成样本 with autocast(enabled=self.use_amp): latents = self.diffusion.generate( context=text_embeddings, num_samples=4, guidance_scale=7.5 ) # 解码为图像 # 这里需要VAE解码器,暂时保存潜在表示 if self.use_wandb: wandb.log({ 'samples/latents': wandb.Image(latents[0].cpu().numpy()) }) def _log_metrics(self, metrics: Dict[str, Any]): """记录指标""" if self.use_wandb: wandb.log(metrics) # 同时记录到本地文件 log_file = os.path.join(self.log_dir, 'training_log.csv') with open(log_file, 'a') as f: if self.global_step == 0: header = ','.join(['step'] + list(metrics.keys())) f.write(header + '\n') values = ','.join([str(self.global_step)] + [str(v) for v in metrics.values()]) f.write(values + '\n') def save_checkpoint(self, filename: str): """保存检查点""" checkpoint = { 'epoch': self.current_epoch, 'global_step': self.global_step, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scaler_state_dict': self.scaler.state_dict(), 'best_loss': self.best_loss, 'config': self.config } if self.use_ema: checkpoint['ema_model_state_dict'] = self.ema_model.state_dict() save_path = os.path.join(self.checkpoint_dir, filename) torch.save(checkpoint, save_path) # 如果启用压缩,保存压缩版本 if self.config.get('save_compressed', True): torch.save(checkpoint, save_path.replace('.pt', '_compressed.pt'), _use_new_zipfile_serialization=True) print(f"检查点已保存: {save_path}") def load_checkpoint(self, checkpoint_path: str): """加载检查点""" checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scaler.load_state_dict(checkpoint['scaler_state_dict']) if self.use_ema and 'ema_model_state_dict' in checkpoint: self.ema_model.load_state_dict(checkpoint['ema_model_state_dict']) self.current_epoch = checkpoint['epoch'] self.global_step = checkpoint['global_step'] self.best_loss = checkpoint['best_loss'] print(f"已加载检查点: {checkpoint_path}")