""" Video I/O utilities: load frames + audio from a video file, save frames back to video with audio, and trim audio to match video duration. """ import os import subprocess import tempfile from pathlib import Path import numpy as np from PIL import Image # --------------------------------------------------------------------------- # Loading # --------------------------------------------------------------------------- def load_video_frames( path: str, fps: float = 24.0, max_frames: int | None = None, ) -> tuple[np.ndarray, float]: """ Decode video frames to a uint8 numpy array [N, H, W, 3]. Returns (frames, actual_fps). Uses decord when available; falls back to opencv. """ try: import decord decord.bridge.set_bridge("native") vr = decord.VideoReader(path, ctx=decord.cpu(0)) actual_fps = float(vr.get_avg_fps()) total = len(vr) if max_frames is not None: total = min(total, max_frames) indices = list(range(total)) frames = vr.get_batch(indices).asnumpy() # [N, H, W, 3] return frames, actual_fps except ImportError: pass import cv2 cap = cv2.VideoCapture(path) actual_fps = cap.get(cv2.CAP_PROP_FPS) or fps frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) if max_frames is not None and len(frames) >= max_frames: break cap.release() return np.stack(frames, axis=0), actual_fps def extract_audio(video_path: str, output_path: str) -> bool: """Extract audio track from video to a WAV file. Returns False if no audio.""" result = subprocess.run( [ "ffprobe", "-v", "quiet", "-select_streams", "a", "-show_entries", "stream=codec_type", "-of", "csv=p=0", video_path, ], capture_output=True, text=True, ) if "audio" not in result.stdout: return False subprocess.run( [ "ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2", output_path, ], capture_output=True, check=True, ) return True # --------------------------------------------------------------------------- # Saving # --------------------------------------------------------------------------- def save_video( frames: np.ndarray, fps: float, output_path: str, audio_path: str | None = None, audio_duration: float | None = None, crf: int = 19, ) -> str: """ Encode frames [N, H, W, 3] uint8 to an mp4 file. Optionally mux audio_path (trimmed to audio_duration seconds if provided). Returns the path to the written file. """ N, H, W, _ = frames.shape tmp_video = output_path + ".noaudio.mp4" # Write raw video with ffmpeg via stdin pipe cmd = [ "ffmpeg", "-y", "-f", "rawvideo", "-vcodec", "rawvideo", "-s", f"{W}x{H}", "-pix_fmt", "rgb24", "-r", str(fps), "-i", "pipe:0", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-crf", str(crf), "-preset", "fast", tmp_video, ] proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) for frame in frames: proc.stdin.write(frame.tobytes()) proc.stdin.close() proc.wait() if audio_path and os.path.exists(audio_path): duration_flag = ["-t", str(audio_duration)] if audio_duration else [] subprocess.run( [ "ffmpeg", "-y", "-i", tmp_video, "-i", audio_path, *duration_flag, "-c:v", "copy", "-c:a", "aac", "-b:a", "192k", "-shortest", output_path, ], capture_output=True, check=True, ) os.remove(tmp_video) else: os.rename(tmp_video, output_path) return output_path # --------------------------------------------------------------------------- # Resolution helpers # --------------------------------------------------------------------------- def align_to(value: int, multiple: int = 32) -> int: """Round value up to the nearest multiple.""" return ((value + multiple - 1) // multiple) * multiple def compute_target_size( orig_w: int, orig_h: int, base_resolution: int = 768, multiple: int = 32, ) -> tuple[int, int]: """ Scale the longer edge to base_resolution, preserving aspect ratio, then align both dimensions to `multiple`. """ scale = base_resolution / max(orig_w, orig_h) new_w = align_to(int(orig_w * scale), multiple) new_h = align_to(int(orig_h * scale), multiple) return new_w, new_h def resize_frames(frames: np.ndarray, target_w: int, target_h: int) -> np.ndarray: """Resize [N, H, W, 3] frames to target_w x target_h.""" if frames.shape[2] == target_w and frames.shape[1] == target_h: return frames out = np.empty((len(frames), target_h, target_w, 3), dtype=np.uint8) for i, f in enumerate(frames): out[i] = np.array(Image.fromarray(f).resize((target_w, target_h), Image.LANCZOS)) return out def frames_for_duration(fps: float, duration: float) -> int: """Return frame count aligned to LTX-2.3 requirements: ((n * fps) // 8) * 8 + 1.""" raw = int(duration * fps) return ((raw // 8) * 8) + 1