""" Cosmos Predict2.5 Inference Text2World, Image2World, Video2World unified generation """ import torch import time from typing import Optional, Union, List from PIL import Image from pathlib import Path from .loaders import load_predict_pipeline, get_device_info from .utils_video import save_video, load_video_frames def predict_text2world( prompt: str, negative_prompt: str = "", num_frames: int = 49, height: int = 480, width: int = 720, num_inference_steps: int = 30, guidance_scale: float = 7.0, seed: int = 42, output_path: Optional[str] = None ) -> dict: """ Generate video from text prompt (Text2World) Args: prompt: Text description of the world to generate negative_prompt: What to avoid in generation num_frames: Number of output frames (default 49 = ~3s at 16fps) height: Output height (480 or 720) width: Output width (720 or 1280) num_inference_steps: Diffusion steps (more = better quality) guidance_scale: CFG scale (higher = more prompt adherence) seed: Random seed for reproducibility output_path: Path to save output video Returns: dict with video_path, frames, metadata """ start_time = time.time() # Load pipeline pipe = load_predict_pipeline() # Set seed for reproducibility generator = torch.Generator(device="cuda").manual_seed(seed) device_info = get_device_info() print(f"\n=== Predict2.5 Text2World Inference ===") print(f"Prompt: {prompt[:100]}...") print(f"Resolution: {width}x{height}, Frames: {num_frames}") print(f"Steps: {num_inference_steps}, CFG: {guidance_scale}, Seed: {seed}") print(f"Device: {device_info['name']}") # Run inference output = pipe( prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator ) frames = output.frames[0] inference_time = time.time() - start_time # Save video if output_path is None: output_path = f"output_predict_text2world_seed{seed}.mp4" save_video(frames, output_path, fps=16) result = { "video_path": output_path, "num_frames": len(frames), "resolution": f"{width}x{height}", "inference_time_s": round(inference_time, 2), "seed": seed, "prompt": prompt, "device": device_info['name'] } print(f"Video saved: {output_path} ({len(frames)} frames in {inference_time:.2f}s)") return result def predict_image2world( image: Union[str, Image.Image], prompt: str, negative_prompt: str = "", num_frames: int = 49, num_inference_steps: int = 30, guidance_scale: float = 7.0, seed: int = 42, output_path: Optional[str] = None ) -> dict: """ Generate video from image + prompt (Image2World) Args: image: Input image (path or PIL Image) prompt: Text description of the world evolution negative_prompt: What to avoid num_frames: Number of output frames num_inference_steps: Diffusion steps guidance_scale: CFG scale seed: Random seed output_path: Path to save output Returns: dict with video_path, frames, metadata """ start_time = time.time() # Load image if path provided if isinstance(image, str): image = Image.open(image).convert("RGB") # Get dimensions from image width, height = image.size # Ensure dimensions are multiples of 8 width = (width // 8) * 8 height = (height // 8) * 8 image = image.resize((width, height)) pipe = load_predict_pipeline() generator = torch.Generator(device="cuda").manual_seed(seed) device_info = get_device_info() print(f"\n=== Predict2.5 Image2World Inference ===") print(f"Input image: {width}x{height}") print(f"Prompt: {prompt[:100]}...") print(f"Frames: {num_frames}, Steps: {num_inference_steps}, Seed: {seed}") output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator ) frames = output.frames[0] inference_time = time.time() - start_time if output_path is None: output_path = f"output_predict_image2world_seed{seed}.mp4" save_video(frames, output_path, fps=16) result = { "video_path": output_path, "num_frames": len(frames), "resolution": f"{width}x{height}", "inference_time_s": round(inference_time, 2), "seed": seed, "prompt": prompt, "device": device_info['name'] } print(f"Video saved: {output_path} ({len(frames)} frames in {inference_time:.2f}s)") return result def predict_video2world( video: Union[str, List[Image.Image]], prompt: str, negative_prompt: str = "", num_frames: int = 49, num_inference_steps: int = 30, guidance_scale: float = 7.0, seed: int = 42, output_path: Optional[str] = None ) -> dict: """ Generate extended video from input video + prompt (Video2World) Args: video: Input video path or list of frames prompt: Text description of world evolution negative_prompt: What to avoid num_frames: Number of output frames num_inference_steps: Diffusion steps guidance_scale: CFG scale seed: Random seed output_path: Path to save output Returns: dict with video_path, frames, metadata """ start_time = time.time() # Load video frames if path provided if isinstance(video, str): input_frames = load_video_frames(video) else: input_frames = video pipe = load_predict_pipeline() generator = torch.Generator(device="cuda").manual_seed(seed) device_info = get_device_info() print(f"\n=== Predict2.5 Video2World Inference ===") print(f"Input frames: {len(input_frames)}") print(f"Prompt: {prompt[:100]}...") print(f"Output frames: {num_frames}, Steps: {num_inference_steps}, Seed: {seed}") output = pipe( video=input_frames, prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator ) frames = output.frames[0] inference_time = time.time() - start_time if output_path is None: output_path = f"output_predict_video2world_seed{seed}.mp4" save_video(frames, output_path, fps=16) result = { "video_path": output_path, "input_frames": len(input_frames), "output_frames": len(frames), "inference_time_s": round(inference_time, 2), "seed": seed, "prompt": prompt, "device": device_info['name'] } print(f"Video saved: {output_path} ({len(frames)} frames in {inference_time:.2f}s)") return result