import torch from typing import Dict, Any, Optional, List import os from datetime import datetime import numpy as np from PIL import Image import torchvision.transforms as T class Callback: """回调基类""" def on_train_begin(self, trainer): pass def on_train_end(self, trainer): pass def on_epoch_begin(self, trainer, epoch): pass def on_epoch_end(self, trainer, epoch, train_loss, val_loss): pass def on_batch_begin(self, trainer, batch_idx, batch): pass def on_batch_end(self, trainer, batch_idx, batch, loss): pass def on_validation_begin(self, trainer): pass def on_validation_end(self, trainer, val_loss): pass class EarlyStopping(Callback): """早停回调""" def __init__(self, patience: int = 10, min_delta: float = 1e-4): self.patience = patience self.min_delta = min_delta self.best_loss = float('inf') self.counter = 0 self.should_stop = False def on_validation_end(self, trainer, val_loss): if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.should_stop = True print(f"早停触发,最佳损失: {self.best_loss:.4f}") class ModelCheckpoint(Callback): """模型检查点回调""" def __init__( self, save_dir: str = './checkpoints', save_best_only: bool = True, save_freq: int = 1, monitor: str = 'val_loss', mode: str = 'min' ): self.save_dir = save_dir self.save_best_only = save_best_only self.save_freq = save_freq self.monitor = monitor self.mode = mode os.makedirs(save_dir, exist_ok=True) self.best_value = float('inf') if mode == 'min' else -float('inf') def on_epoch_end(self, trainer, epoch, train_loss, val_loss): if epoch % self.save_freq != 0: return # 获取监控的值 if self.monitor == 'val_loss': value = val_loss elif self.monitor == 'train_loss': value = train_loss else: value = val_loss # 检查是否需要保存 should_save = False if self.save_best_only: if self.mode == 'min' and value < self.best_value: self.best_value = value should_save = True elif self.mode == 'max' and value > self.best_value: self.best_value = value should_save = True else: should_save = True if should_save: # 保存检查点 checkpoint = { 'epoch': epoch, 'model_state_dict': trainer.model.state_dict(), 'optimizer_state_dict': trainer.optimizer.state_dict(), 'train_loss': train_loss, 'val_loss': val_loss, } if trainer.use_ema: checkpoint['ema_model_state_dict'] = trainer.ema_model.state_dict() filename = f'checkpoint_epoch_{epoch}.pt' if not self.save_best_only else 'best_model.pt' save_path = os.path.join(self.save_dir, filename) torch.save(checkpoint, save_path) print(f"检查点已保存: {save_path}") class LearningRateSchedulerCallback(Callback): """学习率调度回调""" def __init__(self, scheduler, update_on: str = 'epoch'): self.scheduler = scheduler self.update_on = update_on # 'epoch' 或 'batch' def on_epoch_end(self, trainer, epoch, train_loss, val_loss): if self.update_on == 'epoch': self.scheduler.step() def on_batch_end(self, trainer, batch_idx, batch, loss): if self.update_on == 'batch': self.scheduler.step() class TensorBoardLogger(Callback): """TensorBoard日志记录器""" def __init__(self, log_dir: str = './logs'): from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=log_dir) self.global_step = 0 def on_batch_end(self, trainer, batch_idx, batch, loss): self.writer.add_scalar('train/loss', loss, self.global_step) self.writer.add_scalar('train/lr', trainer.optimizer.param_groups[0]['lr'], self.global_step) self.global_step += 1 def on_epoch_end(self, trainer, epoch, train_loss, val_loss): self.writer.add_scalar('epoch/train_loss', train_loss, epoch) self.writer.add_scalar('epoch/val_loss', val_loss, epoch) def on_train_end(self, trainer): self.writer.close() class SampleGeneratorCallback(Callback): """样本生成回调""" def __init__( self, sample_freq: int = 500, num_samples: int = 4, save_dir: str = './samples' ): self.sample_freq = sample_freq self.num_samples = num_samples self.save_dir = save_dir os.makedirs(save_dir, exist_ok=True) def on_batch_end(self, trainer, batch_idx, batch, loss): if trainer.global_step % self.sample_freq != 0: return # 生成样本 trainer.model.eval() with torch.no_grad(): # 使用验证集的提示 sample_batch = next(iter(trainer.val_loader)) text_embeddings = sample_batch['text_embeddings'][:self.num_samples].to(trainer.device) # 生成潜在表示 latents = trainer.diffusion.generate( context=text_embeddings, num_samples=self.num_samples, guidance_scale=7.5 ) # 保存样本 for i in range(self.num_samples): sample_path = os.path.join( self.save_dir, f'step_{trainer.global_step}_sample_{i}.pt' ) torch.save(latents[i].cpu(), sample_path) trainer.model.train() class MemoryMonitorCallback(Callback): """内存监控回调""" def __init__(self, monitor_freq: int = 100): self.monitor_freq = monitor_freq def on_batch_end(self, trainer, batch_idx, batch, loss): if trainer.global_step % self.monitor_freq == 0: if hasattr(trainer, 'memory_manager'): trainer.memory_manager.print_memory_stats() class GradientMonitorCallback(Callback): """梯度监控回调""" def __init__(self, monitor_freq: int = 100): self.monitor_freq = monitor_freq def on_batch_end(self, trainer, batch_idx, batch, loss): if trainer.global_step % self.monitor_freq == 0: grad_norm = self._compute_gradient_norm(trainer.model) if hasattr(trainer, 'writer'): trainer.writer.add_scalar('train/grad_norm', grad_norm, trainer.global_step) def _compute_gradient_norm(self, model) -> float: total_norm = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 return total_norm ** 0.5 class CallbackHandler: """回调处理器""" def __init__(self): self.callbacks = [] def add_callback(self, callback: Callback): self.callbacks.append(callback) def on_train_begin(self, trainer): for callback in self.callbacks: callback.on_train_begin(trainer) def on_train_end(self, trainer): for callback in self.callbacks: callback.on_train_end(trainer) def on_epoch_begin(self, trainer, epoch): for callback in self.callbacks: callback.on_epoch_begin(trainer, epoch) def on_epoch_end(self, trainer, epoch, train_loss, val_loss): for callback in self.callbacks: callback.on_epoch_end(trainer, epoch, train_loss, val_loss) def on_batch_begin(self, trainer, batch_idx, batch): for callback in self.callbacks: callback.on_batch_begin(trainer, batch_idx, batch) def on_batch_end(self, trainer, batch_idx, batch, loss): for callback in self.callbacks: callback.on_batch_end(trainer, batch_idx, batch, loss) def on_validation_begin(self, trainer): for callback in self.callbacks: callback.on_validation_begin(trainer) def on_validation_end(self, trainer, val_loss): for callback in self.callbacks: callback.on_validation_end(trainer, val_loss) def create_default_callbacks(config: dict) -> CallbackHandler: """创建默认回调""" handler = CallbackHandler() # 模型检查点 checkpoint_callback = ModelCheckpoint( save_dir=config.get('checkpoint_dir', './checkpoints'), save_best_only=config.get('save_best_model', True), save_freq=config.get('save_checkpoint_every', 1), monitor='val_loss', mode='min' ) handler.add_callback(checkpoint_callback) # TensorBoard日志 if config.get('use_tensorboard', True): tb_logger = TensorBoardLogger( log_dir=config.get('log_dir', './logs') ) handler.add_callback(tb_logger) # 样本生成 if config.get('sample_steps', 500) > 0: sample_callback = SampleGeneratorCallback( sample_freq=config.get('sample_steps', 500), num_samples=4, save_dir=config.get('sample_dir', './samples') ) handler.add_callback(sample_callback) # 内存监控 memory_callback = MemoryMonitorCallback( monitor_freq=config.get('log_steps', 50) ) handler.add_callback(memory_callback) # 梯度监控 grad_callback = GradientMonitorCallback( monitor_freq=config.get('log_steps', 50) ) handler.add_callback(grad_callback) # 早停 if config.get('early_stopping', False): early_stop = EarlyStopping( patience=config.get('early_stopping_patience', 10), min_delta=config.get('early_stopping_min_delta', 1e-4) ) handler.add_callback(early_stop) return handler