| import torch
|
| import torch.nn as nn
|
| from torch.cuda.amp import autocast, GradScaler
|
| from torch.utils.data import DataLoader
|
| from typing import Optional, Dict, Any
|
| import wandb
|
| import os
|
| from tqdm import tqdm
|
| import numpy as np
|
|
|
|
|
| class MemoryManager:
|
| """P4显存管理器"""
|
| def __init__(self, config: dict):
|
| self.config = config
|
| self.warning_threshold = config.get('warning_threshold_gb', 6.0) * 1024**3
|
| self.critical_threshold = config.get('memory_threshold_gb', 6.5) * 1024**3
|
| self.cleanup_frequency = config.get('cleanup_frequency', 100)
|
|
|
| def check_memory(self, step: int):
|
| """检查显存使用情况"""
|
| if step % self.cleanup_frequency == 0:
|
| self.cleanup()
|
|
|
| allocated = torch.cuda.memory_allocated()
|
| if allocated > self.critical_threshold:
|
| raise RuntimeError(f"显存超出临界阈值: {allocated / 1024**3:.2f} GB")
|
| elif allocated > self.warning_threshold:
|
| print(f"警告: 显存使用较高: {allocated / 1024**3:.2f} GB")
|
|
|
| def cleanup(self):
|
| """清理显存"""
|
| import gc
|
| gc.collect()
|
| torch.cuda.empty_cache()
|
|
|
| def auto_adjust_batch_size(self, model: nn.Module, data_shape: tuple) -> int:
|
| """自动调整批次大小"""
|
| max_batch = 1
|
| device = next(model.parameters()).device
|
|
|
| for batch_size in [1, 2, 4, 8]:
|
| try:
|
|
|
| dummy_input = torch.randn(batch_size, *data_shape, device=device)
|
| dummy_timestep = torch.randint(0, 1000, (batch_size,), device=device)
|
| dummy_context = torch.randn(batch_size, 77, 768, device=device)
|
|
|
| with torch.no_grad():
|
| _ = model(dummy_input, dummy_timestep, dummy_context)
|
|
|
| torch.cuda.empty_cache()
|
| max_batch = batch_size
|
| except RuntimeError as e:
|
| if "CUDA out of memory" in str(e):
|
| break
|
| else:
|
| raise e
|
|
|
| return max_batch
|
|
|
|
|
| class GradientAccumulationScheduler:
|
| """梯度累积调度器"""
|
| def __init__(self, config: dict):
|
| self.initial_steps = config.get('gradient_accumulation_steps', 8)
|
| self.current_steps = self.initial_steps
|
| self.warmup_epochs = config.get('warmup_epochs', 5)
|
|
|
| def update(self, epoch: int):
|
| """根据epoch更新累积步数"""
|
| if epoch < self.warmup_epochs:
|
| self.current_steps = self.initial_steps
|
| else:
|
|
|
| self.current_steps = max(4, self.current_steps // 2)
|
|
|
|
|
| class P4Trainer:
|
| """针对P4优化的训练器"""
|
| def __init__(
|
| self,
|
| model: nn.Module,
|
| diffusion: DiffusionProcess,
|
| optimizer: torch.optim.Optimizer,
|
| train_loader: DataLoader,
|
| val_loader: Optional[DataLoader],
|
| config: dict,
|
| device: torch.device
|
| ):
|
| self.model = model
|
| self.diffusion = diffusion
|
| self.optimizer = optimizer
|
| self.train_loader = train_loader
|
| self.val_loader = val_loader
|
| self.config = config
|
| self.device = device
|
|
|
|
|
| self.current_epoch = 0
|
| self.global_step = 0
|
| self.best_loss = float('inf')
|
|
|
|
|
| self.memory_manager = MemoryManager(config)
|
| self.grad_scheduler = GradientAccumulationScheduler(config)
|
|
|
|
|
| self.use_amp = config.get('mixed_precision', 'fp16') != 'no'
|
| self.scaler = GradScaler(enabled=self.use_amp)
|
|
|
|
|
| self.lr_scheduler = self._create_lr_scheduler(config)
|
|
|
|
|
| self.use_ema = config.get('use_ema', True)
|
| if self.use_ema:
|
| self.ema_model = self._create_ema_model(model, config.get('ema_decay', 0.9999))
|
|
|
|
|
| self.use_wandb = config.get('use_wandb', False)
|
| self.log_dir = config.get('log_dir', './logs')
|
| os.makedirs(self.log_dir, exist_ok=True)
|
|
|
|
|
| self.checkpoint_dir = config.get('checkpoint_dir', './checkpoints')
|
| os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
|
| def _create_lr_scheduler(self, config: dict):
|
| """创建学习率调度器"""
|
| scheduler_type = config.get('learning_rate_scheduler', 'cosine')
|
| warmup_steps = config.get('warmup_steps', 1000)
|
|
|
| if scheduler_type == 'cosine':
|
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| self.optimizer,
|
| T_max=config.get('max_epochs', 50) * len(self.train_loader),
|
| eta_min=1e-6
|
| )
|
| elif scheduler_type == 'linear':
|
| scheduler = torch.optim.lr_scheduler.LinearLR(
|
| self.optimizer,
|
| start_factor=0.01,
|
| total_iters=warmup_steps
|
| )
|
| else:
|
| scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0)
|
|
|
| return scheduler
|
|
|
| def _create_ema_model(self, model: nn.Module, decay: float):
|
| """创建EMA模型"""
|
| from torch.optim.swa_utils import AveragedModel
|
| return AveragedModel(model, device=self.device, avg_fn=lambda avg, new, decay=decay: decay * avg + (1 - decay) * new)
|
|
|
| def train_epoch(self) -> float:
|
| """训练一个epoch"""
|
| self.model.train()
|
| total_loss = 0.0
|
| num_batches = len(self.train_loader)
|
|
|
|
|
| accumulation_steps = self.grad_scheduler.current_steps
|
| self.optimizer.zero_grad()
|
|
|
| pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}")
|
| for batch_idx, batch in enumerate(pbar):
|
|
|
| images = batch['images'].to(self.device)
|
| text_embeddings = batch['text_embeddings'].to(self.device)
|
|
|
|
|
| with autocast(enabled=self.use_amp):
|
| loss = self.diffusion.compute_loss(images, text_embeddings)
|
| loss = loss / accumulation_steps
|
|
|
|
|
| self.scaler.scale(loss).backward()
|
|
|
|
|
| if (batch_idx + 1) % accumulation_steps == 0:
|
|
|
| self.scaler.unscale_(self.optimizer)
|
| torch.nn.utils.clip_grad_norm_(
|
| self.model.parameters(),
|
| max_norm=self.config.get('gradient_clip', 1.0)
|
| )
|
|
|
|
|
| self.scaler.step(self.optimizer)
|
| self.scaler.update()
|
| self.optimizer.zero_grad()
|
|
|
|
|
| if self.use_ema:
|
| self.ema_model.update_parameters(self.model)
|
|
|
|
|
| self.lr_scheduler.step()
|
|
|
| self.global_step += 1
|
|
|
|
|
| total_loss += loss.item() * accumulation_steps
|
| current_loss = total_loss / (batch_idx + 1)
|
|
|
|
|
| pbar.set_postfix({
|
| 'loss': f'{current_loss:.4f}',
|
| 'lr': f'{self.optimizer.param_groups[0]["lr"]:.2e}'
|
| })
|
|
|
|
|
| if self.global_step % self.config.get('log_steps', 50) == 0:
|
| self._log_metrics({
|
| 'train/loss': current_loss,
|
| 'train/lr': self.optimizer.param_groups[0]['lr'],
|
| 'train/grad_norm': self._get_grad_norm(),
|
| })
|
|
|
|
|
| if self.global_step % self.config.get('sample_steps', 500) == 0:
|
| self._generate_samples()
|
|
|
|
|
| self.memory_manager.check_memory(self.global_step)
|
|
|
| epoch_loss = total_loss / num_batches
|
| return epoch_loss
|
|
|
| @torch.no_grad()
|
| def validate(self) -> float:
|
| """验证"""
|
| if self.val_loader is None:
|
| return float('inf')
|
|
|
| self.model.eval()
|
| total_loss = 0.0
|
|
|
| for batch in tqdm(self.val_loader, desc="Validation"):
|
| images = batch['images'].to(self.device)
|
| text_embeddings = batch['text_embeddings'].to(self.device)
|
|
|
| with autocast(enabled=self.use_amp):
|
| loss = self.diffusion.compute_loss(images, text_embeddings)
|
|
|
| total_loss += loss.item()
|
|
|
| val_loss = total_loss / len(self.val_loader)
|
|
|
|
|
| self._log_metrics({'val/loss': val_loss})
|
|
|
| return val_loss
|
|
|
| def train(self, num_epochs: Optional[int] = None):
|
| """训练循环"""
|
| if num_epochs is None:
|
| num_epochs = self.config.get('max_epochs', 50)
|
|
|
| for epoch in range(self.current_epoch, num_epochs):
|
| self.current_epoch = epoch
|
|
|
|
|
| self.grad_scheduler.update(epoch)
|
|
|
|
|
| train_loss = self.train_epoch()
|
|
|
|
|
| val_loss = self.validate()
|
|
|
|
|
| if val_loss < self.best_loss:
|
| self.best_loss = val_loss
|
| self.save_checkpoint('best_model.pt')
|
|
|
|
|
| if (epoch + 1) % self.config.get('save_checkpoint_every', 5) == 0:
|
| self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pt')
|
|
|
|
|
| self._log_metrics({
|
| 'epoch/train_loss': train_loss,
|
| 'epoch/val_loss': val_loss,
|
| 'epoch': epoch
|
| })
|
|
|
| print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
|
|
|
| def _get_grad_norm(self) -> float:
|
| """计算梯度范数"""
|
| total_norm = 0.0
|
| for p in self.model.parameters():
|
| if p.grad is not None:
|
| param_norm = p.grad.data.norm(2)
|
| total_norm += param_norm.item() ** 2
|
| return total_norm ** 0.5
|
|
|
| @torch.no_grad()
|
| def _generate_samples(self):
|
| """生成样本用于监控"""
|
| self.model.eval()
|
|
|
|
|
| sample_batch = next(iter(self.val_loader))
|
| text_embeddings = sample_batch['text_embeddings'][:4].to(self.device)
|
|
|
|
|
| with autocast(enabled=self.use_amp):
|
| latents = self.diffusion.generate(
|
| context=text_embeddings,
|
| num_samples=4,
|
| guidance_scale=7.5
|
| )
|
|
|
|
|
|
|
| if self.use_wandb:
|
| wandb.log({
|
| 'samples/latents': wandb.Image(latents[0].cpu().numpy())
|
| })
|
|
|
| def _log_metrics(self, metrics: Dict[str, Any]):
|
| """记录指标"""
|
| if self.use_wandb:
|
| wandb.log(metrics)
|
|
|
|
|
| log_file = os.path.join(self.log_dir, 'training_log.csv')
|
| with open(log_file, 'a') as f:
|
| if self.global_step == 0:
|
| header = ','.join(['step'] + list(metrics.keys()))
|
| f.write(header + '\n')
|
|
|
| values = ','.join([str(self.global_step)] + [str(v) for v in metrics.values()])
|
| f.write(values + '\n')
|
|
|
| def save_checkpoint(self, filename: str):
|
| """保存检查点"""
|
| checkpoint = {
|
| 'epoch': self.current_epoch,
|
| 'global_step': self.global_step,
|
| 'model_state_dict': self.model.state_dict(),
|
| 'optimizer_state_dict': self.optimizer.state_dict(),
|
| 'scaler_state_dict': self.scaler.state_dict(),
|
| 'best_loss': self.best_loss,
|
| 'config': self.config
|
| }
|
|
|
| if self.use_ema:
|
| checkpoint['ema_model_state_dict'] = self.ema_model.state_dict()
|
|
|
| save_path = os.path.join(self.checkpoint_dir, filename)
|
| torch.save(checkpoint, save_path)
|
|
|
|
|
| if self.config.get('save_compressed', True):
|
| torch.save(checkpoint, save_path.replace('.pt', '_compressed.pt'), _use_new_zipfile_serialization=True)
|
|
|
| print(f"检查点已保存: {save_path}")
|
|
|
| def load_checkpoint(self, checkpoint_path: str):
|
| """加载检查点"""
|
| checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
| self.model.load_state_dict(checkpoint['model_state_dict'])
|
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
|
| if self.use_ema and 'ema_model_state_dict' in checkpoint:
|
| self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
|
|
| self.current_epoch = checkpoint['epoch']
|
| self.global_step = checkpoint['global_step']
|
| self.best_loss = checkpoint['best_loss']
|
|
|
| print(f"已加载检查点: {checkpoint_path}") |