Lumina_Dev_Legacy / src /training /trainer_p4.py
TAI Research
Initial commit: Lumina_Dev_Legacy (archived)
29691f6
Raw
History Blame Contribute Delete
14.2 kB
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from typing import Optional, Dict, Any
import wandb
import os
from tqdm import tqdm
import numpy as np
class MemoryManager:
"""P4显存管理器"""
def __init__(self, config: dict):
self.config = config
self.warning_threshold = config.get('warning_threshold_gb', 6.0) * 1024**3
self.critical_threshold = config.get('memory_threshold_gb', 6.5) * 1024**3
self.cleanup_frequency = config.get('cleanup_frequency', 100)
def check_memory(self, step: int):
"""检查显存使用情况"""
if step % self.cleanup_frequency == 0:
self.cleanup()
allocated = torch.cuda.memory_allocated()
if allocated > self.critical_threshold:
raise RuntimeError(f"显存超出临界阈值: {allocated / 1024**3:.2f} GB")
elif allocated > self.warning_threshold:
print(f"警告: 显存使用较高: {allocated / 1024**3:.2f} GB")
def cleanup(self):
"""清理显存"""
import gc
gc.collect()
torch.cuda.empty_cache()
def auto_adjust_batch_size(self, model: nn.Module, data_shape: tuple) -> int:
"""自动调整批次大小"""
max_batch = 1
device = next(model.parameters()).device
for batch_size in [1, 2, 4, 8]:
try:
# 测试内存
dummy_input = torch.randn(batch_size, *data_shape, device=device)
dummy_timestep = torch.randint(0, 1000, (batch_size,), device=device)
dummy_context = torch.randn(batch_size, 77, 768, device=device)
with torch.no_grad():
_ = model(dummy_input, dummy_timestep, dummy_context)
torch.cuda.empty_cache()
max_batch = batch_size
except RuntimeError as e:
if "CUDA out of memory" in str(e):
break
else:
raise e
return max_batch
class GradientAccumulationScheduler:
"""梯度累积调度器"""
def __init__(self, config: dict):
self.initial_steps = config.get('gradient_accumulation_steps', 8)
self.current_steps = self.initial_steps
self.warmup_epochs = config.get('warmup_epochs', 5)
def update(self, epoch: int):
"""根据epoch更新累积步数"""
if epoch < self.warmup_epochs:
self.current_steps = self.initial_steps
else:
# 逐步减少累积步数以加快训练
self.current_steps = max(4, self.current_steps // 2)
class P4Trainer:
"""针对P4优化的训练器"""
def __init__(
self,
model: nn.Module,
diffusion: DiffusionProcess,
optimizer: torch.optim.Optimizer,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
config: dict,
device: torch.device
):
self.model = model
self.diffusion = diffusion
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = device
# 训练状态
self.current_epoch = 0
self.global_step = 0
self.best_loss = float('inf')
# 初始化工具
self.memory_manager = MemoryManager(config)
self.grad_scheduler = GradientAccumulationScheduler(config)
# 混合精度训练
self.use_amp = config.get('mixed_precision', 'fp16') != 'no'
self.scaler = GradScaler(enabled=self.use_amp)
# 学习率调度器
self.lr_scheduler = self._create_lr_scheduler(config)
# EMA模型
self.use_ema = config.get('use_ema', True)
if self.use_ema:
self.ema_model = self._create_ema_model(model, config.get('ema_decay', 0.9999))
# 日志记录
self.use_wandb = config.get('use_wandb', False)
self.log_dir = config.get('log_dir', './logs')
os.makedirs(self.log_dir, exist_ok=True)
# 检查点
self.checkpoint_dir = config.get('checkpoint_dir', './checkpoints')
os.makedirs(self.checkpoint_dir, exist_ok=True)
def _create_lr_scheduler(self, config: dict):
"""创建学习率调度器"""
scheduler_type = config.get('learning_rate_scheduler', 'cosine')
warmup_steps = config.get('warmup_steps', 1000)
if scheduler_type == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config.get('max_epochs', 50) * len(self.train_loader),
eta_min=1e-6
)
elif scheduler_type == 'linear':
scheduler = torch.optim.lr_scheduler.LinearLR(
self.optimizer,
start_factor=0.01,
total_iters=warmup_steps
)
else:
scheduler = torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0)
return scheduler
def _create_ema_model(self, model: nn.Module, decay: float):
"""创建EMA模型"""
from torch.optim.swa_utils import AveragedModel
return AveragedModel(model, device=self.device, avg_fn=lambda avg, new, decay=decay: decay * avg + (1 - decay) * new)
def train_epoch(self) -> float:
"""训练一个epoch"""
self.model.train()
total_loss = 0.0
num_batches = len(self.train_loader)
# 梯度累积
accumulation_steps = self.grad_scheduler.current_steps
self.optimizer.zero_grad()
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}")
for batch_idx, batch in enumerate(pbar):
# 将数据移到设备
images = batch['images'].to(self.device)
text_embeddings = batch['text_embeddings'].to(self.device)
# 混合精度前向传播
with autocast(enabled=self.use_amp):
loss = self.diffusion.compute_loss(images, text_embeddings)
loss = loss / accumulation_steps
# 反向传播
self.scaler.scale(loss).backward()
# 梯度累积更新
if (batch_idx + 1) % accumulation_steps == 0:
# 梯度裁剪
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.config.get('gradient_clip', 1.0)
)
# 更新参数
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# 更新EMA模型
if self.use_ema:
self.ema_model.update_parameters(self.model)
# 更新学习率
self.lr_scheduler.step()
self.global_step += 1
# 记录损失
total_loss += loss.item() * accumulation_steps
current_loss = total_loss / (batch_idx + 1)
# 更新进度条
pbar.set_postfix({
'loss': f'{current_loss:.4f}',
'lr': f'{self.optimizer.param_groups[0]["lr"]:.2e}'
})
# 记录日志
if self.global_step % self.config.get('log_steps', 50) == 0:
self._log_metrics({
'train/loss': current_loss,
'train/lr': self.optimizer.param_groups[0]['lr'],
'train/grad_norm': self._get_grad_norm(),
})
# 生成样本
if self.global_step % self.config.get('sample_steps', 500) == 0:
self._generate_samples()
# 显存管理
self.memory_manager.check_memory(self.global_step)
epoch_loss = total_loss / num_batches
return epoch_loss
@torch.no_grad()
def validate(self) -> float:
"""验证"""
if self.val_loader is None:
return float('inf')
self.model.eval()
total_loss = 0.0
for batch in tqdm(self.val_loader, desc="Validation"):
images = batch['images'].to(self.device)
text_embeddings = batch['text_embeddings'].to(self.device)
with autocast(enabled=self.use_amp):
loss = self.diffusion.compute_loss(images, text_embeddings)
total_loss += loss.item()
val_loss = total_loss / len(self.val_loader)
# 记录验证指标
self._log_metrics({'val/loss': val_loss})
return val_loss
def train(self, num_epochs: Optional[int] = None):
"""训练循环"""
if num_epochs is None:
num_epochs = self.config.get('max_epochs', 50)
for epoch in range(self.current_epoch, num_epochs):
self.current_epoch = epoch
# 更新梯度累积策略
self.grad_scheduler.update(epoch)
# 训练一个epoch
train_loss = self.train_epoch()
# 验证
val_loss = self.validate()
# 保存最佳模型
if val_loss < self.best_loss:
self.best_loss = val_loss
self.save_checkpoint('best_model.pt')
# 定期保存检查点
if (epoch + 1) % self.config.get('save_checkpoint_every', 5) == 0:
self.save_checkpoint(f'checkpoint_epoch_{epoch+1}.pt')
# 记录epoch指标
self._log_metrics({
'epoch/train_loss': train_loss,
'epoch/val_loss': val_loss,
'epoch': epoch
})
print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
def _get_grad_norm(self) -> float:
"""计算梯度范数"""
total_norm = 0.0
for p in self.model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** 0.5
@torch.no_grad()
def _generate_samples(self):
"""生成样本用于监控"""
self.model.eval()
# 使用验证集的前几个提示
sample_batch = next(iter(self.val_loader))
text_embeddings = sample_batch['text_embeddings'][:4].to(self.device)
# 生成样本
with autocast(enabled=self.use_amp):
latents = self.diffusion.generate(
context=text_embeddings,
num_samples=4,
guidance_scale=7.5
)
# 解码为图像
# 这里需要VAE解码器,暂时保存潜在表示
if self.use_wandb:
wandb.log({
'samples/latents': wandb.Image(latents[0].cpu().numpy())
})
def _log_metrics(self, metrics: Dict[str, Any]):
"""记录指标"""
if self.use_wandb:
wandb.log(metrics)
# 同时记录到本地文件
log_file = os.path.join(self.log_dir, 'training_log.csv')
with open(log_file, 'a') as f:
if self.global_step == 0:
header = ','.join(['step'] + list(metrics.keys()))
f.write(header + '\n')
values = ','.join([str(self.global_step)] + [str(v) for v in metrics.values()])
f.write(values + '\n')
def save_checkpoint(self, filename: str):
"""保存检查点"""
checkpoint = {
'epoch': self.current_epoch,
'global_step': self.global_step,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scaler_state_dict': self.scaler.state_dict(),
'best_loss': self.best_loss,
'config': self.config
}
if self.use_ema:
checkpoint['ema_model_state_dict'] = self.ema_model.state_dict()
save_path = os.path.join(self.checkpoint_dir, filename)
torch.save(checkpoint, save_path)
# 如果启用压缩,保存压缩版本
if self.config.get('save_compressed', True):
torch.save(checkpoint, save_path.replace('.pt', '_compressed.pt'), _use_new_zipfile_serialization=True)
print(f"检查点已保存: {save_path}")
def load_checkpoint(self, checkpoint_path: str):
"""加载检查点"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
if self.use_ema and 'ema_model_state_dict' in checkpoint:
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
self.current_epoch = checkpoint['epoch']
self.global_step = checkpoint['global_step']
self.best_loss = checkpoint['best_loss']
print(f"已加载检查点: {checkpoint_path}")