#!/usr/bin/env python3 """ 模型导出脚本 用于导出训练好的模型为不同格式 """ import os import sys import argparse import yaml import torch import torch.nn as nn # 添加项目根目录到Python路径 sys.path.append(os.path.dirname(os.path.dirname(__file__))) from src.models.unet_light import UNetLight from src.models.diffusion import DiffusionProcess from src.inference.optimization import ModelOptimizer, ONNXExporter, optimize_model_for_p4 def load_config(config_path: str) -> dict: """加载配置文件""" with open(config_path, 'r') as f: config = yaml.safe_load(f) return config def load_model(checkpoint_path: str, config: dict, device: torch.device) -> nn.Module: """加载模型""" # 加载模型配置 model_config_path = config.get('model_config', 'configs/model/unet_light.yaml') model_config = load_config(model_config_path) # 创建模型 model = UNetLight(model_config) # 加载检查点 print(f"加载检查点: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') # 加载模型权重 if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) elif 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) # 移动到设备 model = model.to(device) model.eval() print(f"模型加载完成") return model def export_torchscript(model: nn.Module, output_path: str): """导出为TorchScript格式""" print(f"导出为TorchScript: {output_path}") # 创建示例输入 example_input = torch.randn(1, model.in_channels, 64, 64) example_timestep = torch.tensor([500]) example_context = torch.randn(1, 77, 768) # 跟踪模型 traced_model = torch.jit.trace( model, (example_input, example_timestep, example_context), check_trace=False ) # 保存 traced_model.save(output_path) print(f"TorchScript模型已保存: {output_path}") return traced_model def export_onnx(model: nn.Module, output_path: str, opset_version: int = 14): """导出为ONNX格式""" print(f"导出为ONNX: {output_path}") # 创建示例输入 example_input = torch.randn(1, model.in_channels, 64, 64) example_timestep = torch.tensor([500]) example_context = torch.randn(1, 77, 768) # 设置动态轴 dynamic_axes = { 'input': {0: 'batch_size'}, 'timestep': {0: 'batch_size'}, 'context': {0: 'batch_size'}, 'output': {0: 'batch_size'} } # 导出 torch.onnx.export( model, (example_input, example_timestep, example_context), output_path, input_names=['input', 'timestep', 'context'], output_names=['output'], dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=True, verbose=False ) print(f"ONNX模型已保存: {output_path}") # 验证ONNX模型 import onnx onnx_model = onnx.load(output_path) onnx.checker.check_model(onnx_model) print("ONNX模型验证成功") def export_safetensors(model: nn.Module, output_path: str): """导出为safetensors格式""" try: from safetensors.torch import save_file # 转换为safetensors格式 state_dict = model.state_dict() save_file(state_dict, output_path) print(f"Safetensors模型已保存: {output_path}") except ImportError: print("safetensors未安装,跳过safetensors导出") print("安装: pip install safetensors") def optimize_and_export( checkpoint_path: str, output_dir: str, formats: list = ['torchscript', 'onnx', 'safetensors'], optimize_for_p4: bool = True ): """优化并导出模型""" # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 加载配置 config_path = "configs/training/p4_optimized.yaml" config = load_config(config_path) # 设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载模型 model = load_model(checkpoint_path, config, device) # 优化模型(针对P4) if optimize_for_p4: print("优化模型(针对P4)...") model = optimize_model_for_p4(model) # 获取模型信息 total_params = sum(p.numel() for p in model.parameters()) model_size_mb = total_params * 4 / 1024**2 # fp32 print(f"\n模型信息:") print(f" 参数量: {total_params:,}") print(f" 模型大小: {model_size_mb:.2f} MB (fp32)") # 导出为不同格式 base_name = os.path.splitext(os.path.basename(checkpoint_path))[0] for fmt in formats: if fmt == 'torchscript': output_path = os.path.join(output_dir, f"{base_name}.torchscript.pt") export_torchscript(model, output_path) elif fmt == 'onnx': output_path = os.path.join(output_dir, f"{base_name}.onnx") export_onnx(model, output_path) elif fmt == 'safetensors': output_path = os.path.join(output_dir, f"{base_name}.safetensors") export_safetensors(model, output_path) elif fmt == 'pth': output_path = os.path.join(output_dir, f"{base_name}.pth") torch.save(model.state_dict(), output_path) print(f"PyTorch模型已保存: {output_path}") else: print(f"未知的格式: {fmt}") print(f"\n所有模型已导出到: {output_dir}") def create_lite_version(model: nn.Module, reduction_factor: float = 0.5) -> nn.Module: """创建轻量版本(通过减少通道数)""" # 注意:这是一个示例,需要根据实际模型结构调整 print(f"创建轻量版本,减少因子: {reduction_factor}") # 这里应该实现具体的轻量化逻辑 # 例如,减少UNet的通道数 return model def main(): """主函数""" parser = argparse.ArgumentParser(description="导出Lumina模型") parser.add_argument( "--checkpoint", type=str, required=True, help="模型检查点路径" ) parser.add_argument( "--output-dir", type=str, default="./exported_models", help="输出目录" ) parser.add_argument( "--formats", type=str, nargs="+", default=['torchscript', 'onnx'], choices=['torchscript', 'onnx', 'safetensors', 'pth'], help="导出格式" ) parser.add_argument( "--optimize", action="store_true", help="优化模型(针对P4)" ) parser.add_argument( "--lite", action="store_true", help="创建轻量版本" ) parser.add_argument( "--lite-factor", type=float, default=0.5, help="轻量化减少因子" ) args = parser.parse_args() # 检查输入文件 if not os.path.exists(args.checkpoint): print(f"错误: 检查点文件不存在: {args.checkpoint}") return # 优化并导出 optimize_and_export( checkpoint_path=args.checkpoint, output_dir=args.output_dir, formats=args.formats, optimize_for_p4=args.optimize ) if __name__ == "__main__": main()