""" Paper Consistency Validation for Cosmos Models Reference: arXiv 2511.00062 - World Simulation with Video Foundation Models for Physical AI This module implements minimal reproducible tests to verify consistency with paper claims. """ import sys import os import json import time from pathlib import Path from typing import List, Dict, Any from datetime import datetime # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) # Paper Reference Points for Validation PAPER_CLAIMS = { "predict_unified_generation": { "section": "Section 3.1 - Unified World Generation", "claim": "Cosmos-Predict2.5 unifies Text2World, Image2World, and Video2World into a single model", "validation": "Verify model can handle all three input modalities" }, "predict_temporal_consistency": { "section": "Section 4.2 - Temporal Coherence", "claim": "Generated videos maintain reasonable spatiotemporal continuity in short-term prediction", "validation": "Measure frame-to-frame pixel difference to verify smooth transitions" }, "predict_reproducibility": { "section": "Section 5 - Reproducibility", "claim": "Fixed random seeds produce deterministic outputs", "validation": "Run same prompt with same seed twice, verify identical outputs" }, "transfer_structure_preservation": { "section": "Section 3.2 - World-to-World Translation", "claim": "Cosmos-Transfer2.5 preserves structural consistency during domain transfer", "validation": "Compare edge maps between input and output to verify structure preservation" }, "transfer_domain_adaptation": { "section": "Section 4.3 - Domain Transfer", "claim": "Model can perform world-to-world translation (e.g., day->night, clear->rain)", "validation": "Verify output differs from input while maintaining structure" } } def validate_predict_temporal_consistency( num_samples: int = 3, seed_base: int = 42 ) -> Dict[str, Any]: """ Validate temporal consistency of Predict2.5 outputs Paper claim (Section 4.2): Generated videos maintain spatiotemporal continuity Validation method: - Generate N videos with different seeds - Compute mean frame-to-frame pixel difference - Check that differences are smooth (not explosive) """ print("\n" + "=" * 60) print("VALIDATION: Predict2.5 Temporal Consistency") print("Paper Reference: Section 4.2 - Temporal Coherence") print("=" * 60) from cosmos.utils_video import compute_temporal_smoothness, load_video_frames from cosmos.infer_predict import predict_text2world import tempfile results = { "test_name": "predict_temporal_consistency", "paper_section": "Section 4.2", "paper_claim": "Generated videos maintain reasonable spatiotemporal continuity", "num_samples": num_samples, "samples": [], "overall_pass": False, "timestamp": datetime.now().isoformat() } prompt = "A ball rolling slowly across a flat surface, simple motion, smooth" for i in range(num_samples): seed = seed_base + i output_path = tempfile.mktemp(suffix=".mp4") print(f"\nSample {i+1}/{num_samples} (seed={seed})") try: result = predict_text2world( prompt=prompt, num_frames=25, # ~1.5s at 16fps height=480, width=720, num_inference_steps=15, # Faster for validation seed=seed, output_path=output_path ) # Load frames and compute smoothness frames = load_video_frames(output_path) smoothness = compute_temporal_smoothness(frames) sample_result = { "seed": seed, "num_frames": result['num_frames'], "smoothness": smoothness, "video_path": output_path, "pass": smoothness['mean_diff'] < 50 # Threshold for "smooth" } print(f" Mean frame diff: {smoothness['mean_diff']:.2f}") print(f" Max frame diff: {smoothness['max_diff']:.2f}") print(f" Pass: {sample_result['pass']}") results['samples'].append(sample_result) except Exception as e: print(f" ERROR: {e}") results['samples'].append({ "seed": seed, "error": str(e), "pass": False }) # Overall assessment passed = sum(1 for s in results['samples'] if s.get('pass', False)) results['passed_samples'] = passed results['overall_pass'] = passed >= num_samples // 2 + 1 # Majority pass print(f"\nOverall: {passed}/{num_samples} samples passed") print(f"Validation {'PASSED' if results['overall_pass'] else 'FAILED'}") return results def validate_predict_reproducibility(seed: int = 42) -> Dict[str, Any]: """ Validate reproducibility of Predict2.5 outputs Paper claim (Section 5): Fixed seeds produce deterministic outputs Validation method: - Run same prompt with same seed twice - Compare output frame by frame - Verify outputs are identical (or near-identical) """ print("\n" + "=" * 60) print("VALIDATION: Predict2.5 Reproducibility") print("Paper Reference: Section 5 - Reproducibility") print("=" * 60) from cosmos.utils_video import load_video_frames, compute_ssim from cosmos.infer_predict import predict_text2world import tempfile import numpy as np results = { "test_name": "predict_reproducibility", "paper_section": "Section 5", "paper_claim": "Fixed random seeds produce deterministic outputs", "seed": seed, "overall_pass": False, "timestamp": datetime.now().isoformat() } prompt = "A simple test scene with a single red sphere" # Run 1 print("\nRun 1...") output1 = tempfile.mktemp(suffix="_run1.mp4") result1 = predict_text2world( prompt=prompt, num_frames=17, height=480, width=720, num_inference_steps=10, seed=seed, output_path=output1 ) # Run 2 (same parameters) print("\nRun 2...") output2 = tempfile.mktemp(suffix="_run2.mp4") result2 = predict_text2world( prompt=prompt, num_frames=17, height=480, width=720, num_inference_steps=10, seed=seed, output_path=output2 ) # Compare outputs print("\nComparing outputs...") frames1 = load_video_frames(output1) frames2 = load_video_frames(output2) if len(frames1) != len(frames2): print(f" Frame count mismatch: {len(frames1)} vs {len(frames2)}") results['error'] = "Frame count mismatch" return results # Compute SSIM for each frame pair ssim_scores = [] for i, (f1, f2) in enumerate(zip(frames1, frames2)): ssim = compute_ssim(f1, f2) ssim_scores.append(ssim) mean_ssim = np.mean(ssim_scores) min_ssim = np.min(ssim_scores) results['mean_ssim'] = float(mean_ssim) results['min_ssim'] = float(min_ssim) results['overall_pass'] = mean_ssim > 0.95 # Very high similarity expected print(f" Mean SSIM: {mean_ssim:.4f}") print(f" Min SSIM: {min_ssim:.4f}") print(f" Validation {'PASSED' if results['overall_pass'] else 'FAILED'}") return results def validate_transfer_structure_preservation( control_type: str = "edge" ) -> Dict[str, Any]: """ Validate structure preservation in Transfer2.5 outputs Paper claim (Section 3.2): Transfer preserves structural consistency Validation method: - Create test video with known structure - Apply style transfer - Compare edge maps of input and output - Verify edges are preserved (high correlation) """ print("\n" + "=" * 60) print("VALIDATION: Transfer2.5 Structure Preservation") print("Paper Reference: Section 3.2 - World-to-World Translation") print("=" * 60) from cosmos.utils_video import ( create_test_video, load_video_frames, extract_edges, compute_ssim ) from cosmos.infer_transfer import transfer_video import tempfile import numpy as np results = { "test_name": "transfer_structure_preservation", "paper_section": "Section 3.2", "paper_claim": "Cosmos-Transfer2.5 preserves structural consistency during domain transfer", "control_type": control_type, "overall_pass": False, "timestamp": datetime.now().isoformat() } # Create test video with clear structure test_video = create_test_video(num_frames=17, width=320, height=240) output_path = tempfile.mktemp(suffix="_transfer.mp4") try: print("\nRunning transfer...") result = transfer_video( input_video=test_video, prompt="Transform to nighttime scene with city lights", control_type=control_type, num_inference_steps=10, seed=42, output_path=output_path ) # Compare edge maps print("\nComparing structure (edge maps)...") input_frames = load_video_frames(test_video) output_frames = load_video_frames(output_path) edge_similarities = [] for i, (inp, out) in enumerate(zip(input_frames[:5], output_frames[:5])): inp_edges = extract_edges(inp) out_edges = extract_edges(out) ssim = compute_ssim(inp_edges, out_edges) edge_similarities.append(ssim) mean_edge_ssim = np.mean(edge_similarities) results['mean_edge_ssim'] = float(mean_edge_ssim) results['edge_similarities'] = [float(s) for s in edge_similarities] # Structure is preserved if edge SSIM > 0.3 (some similarity expected) results['overall_pass'] = mean_edge_ssim > 0.3 print(f" Mean edge SSIM: {mean_edge_ssim:.4f}") print(f" Validation {'PASSED' if results['overall_pass'] else 'FAILED'}") except Exception as e: print(f" ERROR: {e}") results['error'] = str(e) finally: # Cleanup if os.path.exists(test_video): os.remove(test_video) return results def validate_transfer_domain_change() -> Dict[str, Any]: """ Validate that Transfer2.5 actually changes the domain Paper claim (Section 4.3): Model performs world-to-world translation Validation method: - Apply day->night transfer - Verify output differs from input (different domain) - While still maintaining some structure """ print("\n" + "=" * 60) print("VALIDATION: Transfer2.5 Domain Change") print("Paper Reference: Section 4.3 - Domain Transfer") print("=" * 60) from cosmos.utils_video import ( create_test_video, load_video_frames, compute_ssim ) from cosmos.infer_transfer import transfer_style import tempfile import numpy as np results = { "test_name": "transfer_domain_change", "paper_section": "Section 4.3", "paper_claim": "Model can perform world-to-world translation (e.g., day->night)", "overall_pass": False, "timestamp": datetime.now().isoformat() } test_video = create_test_video(num_frames=17, width=320, height=240) output_path = tempfile.mktemp(suffix="_domain.mp4") try: print("\nApplying day -> night transfer...") result = transfer_style( input_video=test_video, source_style="daytime", target_style="nighttime with city lights", control_type="blur", num_inference_steps=10, seed=42, output_path=output_path ) # Compare input and output input_frames = load_video_frames(test_video) output_frames = load_video_frames(output_path) # Compute pixel-level similarity (should be different but not completely) similarities = [] for inp, out in zip(input_frames[:5], output_frames[:5]): ssim = compute_ssim(inp, out) similarities.append(ssim) mean_ssim = np.mean(similarities) results['mean_ssim'] = float(mean_ssim) # Domain change successful if: # - Output is different from input (SSIM < 0.9) # - But not completely random (SSIM > 0.1) results['overall_pass'] = 0.1 < mean_ssim < 0.9 print(f" Mean SSIM (input vs output): {mean_ssim:.4f}") print(f" Domain change detected: {results['overall_pass']}") print(f" Validation {'PASSED' if results['overall_pass'] else 'FAILED'}") except Exception as e: print(f" ERROR: {e}") results['error'] = str(e) finally: if os.path.exists(test_video): os.remove(test_video) return results def run_all_validations(skip_transfer: bool = False) -> Dict[str, Any]: """ Run all paper consistency validations Args: skip_transfer: Skip Transfer2.5 tests if VRAM is limited Returns: Combined validation results """ print("\n" + "=" * 70) print("PAPER CONSISTENCY VALIDATION") print("Reference: arXiv 2511.00062") print("=" * 70) all_results = { "paper": "arXiv 2511.00062 - World Simulation with Video Foundation Models for Physical AI", "timestamp": datetime.now().isoformat(), "validations": {} } # Predict2.5 Validations print("\n[PREDICT2.5 VALIDATIONS]") all_results['validations']['predict_temporal'] = validate_predict_temporal_consistency( num_samples=2 # Reduced for speed ) all_results['validations']['predict_reproducibility'] = validate_predict_reproducibility() # Transfer2.5 Validations if not skip_transfer: print("\n[TRANSFER2.5 VALIDATIONS]") # Clear Predict model first from cosmos.loaders import clear_model_cache clear_model_cache() all_results['validations']['transfer_structure'] = validate_transfer_structure_preservation() all_results['validations']['transfer_domain'] = validate_transfer_domain_change() else: print("\n[TRANSFER2.5 VALIDATIONS - SKIPPED]") all_results['validations']['transfer_skipped'] = True # Summary print("\n" + "=" * 70) print("VALIDATION SUMMARY") print("=" * 70) passed_count = 0 total_count = 0 for name, result in all_results['validations'].items(): if isinstance(result, dict) and 'overall_pass' in result: total_count += 1 if result['overall_pass']: passed_count += 1 status = "PASSED" if result['overall_pass'] else "FAILED" print(f" {name}: {status}") all_results['summary'] = { "passed": passed_count, "total": total_count, "overall_pass": passed_count == total_count } print(f"\nOverall: {passed_count}/{total_count} validations passed") return all_results def save_validation_report(results: Dict[str, Any], output_path: str): """Save validation results to JSON file""" with open(output_path, 'w') as f: json.dump(results, f, indent=2) print(f"\nValidation report saved to: {output_path}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Run paper consistency validations") parser.add_argument("--skip-transfer", action="store_true", help="Skip Transfer2.5 tests") parser.add_argument("--output", default="validation_results.json", help="Output file") args = parser.parse_args() results = run_all_validations(skip_transfer=args.skip_transfer) save_validation_report(results, args.output)