import torch import gc from typing import Optional import psutil import os class CPUMemoryManager: """CPU内存管理器""" def __init__(self, warning_threshold: float = 0.9): """ 参数: warning_threshold: 内存使用率警告阈值 (0-1) """ self.warning_threshold = warning_threshold def get_memory_usage(self) -> tuple: """获取内存使用情况""" process = psutil.Process(os.getpid()) memory_info = process.memory_info() # 获取系统内存信息 system_memory = psutil.virtual_memory() return { 'process_rss_mb': memory_info.rss / 1024 / 1024, 'process_vms_mb': memory_info.vms / 1024 / 1024, 'system_total_mb': system_memory.total / 1024 / 1024, 'system_available_mb': system_memory.available / 1024 / 1024, 'system_used_percent': system_memory.percent } def check_memory(self) -> bool: """检查内存使用是否安全""" memory_info = self.get_memory_usage() if memory_info['system_used_percent'] > self.warning_threshold * 100: print(f"警告: 系统内存使用率过高: {memory_info['system_used_percent']:.1f}%") return False return True class OptimizerCPUOffload: """优化器状态CPU卸载""" def __init__(self, optimizer: torch.optim.Optimizer): self.optimizer = optimizer self.original_states = {} def offload_to_cpu(self): """将优化器状态卸载到CPU""" for param_group in self.optimizer.param_groups: for param in param_group['params']: if param in self.optimizer.state: state = self.optimizer.state[param] for key in list(state.keys()): if torch.is_tensor(state[key]): # 移动到CPU并保留引用 self.original_states[(param, key)] = state[key] state[key] = state[key].cpu() def load_to_gpu(self, device: torch.device): """将优化器状态加载回GPU""" for param_group in self.optimizer.param_groups: for param in param_group['params']: if param in self.optimizer.state: state = self.optimizer.state[param] for key in list(state.keys()): if (param, key) in self.original_states: state[key] = self.original_states[(param, key)].to(device) del self.original_states[(param, key)] class ActivationCPUOffload: """激活值CPU卸载""" def __init__(self, model: torch.nn.Module): self.model = model self.hooks = [] def register_hooks(self): """注册前向钩子来卸载激活值""" def hook_fn(module, input, output): if torch.is_tensor(output): return output.cpu() elif isinstance(output, tuple): return tuple(x.cpu() if torch.is_tensor(x) else x for x in output) return output # 为每个模块注册钩子 for name, module in self.model.named_modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm)): hook = module.register_forward_hook(hook_fn) self.hooks.append(hook) def remove_hooks(self): """移除所有钩子""" for hook in self.hooks: hook.remove() self.hooks = [] class MemoryOptimizer: """综合内存优化器""" def __init__(self, config: dict): self.config = config # GPU内存管理 self.gpu_warning_threshold = config.get('warning_threshold_gb', 6.0) * 1024**3 self.gpu_critical_threshold = config.get('memory_threshold_gb', 6.5) * 1024**3 # CPU内存管理 self.cpu_manager = CPUMemoryManager( warning_threshold=config.get('cpu_warning_threshold', 0.85) ) # 清理频率 self.cleanup_frequency = config.get('cleanup_frequency', 100) # 状态跟踪 self.optimizer_offloader = None self.activation_offloader = None def setup_model_optimizations(self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None): """设置模型优化""" # 启用梯度检查点 if hasattr(model, 'enable_gradient_checkpointing'): model.enable_gradient_checkpointing() # 设置优化器CPU卸载 if optimizer is not None and self.config.get('optimizer_on_cpu', True): self.optimizer_offloader = OptimizerCPUOffload(optimizer) # 设置激活值CPU卸载 if self.config.get('cpu_offload', True): self.activation_offloader = ActivationCPUOffload(model) self.activation_offloader.register_hooks() # 设置注意力分片 if self.config.get('attention_slicing', 'auto') == 'auto': self._enable_attention_slicing(model) def _enable_attention_slicing(self, model: torch.nn.Module): """启用注意力分片""" for module in model.modules(): if hasattr(module, 'set_attention_slice'): module.set_attention_slice('auto') def step_start(self): """训练步骤开始时的内存管理""" # 将优化器状态加载到GPU(如果需要) if self.optimizer_offloader is not None: device = next(self.optimizer_offloader.optimizer.param_groups[0]['params'][0].device) self.optimizer_offloader.load_to_gpu(device) # 检查内存 self.check_all_memory() def step_end(self, step: int): """训练步骤结束时的内存管理""" # 将优化器状态卸载到CPU if self.optimizer_offloader is not None: self.optimizer_offloader.offload_to_cpu() # 定期清理 if step % self.cleanup_frequency == 0: self.cleanup() # 检查内存 self.check_all_memory() def check_all_memory(self): """检查所有内存""" # 检查GPU内存 gpu_allocated = torch.cuda.memory_allocated() if gpu_allocated > self.gpu_critical_threshold: self._handle_gpu_oom() elif gpu_allocated > self.gpu_warning_threshold: print(f"GPU内存警告: {gpu_allocated / 1024**3:.2f} GB") # 检查CPU内存 if not self.cpu_manager.check_memory(): self._handle_cpu_oom() def _handle_gpu_oom(self): """处理GPU OOM""" print("GPU内存不足,尝试清理...") self.cleanup(force=True) # 如果仍然不足,抛出异常 if torch.cuda.memory_allocated() > self.gpu_critical_threshold: raise RuntimeError("GPU内存不足,无法继续训练") def _handle_cpu_oom(self): """处理CPU OOM""" print("CPU内存不足,尝试清理...") gc.collect() def cleanup(self, force: bool = False): """清理内存""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # 如果强制清理,尝试更激进的清理 if force: torch.cuda.synchronize() torch.cuda.ipc_collect() def get_memory_stats(self) -> dict: """获取内存统计信息""" stats = {} # GPU统计 if torch.cuda.is_available(): stats['gpu'] = { 'allocated_gb': torch.cuda.memory_allocated() / 1024**3, 'reserved_gb': torch.cuda.memory_reserved() / 1024**3, 'max_allocated_gb': torch.cuda.max_memory_allocated() / 1024**3, } # CPU统计 cpu_stats = self.cpu_manager.get_memory_usage() stats['cpu'] = cpu_stats return stats def print_memory_stats(self): """打印内存统计信息""" stats = self.get_memory_stats() print("=" * 50) print("内存使用统计:") if 'gpu' in stats: gpu = stats['gpu'] print(f"GPU - 已分配: {gpu['allocated_gb']:.2f} GB, " f"已保留: {gpu['reserved_gb']:.2f} GB, " f"最大分配: {gpu['max_allocated_gb']:.2f} GB") if 'cpu' in stats: cpu = stats['cpu'] print(f"CPU - 进程RSS: {cpu['process_rss_mb']:.1f} MB, " f"系统使用率: {cpu['system_used_percent']:.1f}%") print("=" * 50)