import spaces import os, sys, subprocess, importlib, site from PIL import Image import cv2, gradio as gr, gc, numpy as np, tempfile from huggingface_hub import snapshot_download # — Clone Wan2.2 source code from GitHub (contains wan/ module) — WAN_REPO = "https://github.com/Wan-Video/Wan2.2.git" WAN_DIR = os.path.join(os.getcwd(), "Wan2.2") if not os.path.exists(WAN_DIR): print("Cloning Wan2.2 from GitHub...") subprocess.run(["git", "clone", "--depth", "1", WAN_REPO, WAN_DIR], check=True) print("Clone complete.") # — Patch wan/modules/t5.py: calls torch.cuda.current_device() at class # definition time which fails at startup (no GPU yet in ZeroGPU). — t5_path = os.path.join(WAN_DIR, "wan", "modules", "t5.py") if os.path.exists(t5_path): with open(t5_path) as f: t5_code = f.read() if "torch.cuda.current_device()" in t5_code: t5_code = t5_code.replace( "device=torch.cuda.current_device(),", "device=0, # patched for ZeroGPU" ) with open(t5_path, "w") as f: f.write(t5_code) print("Patched wan/modules/t5.py for ZeroGPU compatibility.") sys.path.insert(0, WAN_DIR) PREPROCESS_DIR = os.path.join(WAN_DIR, "wan", "modules", "animate", "preprocess") sys.path.append(PREPROCESS_DIR) for sitedir in site.getsitepackages(): site.addsitedir(sitedir) importlib.invalidate_caches() # — Download SAM2 weights (small, just files) — try: snapshot_download(repo_id="alexnasa/sam2_C_cpu", local_dir=os.getcwd()) print("sam2 weights downloaded successfully.") except Exception as e: print(f"Warning: sam2 download failed: {e}") # — Download Wan2.2-Animate-14B weights — # ZeroGPU free tier has 50GB storage limit. Full model is ~51.5GB. # We skip the T5 encoder (11.4GB) + tokenizer since the animate task # is video-to-video motion transfer and uses internal text conditioning. # DiT (~34.5GB) + CLIP (~4.8GB) + VAE (~0.5GB) ≈ 40GB → fits under 50GB. print("Downloading Wan2.2-Animate-14B model weights (DiT + CLIP + VAE)...") snapshot_download( repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B", ignore_patterns=[ "models_t5_*", # T5 text encoder (11.4GB) — skipped to fit 50GB "google/*", # umt5-xxl tokenizer files "tokenizer*", "special_tokens_map.json", ] ) print("Model weights downloaded.") # — Now safe to import wan (t5.py is patched) — import torch from generate import generate, load_model from preprocess_data import run as run_preprocess from preprocess_data import load_preprocess_models # — Lazy model init: load inside @spaces.GPU on first call — _wan_animate = None def get_wan_animate(): global _wan_animate if _wan_animate is None: print("Loading WanAnimate model (first call)...") _wan_animate = load_model(True) print("WanAnimate model loaded.") return _wan_animate def clip_and_set_fps(input_video_path, output_video_path, duration_s=3, target_fps=8): cmd = [ "ffmpeg", "-nostdin", "-hide_banner", "-y", "-i", input_video_path, "-t", str(duration_s), "-vf", f"fps={target_fps}", "-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "veryfast", "-crf", "18", "-c:a", "aac", "-movflags", "+faststart", output_video_path, ] subprocess.run(cmd, check=True, capture_output=True) def preprocess_video(path, duration): out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name clip_and_set_fps(path, out, duration_s=duration) return out def is_portrait(video_file): cap = cv2.VideoCapture(video_file) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() return w < h @spaces.GPU(duration=500) def predict(ref_img, video, mode, quality, max_duration_s): try: if ref_img is None or video is None: return None, "Error: Please provide both Reference Image and Template Video." wan_animate = get_wan_animate() replace_flag = (mode == "wan2.2-animate-mix") tag_string = "replace_flag" if replace_flag else "retarget_flag" input_video = preprocess_video(video, int(max_duration_s)) w, h = (480, 832) if is_portrait(input_video) else (832, 480) edited_frame_png = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name Image.open(ref_img).save(edited_frame_png) output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name tmpdir = tempfile.mkdtemp() preprocess_model = load_preprocess_models(int(max_duration_s)) src_pose_path, src_face_path, src_bg_path, src_mask_path, src_ref_path = run_preprocess( preprocess_model, input_video, edited_frame_png, tmpdir, w, h, tag_string, {}, {}) generate(wan_animate, src_pose_path, src_face_path, src_bg_path, src_mask_path, src_ref_path, output_video_path, replace_flag) gc.collect() torch.cuda.empty_cache() return output_video_path, "SUCCEEDED - Video generated successfully!" except Exception as e: return None, f"Error: {str(e)}" # ——— Vivek957 UI ——— HEAD = """
""" with gr.Blocks(title="Wan2.2 Animate") as demo: gr.HTML(HEAD) with gr.Accordion("📖 Usage", open=False): gr.Markdown(""" **How to use:** 1. Upload a **Reference Image** (the character/person you want to animate) 2. Upload a **Template Video** (the motion source) 3. Choose **Mode** and **Quality** 4. Click **Generate Video** """) with gr.Row(): with gr.Column(): ref_img = gr.Image(label="Reference Image(参考图片)", type="filepath") video = gr.Video(label="Template Video(模板视频)") mode = gr.Dropdown( label="推理模式(Inference Mode)", choices=["wan2.2-animate", "wan2.2-animate-mix"], value="wan2.2-animate" ) quality = gr.Dropdown(label="推理质量(Inference Quality)", choices=["wan-pro", "wan-std"], value="wan-pro") max_dur = gr.Slider(label="Max Duration (sec)", minimum=1, maximum=5, step=1, value=3) run_button = gr.Button("Generate Video(生成视频)", variant="primary") with gr.Column(): output_video = gr.Video(label="Output Video(输出视频)") output_status = gr.Textbox(label="Status(状态)", lines=5) run_button.click(fn=predict, inputs=[ref_img, video, mode, quality, max_dur], outputs=[output_video, output_status]) demo.queue(default_concurrency_limit=5) demo.launch(server_name="0.0.0.0", server_port=7860)