""" Smoke Tests for Cosmos Predict2.5 Validates model loading and basic inference """ import sys import os import tempfile from pathlib import Path # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) def test_predict_model_loading(): """Test that Predict2.5 model can be loaded""" print("\n" + "=" * 60) print("TEST: Predict2.5 Model Loading") print("=" * 60) import torch from cosmos.loaders import load_predict_pipeline, get_device_info device_info = get_device_info() print(f"Device: {device_info['name']}") print(f"VRAM: {device_info['free_vram_gb']:.2f} / {device_info['total_vram_gb']:.2f} GB") if device_info['free_vram_gb'] < 30: print("WARNING: Less than 30GB VRAM available, model may fail to load") try: pipe = load_predict_pipeline() print("Model loaded successfully!") print(f"Pipeline type: {type(pipe).__name__}") return True except Exception as e: print(f"FAILED: {e}") return False def test_predict_text2world_inference(): """Test Text2World inference with minimal parameters""" print("\n" + "=" * 60) print("TEST: Predict2.5 Text2World Inference") print("=" * 60) import torch from cosmos.infer_predict import predict_text2world # Minimal inference parameters for smoke test output_path = tempfile.mktemp(suffix=".mp4") try: result = predict_text2world( prompt="A simple test scene with a red ball on a white background", num_frames=17, # Minimum frames height=480, width=720, num_inference_steps=10, # Minimum steps for speed guidance_scale=7.0, seed=42, output_path=output_path ) # Validate output assert result is not None, "Result is None" assert os.path.exists(result['video_path']), "Output video does not exist" assert result['num_frames'] >= 8, f"Too few frames: {result['num_frames']}" # Check video file size file_size = os.path.getsize(result['video_path']) assert file_size > 1000, f"Video file too small: {file_size} bytes" print(f"Output video: {result['video_path']}") print(f"Frames: {result['num_frames']}") print(f"Resolution: {result['resolution']}") print(f"Inference time: {result['inference_time_s']}s") print("TEST PASSED!") return True except Exception as e: print(f"FAILED: {e}") import traceback traceback.print_exc() return False def test_predict_output_validation(): """Validate output video properties""" print("\n" + "=" * 60) print("TEST: Predict2.5 Output Validation") print("=" * 60) import torch from cosmos.utils_video import create_test_video, get_video_info, compute_temporal_smoothness, load_video_frames # Create a simple test video to validate utilities test_video = create_test_video(num_frames=16, width=320, height=240) try: # Test video info extraction info = get_video_info(test_video) print(f"Test video info: {info}") assert info['width'] == 320, f"Width mismatch: {info['width']}" assert info['height'] == 240, f"Height mismatch: {info['height']}" assert info['fps'] > 0, f"Invalid FPS: {info['fps']}" # Test frame loading frames = load_video_frames(test_video) assert len(frames) == 16, f"Frame count mismatch: {len(frames)}" # Test temporal smoothness smoothness = compute_temporal_smoothness(frames) print(f"Temporal smoothness: {smoothness}") assert smoothness['mean_diff'] < 100, "Frames too different (smoothness check)" assert smoothness['num_frames'] == 16, "Frame count mismatch in smoothness" print("TEST PASSED!") return True except Exception as e: print(f"FAILED: {e}") import traceback traceback.print_exc() return False finally: # Cleanup if os.path.exists(test_video): os.remove(test_video) def run_all_predict_tests(): """Run all Predict2.5 smoke tests""" print("\n" + "=" * 60) print("COSMOS PREDICT2.5 SMOKE TESTS") print("=" * 60) results = {} # Test 1: Output validation (no model needed) results['output_validation'] = test_predict_output_validation() # Test 2: Model loading (requires GPU) results['model_loading'] = test_predict_model_loading() # Test 3: Inference (requires GPU + loaded model) if results['model_loading']: results['text2world_inference'] = test_predict_text2world_inference() else: results['text2world_inference'] = False print("\nSkipping inference test due to model loading failure") # Summary print("\n" + "=" * 60) print("PREDICT2.5 TEST SUMMARY") print("=" * 60) passed = sum(1 for v in results.values() if v) total = len(results) for test_name, passed_flag in results.items(): status = "PASSED" if passed_flag else "FAILED" print(f" {test_name}: {status}") print(f"\nTotal: {passed}/{total} tests passed") return passed == total if __name__ == "__main__": success = run_all_predict_tests() sys.exit(0 if success else 1)