""" Cosmos Model Loaders Lazy loading with caching for Predict2.5 and Transfer2.5 models """ import os import torch from typing import Optional, Dict, Any import gc # Global model cache _model_cache: Dict[str, Any] = {} # Model configurations MODEL_CONFIGS = { "predict2.5-2b": { "model_id": "nvidia/Cosmos-Predict2.5-2B", "vram_gb": 32.54, "description": "Text/Image/Video to World generation" }, "transfer2.5-2b": { "model_id": "nvidia/Cosmos-Transfer2.5-2B", "vram_gb": 65.4, "description": "World-to-world translation with control inputs" } } def get_device_info() -> Dict[str, Any]: """Get current device and memory information""" if torch.cuda.is_available(): device = torch.cuda.current_device() props = torch.cuda.get_device_properties(device) free_mem, total_mem = torch.cuda.mem_get_info(device) return { "device": f"cuda:{device}", "name": props.name, "total_vram_gb": total_mem / (1024**3), "free_vram_gb": free_mem / (1024**3), "compute_capability": f"{props.major}.{props.minor}" } return {"device": "cpu", "name": "CPU", "total_vram_gb": 0, "free_vram_gb": 0} def clear_model_cache(model_name: Optional[str] = None): """Clear cached models to free VRAM""" global _model_cache if model_name: if model_name in _model_cache: del _model_cache[model_name] print(f"Cleared {model_name} from cache") else: _model_cache.clear() print("Cleared all models from cache") gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def load_predict_pipeline( model_size: str = "2b", torch_dtype: torch.dtype = torch.bfloat16, force_reload: bool = False ): """ Load Cosmos Predict2.5 pipeline Args: model_size: "2b" or "14b" (only 2b supported on ZeroGPU) torch_dtype: torch.bfloat16 recommended force_reload: Force reload even if cached Returns: Diffusers pipeline for video generation """ global _model_cache cache_key = f"predict_{model_size}" if cache_key in _model_cache and not force_reload: print(f"Using cached Predict2.5-{model_size.upper()} model") return _model_cache[cache_key] # Clear other models to free VRAM clear_model_cache() from diffusers import DiffusionPipeline model_id = f"nvidia/Cosmos-Predict2.5-{model_size.upper()}" print(f"Loading {model_id}...") device_info = get_device_info() print(f"Device: {device_info['name']}, Free VRAM: {device_info['free_vram_gb']:.2f} GB") pipe = DiffusionPipeline.from_pretrained( model_id, torch_dtype=torch_dtype, trust_remote_code=True ) pipe.to("cuda") _model_cache[cache_key] = pipe print(f"Predict2.5-{model_size.upper()} loaded successfully!") return pipe def load_transfer_pipeline( model_size: str = "2b", torch_dtype: torch.dtype = torch.bfloat16, force_reload: bool = False ): """ Load Cosmos Transfer2.5 pipeline Args: model_size: "2b" only (larger models require more VRAM) torch_dtype: torch.bfloat16 required (FP16/FP32 not supported) force_reload: Force reload even if cached Returns: Diffusers pipeline for world-to-world translation """ global _model_cache cache_key = f"transfer_{model_size}" if cache_key in _model_cache and not force_reload: print(f"Using cached Transfer2.5-{model_size.upper()} model") return _model_cache[cache_key] # Clear other models to free VRAM (Transfer needs more VRAM) clear_model_cache() from diffusers import DiffusionPipeline model_id = f"nvidia/Cosmos-Transfer2.5-{model_size.upper()}" print(f"Loading {model_id}...") device_info = get_device_info() print(f"Device: {device_info['name']}, Free VRAM: {device_info['free_vram_gb']:.2f} GB") if device_info['free_vram_gb'] < 60: print("WARNING: Transfer2.5 requires ~65GB VRAM. May fail on current hardware.") pipe = DiffusionPipeline.from_pretrained( model_id, torch_dtype=torch_dtype, trust_remote_code=True ) pipe.to("cuda") _model_cache[cache_key] = pipe print(f"Transfer2.5-{model_size.upper()} loaded successfully!") return pipe