""" WAN 2.2 Multi-Task Video Generation - Bilingual UI I2V: Lightning 14B (6 steps, FP8+AoT) T2V: Lightning 14B (4 steps, Lightning LoRA + FP8) V2V: ControlNet Depth (preserves motion/pose from source video) LoRA: from lkzd7/WAN2.2_LoraSet_NSFW (I2V only) """ import os import spaces import shutil import subprocess import copy import random import tempfile import warnings import time import gc import uuid from tqdm import tqdm import cv2 import numpy as np import torch from torch.nn import functional as F from PIL import Image import gradio as gr from diffusers import ( AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, SASolverScheduler, DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, ) from diffusers.models.transformers.transformer_wan import WanTransformer3DModel from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline from diffusers.utils.export_utils import export_to_video from diffusers.utils import load_video from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig import aoti import lora_loader # V2V ControlNet imports from wan_controlnet import WanControlnet from wan_transformer import CustomWanTransformer3DModel from wan_t2v_controlnet_pipeline import WanTextToVideoControlnetPipeline from controlnet_aux import MidasDetector os.environ["TOKENIZERS_PARALLELISM"] = "true" warnings.filterwarnings("ignore") def clear_vram(): gc.collect() torch.cuda.empty_cache() # ============ RIFE ============ get_timestamp_js = """ function() { const video = document.querySelector('#generated-video video'); if (video) { return video.currentTime; } return 0; } """ def extract_frame(video_path, timestamp): if not video_path: return None cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None fps = cap.get(cv2.CAP_FPS) target_frame_num = int(float(timestamp) * fps) total_frames = int(cap.get(cv2.CAP_FRAME_COUNT)) if target_frame_num >= total_frames: target_frame_num = total_frames - 1 cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num) ret, frame = cap.read() cap.release() if ret: return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return None if not os.path.exists("RIFEv4.26_0921.zip"): print("Downloading RIFE Model...") subprocess.run(["wget", "-q", "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip", "-O", "RIFEv4.26_0921.zip"], check=True) subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True) from train_log.RIFE_HDv3 import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") rife_model = Model() rife_model.load_model("train_log", -1) rife_model.eval() @torch.no_grad() def interpolate_bits(frames_np, multiplier=2, scale=1.0): if isinstance(frames_np, list): T = len(frames_np) H, W, C = frames_np[0].shape else: T, H, W, C = frames_np.shape if multiplier < 2: return list(frames_np) if isinstance(frames_np, np.ndarray) else frames_np n_interp = multiplier - 1 tmp = max(128, int(128 / scale)) ph = ((H - 1) // tmp + 1) * tmp pw = ((W - 1) // tmp + 1) * tmp padding = (0, pw - W, 0, ph - H) def to_tensor(frame_np): t = torch.from_numpy(frame_np).to(device) t = t.permute(2, 0, 1).unsqueeze(0) return F.pad(t, padding).half() def from_tensor(tensor): t = tensor[0, :, :H, :W] return t.permute(1, 2, 0).float().cpu().numpy() def make_inference(I0, I1, n): if rife_model.version >= 3.9: return [rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale) for i in range(n)] else: middle = rife_model.inference(I0, I1, scale) if n == 1: return [middle] first_half = make_inference(I0, middle, n//2) second_half = make_inference(middle, I1, n//2) return [*first_half, middle, *second_half] if n % 2 else [*first_half, *second_half] output_frames = [] I1 = to_tensor(frames_np[0]) with tqdm(total=T-1, desc="Interpolating", unit="frame") as pbar: for i in range(T - 1): I0 = I1 output_frames.append(from_tensor(I0)) I1 = to_tensor(frames_np[i+1]) for mid in make_inference(I0, I1, n_interp): output_frames.append(from_tensor(mid)) if (i + 1) % 50 == 0: pbar.update(50) pbar.update((T-1) % 50) output_frames.append(from_tensor(I1)) del I0, I1 torch.cuda.empty_cache() return output_frames # ============ Config ============ FIXED_FPS = 16 MAX_FRAMES_MODEL = 241 # ~15s@16fps, requires more VRAM/time MAX_SEED = np.iinfo(np.int32).max SCHEDULER_MAP = { "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler, "SASolver": SASolverScheduler, "DEISMultistep": DEISMultistepScheduler, "DPMSolverMultistepInverse": DPMSolverMultistepInverseScheduler, "UniPCMultistep": UniPCMultistepScheduler, "DPMSolverMultistep": DPMSolverMultistepScheduler, "DPMSolverSinglestep": DPMSolverSinglestepScheduler, } default_negative_prompt = ( "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " "still image, overall gray, worst quality, low quality, JPEG artifacts, ugly, incomplete, " "extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, " "malformed limbs, fused fingers, still frame, messy background, three legs, " "many people in background, walking backwards, watermark, text, signature" ) # ============ Load I2V Pipeline (Lightning, AoT compiled) ============ print("Loading I2V Pipeline (Lightning 14B)...") i2v_pipe = WanImageToVideoPipeline.from_pretrained( "TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING", torch_dtype=torch.bfloat16, ).to('cuda') i2v_original_scheduler = copy.deepcopy(i2v_pipe.scheduler) quantize_(i2v_pipe.text_encoder, Int8WeightOnlyConfig()) major, minor = torch.cuda.get_device_capability() supports_fp8 = (major > 8) or (major == 8 and minor >= 9) if supports_fp8: quantize_(i2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) quantize_(i2v_pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) aoti.aoti_blocks_load(i2v_pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da') aoti.aoti_blocks_load(i2v_pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da') else: quantize_(i2v_pipe.transformer, Int8WeightOnlyConfig()) quantize_(i2v_pipe.transformer_2, Int8WeightOnlyConfig()) # ============ T2V Pipeline (on-demand, 14B + Wan22 Lightning LoRA) ============ # Use T2V-A14B + Wan22 Lightning LoRA (separate HIGH/LOW for dual transformer) # Load on-demand with CPU offload to avoid OOM alongside I2V T2V_MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" T2V_LORA_REPO = "Kijai/WanVideo_comfy" T2V_LORA_HIGH = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors" T2V_LORA_LOW = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" t2v_pipe = None t2v_ready = False def load_t2v_pipeline(): """Load T2V 14B + Lightning LoRA on-demand with CPU offload.""" global t2v_pipe, t2v_ready if t2v_pipe is not None and t2v_ready: print("T2V pipeline reused from memory") return t2v_pipe print("Loading T2V Pipeline (14B + Lightning LoRA) first time...") # Move I2V components to CPU to make room i2v_pipe.to('cpu') clear_vram() t2v_vae = AutoencoderKLWan.from_pretrained(T2V_MODEL_ID, subfolder="vae", torch_dtype=torch.float32) t2v_pipe = WanPipeline.from_pretrained( T2V_MODEL_ID, transformer=WanTransformer3DModel.from_pretrained( 'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16', subfolder='transformer', torch_dtype=torch.bfloat16, ), transformer_2=WanTransformer3DModel.from_pretrained( 'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16', subfolder='transformer_2', torch_dtype=torch.bfloat16, ), vae=t2v_vae, torch_dtype=torch.bfloat16, ) # Load and fuse Lightning LoRAs (HIGH for transformer, LOW for transformer_2) print("Fusing Lightning LoRA HIGH (transformer)...") from safetensors.torch import load_file from huggingface_hub import hf_hub_download # Download LoRA files high_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_HIGH) low_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_LOW) # Load HIGH LoRA into transformer t2v_pipe.load_lora_weights(high_path, adapter_name="lightning_high") t2v_pipe.set_adapters(["lightning_high"], adapter_weights=[1.0]) t2v_pipe.fuse_lora(adapter_names=["lightning_high"], lora_scale=1.0, components=["transformer"]) t2v_pipe.unload_lora_weights() # Load LOW LoRA into transformer_2 print("Fusing Lightning LoRA LOW (transformer_2)...") t2v_pipe.load_lora_weights(low_path, adapter_name="lightning_low", load_into_transformer_2=True) t2v_pipe.set_adapters(["lightning_low"], adapter_weights=[1.0]) t2v_pipe.fuse_lora(adapter_names=["lightning_low"], lora_scale=1.0, components=["transformer_2"]) t2v_pipe.unload_lora_weights() # Use model CPU offload — only one component on GPU at a time t2v_pipe.enable_model_cpu_offload() t2v_ready = True print("T2V pipeline ready (14B + Lightning + CPU offload)") return t2v_pipe def unload_t2v_pipeline(): """Restore I2V to GPU after T2V is done.""" clear_vram() i2v_pipe.to('cuda') print("I2V restored to GPU") # Keep cache for on-demand T2V loading # ============ V2V Pipeline (ControlNet Depth, on-demand) ============ V2V_BASE_MODEL = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" V2V_CONTROLNET_MODEL = "TheDenk/wan2.2-ti2v-5b-controlnet-depth-v1" v2v_pipe = None v2v_ready = False depth_processor = None def load_v2v_pipeline(): """Load V2V ControlNet pipeline on-demand. Uses TI2V-5B + Depth ControlNet.""" global v2v_pipe, v2v_ready, depth_processor # Move I2V to CPU i2v_pipe.to('cpu') clear_vram() if v2v_pipe is not None and v2v_ready: v2v_pipe.to('cuda') print("V2V ControlNet restored to GPU") return v2v_pipe print("Loading V2V ControlNet Pipeline (5B + Depth) first time...") v2v_vae = AutoencoderKLWan.from_pretrained(V2V_BASE_MODEL, subfolder="vae", torch_dtype=torch.float32) v2v_transformer = CustomWanTransformer3DModel.from_pretrained(V2V_BASE_MODEL, subfolder="transformer", torch_dtype=torch.bfloat16) v2v_controlnet = WanControlnet.from_pretrained(V2V_CONTROLNET_MODEL, torch_dtype=torch.bfloat16) v2v_pipe = WanTextToVideoControlnetPipeline.from_pretrained( pretrained_model_name_or_path=V2V_BASE_MODEL, controlnet=v2v_controlnet, transformer=v2v_transformer, vae=v2v_vae, torch_dtype=torch.bfloat16, ) v2v_pipe.scheduler = UniPCMultistepScheduler.from_config(v2v_pipe.scheduler.config, flow_shift=5.0) v2v_pipe.enable_model_cpu_offload() # Load depth processor if depth_processor is None: print("Loading MidasDetector for depth estimation...") depth_processor = MidasDetector.from_pretrained('lllyasviel/Annotators') v2v_ready = True print("V2V ControlNet pipeline ready (5B + Depth)") return v2v_pipe def unload_v2v_pipeline(): """Restore I2V to GPU after V2V.""" global v2v_pipe if v2v_pipe is not None: v2v_pipe.to('cpu') clear_vram() i2v_pipe.to('cuda') print("V2V → CPU, I2V → GPU") def extract_depth_frames(video_path, num_frames, target_h, target_w): """Extract frames from video and compute depth maps.""" global depth_processor if depth_processor is None: depth_processor = MidasDetector.from_pretrained('lllyasviel/Annotators') frames = load_video(video_path)[:num_frames] frames = [f.resize((target_w, target_h)) for f in frames] print(f"Extracting depth for {len(frames)} frames ({target_w}x{target_h})...") depth_frames = [] for i, frame in enumerate(frames): depth = depth_processor(frame) depth_frames.append(depth) if (i + 1) % 10 == 0: print(f" Depth: {i+1}/{len(frames)}") print(f"Depth extraction done: {len(depth_frames)} frames") return depth_frames # ============ Utils ============ def resize_image(image, max_dim=832, min_dim=480, square_dim=640, multiple_of=16): width, height = image.size if width == height: return image.resize((square_dim, square_dim), Image.LANCZOS) aspect_ratio = width / height max_ar = max_dim / min_dim min_ar = min_dim / max_dim if aspect_ratio > max_ar: crop_width = int(round(height * max_ar)) left = (width - crop_width) // 2 image = image.crop((left, 0, left + crop_width, height)) target_w, target_h = max_dim, min_dim elif aspect_ratio < min_ar: crop_height = int(round(width / min_ar)) top = (height - crop_height) // 2 image = image.crop((0, top, width, top + crop_height)) target_w, target_h = min_dim, max_dim else: if width > height: target_w = max_dim target_h = int(round(target_w / aspect_ratio)) else: target_h = max_dim target_w = int(round(target_h * aspect_ratio)) final_w = max(min_dim, min(max_dim, round(target_w / multiple_of) * multiple_of)) final_h = max(min_dim, min(max_dim, round(target_h / multiple_of) * multiple_of)) return image.resize((final_w, final_h), Image.LANCZOS) def resize_and_crop_to_match(target_image, reference_image): ref_w, ref_h = reference_image.size tgt_w, tgt_h = target_image.size scale = max(ref_w / tgt_w, ref_h / tgt_h) new_w, new_h = int(tgt_w * scale), int(tgt_h * scale) resized = target_image.resize((new_w, new_h), Image.Resampling.LANCZOS) left, top = (new_w - ref_w) // 2, (new_h - ref_h) // 2 return resized.crop((left, top, left + ref_w, top + ref_h)) def get_num_frames(duration_seconds): raw = int(round(duration_seconds * FIXED_FPS)) raw = ((raw - 1) // 4) * 4 + 1 return int(np.clip(raw, 9, MAX_FRAMES_MODEL)) def extract_video_path(input_video): if input_video is None: return None if isinstance(input_video, str): return input_video if isinstance(input_video, dict): # Gradio 5.x format: {'video': filepath, ...} or {'name': filepath, ...} or {'path': filepath} return input_video.get("video", input_video.get("path", input_video.get("name", None))) # Could be a Gradio VideoData object if hasattr(input_video, 'video'): return input_video.video if hasattr(input_video, 'path'): return input_video.path if hasattr(input_video, 'name'): return input_video.name return str(input_video) def extract_first_frame(video_input): path = extract_video_path(video_input) if not path or not os.path.exists(path): return None cap = cv2.VideoCapture(path) ret, frame = cap.read() cap.release() if ret: return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) return None # ============ Inference ============ @spaces.GPU(duration=1200) def run_inference( task_type, input_image, input_video, prompt, negative_prompt, duration_seconds, steps, guidance_scale, guidance_scale_2, current_seed, scheduler_name, flow_shift, frame_multiplier, quality, last_image_input, lora_groups, progress=gr.Progress(track_tqdm=True), ): clear_vram() num_frames = get_num_frames(duration_seconds) task_id = str(uuid.uuid4())[:8] print(f"Task: {task_id}, type={task_type}, duration={duration_seconds}s, frames={num_frames}") start = time.time() if "T2V" in task_type: # ====== T2V: 14B + Lightning LoRA (4 steps, dual guidance) ====== t2v_steps = max(int(steps), 4) print(f"T2V: steps={t2v_steps}, guidance={guidance_scale}/{guidance_scale_2}, frames={num_frames}") pipe = load_t2v_pipeline() result = pipe( prompt=prompt, negative_prompt=negative_prompt, height=480, width=832, num_frames=num_frames, guidance_scale=float(guidance_scale), guidance_scale_2=float(guidance_scale_2), num_inference_steps=t2v_steps, generator=torch.Generator(device="cpu").manual_seed(int(current_seed)), output_type="np", ) unload_t2v_pipeline() else: # ====== I2V / V2V ====== if "V2V" in task_type: # V2V uses ControlNet Depth pipeline — completely separate from I2V print(f"V2V: input_video type={type(input_video)}, value={input_video}") video_path = extract_video_path(input_video) if not video_path or not os.path.exists(video_path): raise gr.Error("Upload a video for V2V / V2V请上传视频") # Extract depth maps from source video target_h, target_w = 480, 832 depth_frames = extract_depth_frames(video_path, num_frames, target_h, target_w) if not depth_frames: raise gr.Error("Failed to extract depth from video / 无法提取视频深度图") # Load V2V ControlNet pipeline pipe = load_v2v_pipeline() v2v_steps = max(int(steps), 20) # ControlNet needs more steps (no Lightning) print(f"V2V ControlNet: steps={v2v_steps}, guidance={guidance_scale}, frames={num_frames}, depth_frames={len(depth_frames)}") result = pipe( prompt=prompt, negative_prompt=negative_prompt, height=target_h, width=target_w, num_frames=min(num_frames, len(depth_frames)), guidance_scale=max(float(guidance_scale), 5.0), num_inference_steps=v2v_steps, generator=torch.Generator(device="cpu").manual_seed(int(current_seed)), output_type="np", controlnet_frames=depth_frames, controlnet_weight=0.8, controlnet_guidance_start=0.0, controlnet_guidance_end=0.8, ) unload_v2v_pipeline() else: # ====== I2V ====== if input_image is None: raise gr.Error("Upload an image / 请上传图片") scheduler_class = SCHEDULER_MAP.get(scheduler_name) if scheduler_class and scheduler_class.__name__ != i2v_pipe.scheduler.config._class_name: config = copy.deepcopy(i2v_original_scheduler.config) if scheduler_class == FlowMatchEulerDiscreteScheduler: config['shift'] = flow_shift else: config['flow_shift'] = flow_shift i2v_pipe.scheduler = scheduler_class.from_config(config) lora_loaded = False if lora_groups: try: for idx, name in enumerate(lora_groups): if name and name != "(None)": lora_loader.load_lora_to_pipe(i2v_pipe, name, adapter_name=f"lora_{idx}") lora_loaded = True except Exception as e: print(f"LoRA warning: {e}") resized_image = resize_image(input_image) processed_last = None if last_image_input: processed_last = resize_and_crop_to_match(last_image_input, resized_image) print(f"I2V: size={resized_image.size}, steps={int(steps)}, guidance={guidance_scale}/{guidance_scale_2}") result = i2v_pipe( image=resized_image, last_image=processed_last, prompt=prompt, negative_prompt=negative_prompt, height=resized_image.height, width=resized_image.width, num_frames=num_frames, guidance_scale=float(guidance_scale), guidance_scale_2=float(guidance_scale_2), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(int(current_seed)), output_type="np", ) if lora_loaded: lora_loader.unload_lora(i2v_pipe) raw_frames = result.frames[0] elapsed = time.time() - start print(f"Generation took {elapsed:.1f}s ({len(raw_frames)} frames)") frame_factor = frame_multiplier // FIXED_FPS if frame_factor > 1: rife_model.device() rife_model.flownet = rife_model.flownet.half() final_frames = interpolate_bits(raw_frames, multiplier=int(frame_factor)) else: final_frames = list(raw_frames) final_fps = FIXED_FPS * max(1, frame_factor) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name export_to_video(final_frames, video_path, fps=final_fps, quality=quality) return video_path, task_id # ============ Generate ============ def generate_video( task_type, input_image, input_video, prompt, lora_groups, duration_seconds, frame_multiplier, steps, guidance_scale, guidance_scale_2, negative_prompt, quality, seed, randomize_seed, scheduler, flow_shift, last_image, display_result, progress=gr.Progress(track_tqdm=True), ): if not prompt or not prompt.strip(): raise gr.Error("Enter a prompt / 请输入提示词") current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) video_path, task_id = run_inference( task_type, input_image, input_video, prompt, negative_prompt, duration_seconds, steps, guidance_scale, guidance_scale_2, current_seed, scheduler, flow_shift, frame_multiplier, quality, last_image, lora_groups, ) print(f"Done: {task_id}") return (video_path if display_result else None), video_path, current_seed # ============ UI ============ CSS = """ #hidden-timestamp { opacity: 0; height: 0; width: 0; margin: 0; padding: 0; overflow: hidden; position: absolute; } """ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as demo: gr.Markdown("## WAN 2.2 Multi-Task Video Generation / 多任务视频生成") gr.Markdown("#### I2V (Lightning 6-step) · T2V (Lightning 14B 4-step) · V2V (ControlNet Depth) · LoRA") gr.Markdown("---") task_type = gr.Radio( choices=[ "I2V (图生视频 / Image-to-Video)", "T2V (文生视频 / Text-to-Video)", "V2V (视频生视频 / Video-to-Video)", ], value="I2V (图生视频 / Image-to-Video)", label="Task Type / 任务类型", ) with gr.Row(): with gr.Column(): with gr.Group(): input_image = gr.Image(type="pil", label="Input Image / 输入图片 (I2V/V2V)", sources=["upload", "clipboard"]) with gr.Group(): input_video = gr.Video(label="Input Video / 输入视频 (V2V)", sources=["upload"], visible=False, interactive=True) prompt_input = gr.Textbox( label="Prompt / 提示词", value="", placeholder="Describe the video... / 描述你想生成的视频...", lines=3, ) duration_slider = gr.Slider( minimum=0.5, maximum=15, step=0.5, value=3, label="Duration / 时长 (seconds/秒)", info="Max ~15s (241 frames @16fps) / 最大约15秒", ) frame_multi = gr.Dropdown(choices=[16, 32, 64], value=16, label="Output FPS / 输出帧率", info="RIFE interpolation / RIFE插帧") with gr.Accordion("⚙️ Advanced Settings / 高级设置", open=False): last_image = gr.Image(type="pil", label="Last Frame / 末帧 (Optional)", sources=["upload", "clipboard"]) negative_prompt_input = gr.Textbox(label="Negative Prompt / 负面提示词", value=default_negative_prompt, lines=3) with gr.Row(): steps_slider = gr.Slider(minimum=1, maximum=50, step=1, value=6, label="Steps / 步数", info="I2V: 4-8 | T2V: 4-8 | V2V: 25-50") quality_sl = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Quality / 质量") with gr.Row(): guidance_h = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance High / 引导(高噪声)") guidance_l = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Low / 引导(低噪声)") with gr.Row(): scheduler_dd = gr.Dropdown(choices=list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler / 调度器") flow_shift_sl = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift / 流偏移") with gr.Row(): seed_sl = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=42, label="Seed / 种子") random_seed_cb = gr.Checkbox(label="Random / 随机", value=True) lora_dd = gr.Dropdown(choices=lora_loader.get_lora_choices(), label="LoRA (I2V only / 仅I2V)", multiselect=True, info="From WAN2.2_LoraSet_NSFW") display_cb = gr.Checkbox(label="Display / 显示", value=True) generate_btn = gr.Button("🎬 Generate / 生成视频", variant="primary", size="lg") with gr.Column(): video_output = gr.Video(label="Generated Video / 生成的视频", autoplay=True, sources=["upload"], show_download_button=True, show_share_button=True, interactive=False, elem_id="generated-video") with gr.Row(): grab_frame_btn = gr.Button("📸 Use Frame / 使用帧", variant="secondary") timestamp_box = gr.Number(value=0, label="Timestamp", visible=False, elem_id="hidden-timestamp") file_output = gr.File(label="Download / 下载") def update_task_ui(task): is_v2v = "V2V" in task is_t2v = "T2V" in task if is_t2v: return gr.update(visible=False), gr.update(visible=False), gr.update(value=4), gr.update(value=1.0), gr.update(value=1.0) elif is_v2v: return gr.update(visible=False), gr.update(visible=True), gr.update(value=30), gr.update(value=5.0), gr.update(value=1.0) else: return gr.update(visible=True), gr.update(visible=False), gr.update(value=6), gr.update(value=1.0), gr.update(value=1.0) task_type.change(update_task_ui, inputs=[task_type], outputs=[input_image, input_video, steps_slider, guidance_h, guidance_l]) generate_btn.click( fn=generate_video, inputs=[task_type, input_image, input_video, prompt_input, lora_dd, duration_slider, frame_multi, steps_slider, guidance_h, guidance_l, negative_prompt_input, quality_sl, seed_sl, random_seed_cb, scheduler_dd, flow_shift_sl, last_image, display_cb], outputs=[video_output, file_output, seed_sl], ) grab_frame_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js) timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image]) if __name__ == "__main__": demo.queue().launch(mcp_server=True, show_error=True)