wbw2000 commited on
Commit
ef68d21
·
verified ·
1 Parent(s): 585c8d6

Upload tests/paper_validation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tests/paper_validation.py +477 -0
tests/paper_validation.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paper Consistency Validation for Cosmos Models
3
+ Reference: arXiv 2511.00062 - World Simulation with Video Foundation Models for Physical AI
4
+
5
+ This module implements minimal reproducible tests to verify consistency with paper claims.
6
+ """
7
+ import sys
8
+ import os
9
+ import json
10
+ import time
11
+ from pathlib import Path
12
+ from typing import List, Dict, Any
13
+ from datetime import datetime
14
+
15
+ # Add parent directory to path
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+
19
+ # Paper Reference Points for Validation
20
+ PAPER_CLAIMS = {
21
+ "predict_unified_generation": {
22
+ "section": "Section 3.1 - Unified World Generation",
23
+ "claim": "Cosmos-Predict2.5 unifies Text2World, Image2World, and Video2World into a single model",
24
+ "validation": "Verify model can handle all three input modalities"
25
+ },
26
+ "predict_temporal_consistency": {
27
+ "section": "Section 4.2 - Temporal Coherence",
28
+ "claim": "Generated videos maintain reasonable spatiotemporal continuity in short-term prediction",
29
+ "validation": "Measure frame-to-frame pixel difference to verify smooth transitions"
30
+ },
31
+ "predict_reproducibility": {
32
+ "section": "Section 5 - Reproducibility",
33
+ "claim": "Fixed random seeds produce deterministic outputs",
34
+ "validation": "Run same prompt with same seed twice, verify identical outputs"
35
+ },
36
+ "transfer_structure_preservation": {
37
+ "section": "Section 3.2 - World-to-World Translation",
38
+ "claim": "Cosmos-Transfer2.5 preserves structural consistency during domain transfer",
39
+ "validation": "Compare edge maps between input and output to verify structure preservation"
40
+ },
41
+ "transfer_domain_adaptation": {
42
+ "section": "Section 4.3 - Domain Transfer",
43
+ "claim": "Model can perform world-to-world translation (e.g., day->night, clear->rain)",
44
+ "validation": "Verify output differs from input while maintaining structure"
45
+ }
46
+ }
47
+
48
+
49
+ def validate_predict_temporal_consistency(
50
+ num_samples: int = 3,
51
+ seed_base: int = 42
52
+ ) -> Dict[str, Any]:
53
+ """
54
+ Validate temporal consistency of Predict2.5 outputs
55
+
56
+ Paper claim (Section 4.2): Generated videos maintain spatiotemporal continuity
57
+
58
+ Validation method:
59
+ - Generate N videos with different seeds
60
+ - Compute mean frame-to-frame pixel difference
61
+ - Check that differences are smooth (not explosive)
62
+ """
63
+ print("\n" + "=" * 60)
64
+ print("VALIDATION: Predict2.5 Temporal Consistency")
65
+ print("Paper Reference: Section 4.2 - Temporal Coherence")
66
+ print("=" * 60)
67
+
68
+ from cosmos.utils_video import compute_temporal_smoothness, load_video_frames
69
+ from cosmos.infer_predict import predict_text2world
70
+ import tempfile
71
+
72
+ results = {
73
+ "test_name": "predict_temporal_consistency",
74
+ "paper_section": "Section 4.2",
75
+ "paper_claim": "Generated videos maintain reasonable spatiotemporal continuity",
76
+ "num_samples": num_samples,
77
+ "samples": [],
78
+ "overall_pass": False,
79
+ "timestamp": datetime.now().isoformat()
80
+ }
81
+
82
+ prompt = "A ball rolling slowly across a flat surface, simple motion, smooth"
83
+
84
+ for i in range(num_samples):
85
+ seed = seed_base + i
86
+ output_path = tempfile.mktemp(suffix=".mp4")
87
+
88
+ print(f"\nSample {i+1}/{num_samples} (seed={seed})")
89
+
90
+ try:
91
+ result = predict_text2world(
92
+ prompt=prompt,
93
+ num_frames=25, # ~1.5s at 16fps
94
+ height=480,
95
+ width=720,
96
+ num_inference_steps=15, # Faster for validation
97
+ seed=seed,
98
+ output_path=output_path
99
+ )
100
+
101
+ # Load frames and compute smoothness
102
+ frames = load_video_frames(output_path)
103
+ smoothness = compute_temporal_smoothness(frames)
104
+
105
+ sample_result = {
106
+ "seed": seed,
107
+ "num_frames": result['num_frames'],
108
+ "smoothness": smoothness,
109
+ "video_path": output_path,
110
+ "pass": smoothness['mean_diff'] < 50 # Threshold for "smooth"
111
+ }
112
+
113
+ print(f" Mean frame diff: {smoothness['mean_diff']:.2f}")
114
+ print(f" Max frame diff: {smoothness['max_diff']:.2f}")
115
+ print(f" Pass: {sample_result['pass']}")
116
+
117
+ results['samples'].append(sample_result)
118
+
119
+ except Exception as e:
120
+ print(f" ERROR: {e}")
121
+ results['samples'].append({
122
+ "seed": seed,
123
+ "error": str(e),
124
+ "pass": False
125
+ })
126
+
127
+ # Overall assessment
128
+ passed = sum(1 for s in results['samples'] if s.get('pass', False))
129
+ results['passed_samples'] = passed
130
+ results['overall_pass'] = passed >= num_samples // 2 + 1 # Majority pass
131
+
132
+ print(f"\nOverall: {passed}/{num_samples} samples passed")
133
+ print(f"Validation {'PASSED' if results['overall_pass'] else 'FAILED'}")
134
+
135
+ return results
136
+
137
+
138
+ def validate_predict_reproducibility(seed: int = 42) -> Dict[str, Any]:
139
+ """
140
+ Validate reproducibility of Predict2.5 outputs
141
+
142
+ Paper claim (Section 5): Fixed seeds produce deterministic outputs
143
+
144
+ Validation method:
145
+ - Run same prompt with same seed twice
146
+ - Compare output frame by frame
147
+ - Verify outputs are identical (or near-identical)
148
+ """
149
+ print("\n" + "=" * 60)
150
+ print("VALIDATION: Predict2.5 Reproducibility")
151
+ print("Paper Reference: Section 5 - Reproducibility")
152
+ print("=" * 60)
153
+
154
+ from cosmos.utils_video import load_video_frames, compute_ssim
155
+ from cosmos.infer_predict import predict_text2world
156
+ import tempfile
157
+ import numpy as np
158
+
159
+ results = {
160
+ "test_name": "predict_reproducibility",
161
+ "paper_section": "Section 5",
162
+ "paper_claim": "Fixed random seeds produce deterministic outputs",
163
+ "seed": seed,
164
+ "overall_pass": False,
165
+ "timestamp": datetime.now().isoformat()
166
+ }
167
+
168
+ prompt = "A simple test scene with a single red sphere"
169
+
170
+ # Run 1
171
+ print("\nRun 1...")
172
+ output1 = tempfile.mktemp(suffix="_run1.mp4")
173
+ result1 = predict_text2world(
174
+ prompt=prompt,
175
+ num_frames=17,
176
+ height=480,
177
+ width=720,
178
+ num_inference_steps=10,
179
+ seed=seed,
180
+ output_path=output1
181
+ )
182
+
183
+ # Run 2 (same parameters)
184
+ print("\nRun 2...")
185
+ output2 = tempfile.mktemp(suffix="_run2.mp4")
186
+ result2 = predict_text2world(
187
+ prompt=prompt,
188
+ num_frames=17,
189
+ height=480,
190
+ width=720,
191
+ num_inference_steps=10,
192
+ seed=seed,
193
+ output_path=output2
194
+ )
195
+
196
+ # Compare outputs
197
+ print("\nComparing outputs...")
198
+ frames1 = load_video_frames(output1)
199
+ frames2 = load_video_frames(output2)
200
+
201
+ if len(frames1) != len(frames2):
202
+ print(f" Frame count mismatch: {len(frames1)} vs {len(frames2)}")
203
+ results['error'] = "Frame count mismatch"
204
+ return results
205
+
206
+ # Compute SSIM for each frame pair
207
+ ssim_scores = []
208
+ for i, (f1, f2) in enumerate(zip(frames1, frames2)):
209
+ ssim = compute_ssim(f1, f2)
210
+ ssim_scores.append(ssim)
211
+
212
+ mean_ssim = np.mean(ssim_scores)
213
+ min_ssim = np.min(ssim_scores)
214
+
215
+ results['mean_ssim'] = float(mean_ssim)
216
+ results['min_ssim'] = float(min_ssim)
217
+ results['overall_pass'] = mean_ssim > 0.95 # Very high similarity expected
218
+
219
+ print(f" Mean SSIM: {mean_ssim:.4f}")
220
+ print(f" Min SSIM: {min_ssim:.4f}")
221
+ print(f" Validation {'PASSED' if results['overall_pass'] else 'FAILED'}")
222
+
223
+ return results
224
+
225
+
226
+ def validate_transfer_structure_preservation(
227
+ control_type: str = "edge"
228
+ ) -> Dict[str, Any]:
229
+ """
230
+ Validate structure preservation in Transfer2.5 outputs
231
+
232
+ Paper claim (Section 3.2): Transfer preserves structural consistency
233
+
234
+ Validation method:
235
+ - Create test video with known structure
236
+ - Apply style transfer
237
+ - Compare edge maps of input and output
238
+ - Verify edges are preserved (high correlation)
239
+ """
240
+ print("\n" + "=" * 60)
241
+ print("VALIDATION: Transfer2.5 Structure Preservation")
242
+ print("Paper Reference: Section 3.2 - World-to-World Translation")
243
+ print("=" * 60)
244
+
245
+ from cosmos.utils_video import (
246
+ create_test_video, load_video_frames, extract_edges, compute_ssim
247
+ )
248
+ from cosmos.infer_transfer import transfer_video
249
+ import tempfile
250
+ import numpy as np
251
+
252
+ results = {
253
+ "test_name": "transfer_structure_preservation",
254
+ "paper_section": "Section 3.2",
255
+ "paper_claim": "Cosmos-Transfer2.5 preserves structural consistency during domain transfer",
256
+ "control_type": control_type,
257
+ "overall_pass": False,
258
+ "timestamp": datetime.now().isoformat()
259
+ }
260
+
261
+ # Create test video with clear structure
262
+ test_video = create_test_video(num_frames=17, width=320, height=240)
263
+ output_path = tempfile.mktemp(suffix="_transfer.mp4")
264
+
265
+ try:
266
+ print("\nRunning transfer...")
267
+ result = transfer_video(
268
+ input_video=test_video,
269
+ prompt="Transform to nighttime scene with city lights",
270
+ control_type=control_type,
271
+ num_inference_steps=10,
272
+ seed=42,
273
+ output_path=output_path
274
+ )
275
+
276
+ # Compare edge maps
277
+ print("\nComparing structure (edge maps)...")
278
+ input_frames = load_video_frames(test_video)
279
+ output_frames = load_video_frames(output_path)
280
+
281
+ edge_similarities = []
282
+ for i, (inp, out) in enumerate(zip(input_frames[:5], output_frames[:5])):
283
+ inp_edges = extract_edges(inp)
284
+ out_edges = extract_edges(out)
285
+ ssim = compute_ssim(inp_edges, out_edges)
286
+ edge_similarities.append(ssim)
287
+
288
+ mean_edge_ssim = np.mean(edge_similarities)
289
+ results['mean_edge_ssim'] = float(mean_edge_ssim)
290
+ results['edge_similarities'] = [float(s) for s in edge_similarities]
291
+
292
+ # Structure is preserved if edge SSIM > 0.3 (some similarity expected)
293
+ results['overall_pass'] = mean_edge_ssim > 0.3
294
+
295
+ print(f" Mean edge SSIM: {mean_edge_ssim:.4f}")
296
+ print(f" Validation {'PASSED' if results['overall_pass'] else 'FAILED'}")
297
+
298
+ except Exception as e:
299
+ print(f" ERROR: {e}")
300
+ results['error'] = str(e)
301
+
302
+ finally:
303
+ # Cleanup
304
+ if os.path.exists(test_video):
305
+ os.remove(test_video)
306
+
307
+ return results
308
+
309
+
310
+ def validate_transfer_domain_change() -> Dict[str, Any]:
311
+ """
312
+ Validate that Transfer2.5 actually changes the domain
313
+
314
+ Paper claim (Section 4.3): Model performs world-to-world translation
315
+
316
+ Validation method:
317
+ - Apply day->night transfer
318
+ - Verify output differs from input (different domain)
319
+ - While still maintaining some structure
320
+ """
321
+ print("\n" + "=" * 60)
322
+ print("VALIDATION: Transfer2.5 Domain Change")
323
+ print("Paper Reference: Section 4.3 - Domain Transfer")
324
+ print("=" * 60)
325
+
326
+ from cosmos.utils_video import (
327
+ create_test_video, load_video_frames, compute_ssim
328
+ )
329
+ from cosmos.infer_transfer import transfer_style
330
+ import tempfile
331
+ import numpy as np
332
+
333
+ results = {
334
+ "test_name": "transfer_domain_change",
335
+ "paper_section": "Section 4.3",
336
+ "paper_claim": "Model can perform world-to-world translation (e.g., day->night)",
337
+ "overall_pass": False,
338
+ "timestamp": datetime.now().isoformat()
339
+ }
340
+
341
+ test_video = create_test_video(num_frames=17, width=320, height=240)
342
+ output_path = tempfile.mktemp(suffix="_domain.mp4")
343
+
344
+ try:
345
+ print("\nApplying day -> night transfer...")
346
+ result = transfer_style(
347
+ input_video=test_video,
348
+ source_style="daytime",
349
+ target_style="nighttime with city lights",
350
+ control_type="blur",
351
+ num_inference_steps=10,
352
+ seed=42,
353
+ output_path=output_path
354
+ )
355
+
356
+ # Compare input and output
357
+ input_frames = load_video_frames(test_video)
358
+ output_frames = load_video_frames(output_path)
359
+
360
+ # Compute pixel-level similarity (should be different but not completely)
361
+ similarities = []
362
+ for inp, out in zip(input_frames[:5], output_frames[:5]):
363
+ ssim = compute_ssim(inp, out)
364
+ similarities.append(ssim)
365
+
366
+ mean_ssim = np.mean(similarities)
367
+ results['mean_ssim'] = float(mean_ssim)
368
+
369
+ # Domain change successful if:
370
+ # - Output is different from input (SSIM < 0.9)
371
+ # - But not completely random (SSIM > 0.1)
372
+ results['overall_pass'] = 0.1 < mean_ssim < 0.9
373
+
374
+ print(f" Mean SSIM (input vs output): {mean_ssim:.4f}")
375
+ print(f" Domain change detected: {results['overall_pass']}")
376
+ print(f" Validation {'PASSED' if results['overall_pass'] else 'FAILED'}")
377
+
378
+ except Exception as e:
379
+ print(f" ERROR: {e}")
380
+ results['error'] = str(e)
381
+
382
+ finally:
383
+ if os.path.exists(test_video):
384
+ os.remove(test_video)
385
+
386
+ return results
387
+
388
+
389
+ def run_all_validations(skip_transfer: bool = False) -> Dict[str, Any]:
390
+ """
391
+ Run all paper consistency validations
392
+
393
+ Args:
394
+ skip_transfer: Skip Transfer2.5 tests if VRAM is limited
395
+
396
+ Returns:
397
+ Combined validation results
398
+ """
399
+ print("\n" + "=" * 70)
400
+ print("PAPER CONSISTENCY VALIDATION")
401
+ print("Reference: arXiv 2511.00062")
402
+ print("=" * 70)
403
+
404
+ all_results = {
405
+ "paper": "arXiv 2511.00062 - World Simulation with Video Foundation Models for Physical AI",
406
+ "timestamp": datetime.now().isoformat(),
407
+ "validations": {}
408
+ }
409
+
410
+ # Predict2.5 Validations
411
+ print("\n[PREDICT2.5 VALIDATIONS]")
412
+
413
+ all_results['validations']['predict_temporal'] = validate_predict_temporal_consistency(
414
+ num_samples=2 # Reduced for speed
415
+ )
416
+
417
+ all_results['validations']['predict_reproducibility'] = validate_predict_reproducibility()
418
+
419
+ # Transfer2.5 Validations
420
+ if not skip_transfer:
421
+ print("\n[TRANSFER2.5 VALIDATIONS]")
422
+
423
+ # Clear Predict model first
424
+ from cosmos.loaders import clear_model_cache
425
+ clear_model_cache()
426
+
427
+ all_results['validations']['transfer_structure'] = validate_transfer_structure_preservation()
428
+ all_results['validations']['transfer_domain'] = validate_transfer_domain_change()
429
+ else:
430
+ print("\n[TRANSFER2.5 VALIDATIONS - SKIPPED]")
431
+ all_results['validations']['transfer_skipped'] = True
432
+
433
+ # Summary
434
+ print("\n" + "=" * 70)
435
+ print("VALIDATION SUMMARY")
436
+ print("=" * 70)
437
+
438
+ passed_count = 0
439
+ total_count = 0
440
+
441
+ for name, result in all_results['validations'].items():
442
+ if isinstance(result, dict) and 'overall_pass' in result:
443
+ total_count += 1
444
+ if result['overall_pass']:
445
+ passed_count += 1
446
+ status = "PASSED" if result['overall_pass'] else "FAILED"
447
+ print(f" {name}: {status}")
448
+
449
+ all_results['summary'] = {
450
+ "passed": passed_count,
451
+ "total": total_count,
452
+ "overall_pass": passed_count == total_count
453
+ }
454
+
455
+ print(f"\nOverall: {passed_count}/{total_count} validations passed")
456
+
457
+ return all_results
458
+
459
+
460
+ def save_validation_report(results: Dict[str, Any], output_path: str):
461
+ """Save validation results to JSON file"""
462
+ with open(output_path, 'w') as f:
463
+ json.dump(results, f, indent=2)
464
+ print(f"\nValidation report saved to: {output_path}")
465
+
466
+
467
+ if __name__ == "__main__":
468
+ import argparse
469
+
470
+ parser = argparse.ArgumentParser(description="Run paper consistency validations")
471
+ parser.add_argument("--skip-transfer", action="store_true", help="Skip Transfer2.5 tests")
472
+ parser.add_argument("--output", default="validation_results.json", help="Output file")
473
+
474
+ args = parser.parse_args()
475
+
476
+ results = run_all_validations(skip_transfer=args.skip_transfer)
477
+ save_validation_report(results, args.output)