pormungtai commited on
Commit
159bfa1
·
verified ·
1 Parent(s): dd5894d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -173
app.py CHANGED
@@ -1,199 +1,158 @@
 
 
 
 
1
  import spaces
2
- import os, sys, subprocess, importlib, site
3
- from PIL import Image
4
- import cv2, gradio as gr, gc, numpy as np, tempfile
5
  from huggingface_hub import snapshot_download
6
 
7
- # Clone Wan2.2 source code from GitHub (contains wan/ module) —
8
- WAN_REPO = "https://github.com/Wan-Video/Wan2.2.git"
9
- WAN_DIR = os.path.join(os.getcwd(), "Wan2.2")
10
- if not os.path.exists(WAN_DIR):
11
- print("Cloning Wan2.2 from GitHub...")
12
- subprocess.run(["git", "clone", "--depth", "1", WAN_REPO, WAN_DIR], check=True)
13
- print("Clone complete.")
14
-
15
- # — Patch wan/modules/t5.py: calls torch.cuda.current_device() at class
16
- # definition time which fails at startup (no GPU yet in ZeroGPU). —
17
- t5_path = os.path.join(WAN_DIR, "wan", "modules", "t5.py")
18
- if os.path.exists(t5_path):
19
- with open(t5_path) as f:
20
- t5_code = f.read()
21
- if "torch.cuda.current_device()" in t5_code:
22
- t5_code = t5_code.replace(
23
- "device=torch.cuda.current_device(),",
24
- "device=0, # patched for ZeroGPU"
25
  )
 
 
 
 
 
26
  with open(t5_path, "w") as f:
27
- f.write(t5_code)
28
- print("Patched wan/modules/t5.py for ZeroGPU compatibility.")
29
-
30
- sys.path.insert(0, WAN_DIR)
31
- PREPROCESS_DIR = os.path.join(WAN_DIR, "wan", "modules", "animate", "preprocess")
32
- sys.path.append(PREPROCESS_DIR)
33
- for sitedir in site.getsitepackages():
34
- site.addsitedir(sitedir)
35
- importlib.invalidate_caches()
36
-
37
- # — Download SAM2 weights (small, just files) —
38
- try:
39
- snapshot_download(repo_id="alexnasa/sam2_C_cpu", local_dir=os.getcwd())
40
- print("sam2 weights downloaded successfully.")
41
- except Exception as e:
42
- print(f"Warning: sam2 download failed: {e}")
43
-
44
- # Download Wan2.2-Animate-14B weights —
45
- # ZeroGPU free tier has 50GB storage limit. Full model is ~51.5GB.
46
- # We skip the T5 encoder (11.4GB) + tokenizer since the animate task
47
- # is video-to-video motion transfer and uses internal text conditioning.
48
- # DiT (~34.5GB) + CLIP (~4.8GB) + VAE (~0.5GB) ≈ 40GB → fits under 50GB.
49
- print("Downloading Wan2.2-Animate-14B model weights (DiT + CLIP + VAE)...")
50
- snapshot_download(
51
- repo_id="Wan-AI/Wan2.2-Animate-14B",
52
- local_dir="./Wan2.2-Animate-14B",
53
- ignore_patterns=[
54
- "models_t5_*", # T5 text encoder (11.4GB) — skipped to fit 50GB
55
- "google/*", # umt5-xxl tokenizer files
56
- "tokenizer*",
57
- "special_tokens_map.json",
58
- ]
59
- )
60
- print("Model weights downloaded.")
61
-
62
- # Now safe to import wan (t5.py is patched) —
63
- import torch
64
- from generate import generate, load_model
65
- from preprocess_data import run as run_preprocess
66
- from preprocess_data import load_preprocess_models
67
-
68
- # — Lazy model init: load inside @spaces.GPU on first call —
 
 
 
 
 
 
69
  _wan_animate = None
70
 
71
  def get_wan_animate():
72
  global _wan_animate
73
  if _wan_animate is None:
74
- print("Loading WanAnimate model (first call)...")
75
- _wan_animate = load_model(True)
76
- print("WanAnimate model loaded.")
77
  return _wan_animate
78
 
 
 
 
 
 
79
 
80
- def clip_and_set_fps(input_video_path, output_video_path, duration_s=3, target_fps=8):
81
- cmd = [
82
- "ffmpeg", "-nostdin", "-hide_banner", "-y",
83
- "-i", input_video_path, "-t", str(duration_s),
84
- "-vf", f"fps={target_fps}",
85
- "-c:v", "libx264", "-pix_fmt", "yuv420p",
86
- "-preset", "veryfast", "-crf", "18",
87
- "-c:a", "aac", "-movflags", "+faststart",
88
- output_video_path,
89
- ]
90
- subprocess.run(cmd, check=True, capture_output=True)
91
-
92
-
93
- def preprocess_video(path, duration):
94
- out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
95
- clip_and_set_fps(path, out, duration_s=duration)
96
- return out
97
 
 
 
 
98
 
99
- def is_portrait(video_file):
100
- cap = cv2.VideoCapture(video_file)
101
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
102
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
103
- cap.release()
104
- return w < h
105
-
106
-
107
- @spaces.GPU(duration=500)
108
- def predict(ref_img, video, mode, quality, max_duration_s):
109
  try:
110
- if ref_img is None or video is None:
111
- return None, "Error: Please provide both Reference Image and Template Video."
112
-
113
- wan_animate = get_wan_animate()
114
-
115
- replace_flag = (mode == "wan2.2-animate-mix")
116
- tag_string = "replace_flag" if replace_flag else "retarget_flag"
117
-
118
- input_video = preprocess_video(video, int(max_duration_s))
119
- w, h = (480, 832) if is_portrait(input_video) else (832, 480)
120
-
121
- edited_frame_png = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
122
- Image.open(ref_img).save(edited_frame_png)
123
-
124
- output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
125
- tmpdir = tempfile.mkdtemp()
126
-
127
- preprocess_model = load_preprocess_models(int(max_duration_s))
128
- src_pose_path, src_face_path, src_bg_path, src_mask_path, src_ref_path = run_preprocess(
129
- preprocess_model, input_video, edited_frame_png, tmpdir, w, h, tag_string, {}, {})
130
-
131
- generate(wan_animate, src_pose_path, src_face_path, src_bg_path,
132
- src_mask_path, src_ref_path, output_video_path, replace_flag)
 
 
 
133
 
134
- gc.collect()
135
- torch.cuda.empty_cache()
136
- return output_video_path, "SUCCEEDED - Video generated successfully!"
 
 
 
 
 
 
 
137
 
 
138
  except Exception as e:
139
- return None, f"Error: {str(e)}"
140
-
141
-
142
- # ——— Vivek957 UI ———
143
- HEAD = """
144
- <div style="text-align:center; margin-bottom:10px">
145
- <h1 style="font-size:2em; font-weight:700">Wan2.2 Animate (ZeroGPU)</h1>
146
- <p>Motion Transfer · Free ZeroGPU A100</p>
147
- <div style="display:flex; gap:8px; justify-content:center; margin-top:8px">
148
- <a href="https://arxiv.org/abs/2503.20314" target="_blank">
149
- <button style="padding:6px 14px; border-radius:6px; border:1px solid #aaa; cursor:pointer">📄 Paper</button>
150
- </a>
151
- <a href="https://github.com/Wan-Video/Wan2.2" target="_blank">
152
- <button style="padding:6px 14px; border-radius:6px; border:1px solid #aaa; cursor:pointer">💻 GitHub</button>
153
- </a>
154
- <a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" target="_blank">
155
- <button style="padding:6px 14px; border-radius:6px; border:1px solid #aaa; cursor:pointer">🤗 HF Model</button>
156
- </a>
157
- <a href="https://modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B" target="_blank">
158
- <button style="padding:6px 14px; border-radius:6px; border:1px solid #aaa; cursor:pointer">🔮 ModelScope</button>
159
- </a>
160
- </div>
161
- </div>
162
- """
163
 
 
164
  with gr.Blocks(title="Wan2.2 Animate") as demo:
165
- gr.HTML(HEAD)
166
-
167
- with gr.Accordion("📖 Usage", open=False):
168
- gr.Markdown("""
169
- **How to use:**
170
- 1. Upload a **Reference Image** (the character/person you want to animate)
171
- 2. Upload a **Template Video** (the motion source)
172
- 3. Choose **Mode** and **Quality**
173
- 4. Click **Generate Video**
174
- """)
175
-
176
  with gr.Row():
177
  with gr.Column():
178
- ref_img = gr.Image(label="Reference Image(参考图片)", type="filepath")
179
- video = gr.Video(label="Template Video(模板视频)")
180
- mode = gr.Dropdown(
181
- label="&#25512;&#29702;&#27169;&#24335;(Inference Mode)",
182
- choices=["wan2.2-animate", "wan2.2-animate-mix"],
183
- value="wan2.2-animate"
184
- )
185
- quality = gr.Dropdown(label="&#25512;&#29702;&#36136;&#37327;(Inference Quality)",
186
- choices=["wan-pro", "wan-std"], value="wan-pro")
187
- max_dur = gr.Slider(label="Max Duration (sec)", minimum=1, maximum=5,
188
- step=1, value=3)
189
- run_button = gr.Button("Generate Video(&#29983;&#25104;&#35270;&#39057;)", variant="primary")
190
  with gr.Column():
191
- output_video = gr.Video(label="Output Video(&#36755;&#20986;&#35270;&#39057;)")
192
- output_status = gr.Textbox(label="Status(&#29366;&#24577;)", lines=5)
193
 
194
- run_button.click(fn=predict,
195
- inputs=[ref_img, video, mode, quality, max_dur],
196
- outputs=[output_video, output_status])
 
 
197
 
198
- demo.queue(default_concurrency_limit=5)
199
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import shutil
5
  import spaces
6
+ import gradio as gr
7
+ import torch
 
8
  from huggingface_hub import snapshot_download
9
 
10
+ # ── Patch wan/modules/t5.py before importing wan ─────────────────────────────
11
+ def clone_and_patch_wan():
12
+ if not os.path.exists("./Wan2.2"):
13
+ subprocess.run(
14
+ ["git", "clone", "https://github.com/Wan-Video/Wan2.2.git", "./Wan2.2"],
15
+ check=True
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
+ t5_path = "./Wan2.2/wan/modules/t5.py"
18
+ with open(t5_path, "r") as f:
19
+ src = f.read()
20
+ if "device=torch.cuda.current_device()," in src:
21
+ src = src.replace("device=torch.cuda.current_device(),", "device=0,")
22
  with open(t5_path, "w") as f:
23
+ f.write(src)
24
+ print("[patch] t5.py patched: replaced current_device() with 0")
25
+
26
+ clone_and_patch_wan()
27
+
28
+ if "./Wan2.2" not in sys.path:
29
+ sys.path.insert(0, "./Wan2.2")
30
+
31
+ # ── Download SAM2 CPU model ───────────────────────────────────────────────────
32
+ if not os.path.exists("./process_checkpoint/sam2"):
33
+ snapshot_download(
34
+ repo_id="alexnasa/sam2_C_cpu",
35
+ local_dir="./process_checkpoint/sam2",
36
+ )
37
+ print("[init] SAM2 CPU model downloaded")
38
+
39
+ # ── Download Wan2.2-Animate-14B (skip large unused files) ────────────────────
40
+ if not os.path.exists("./Wan2.2-Animate-14B"):
41
+ snapshot_download(
42
+ repo_id="Wan-AI/Wan2.2-Animate-14B",
43
+ local_dir="./Wan2.2-Animate-14B",
44
+ ignore_patterns=[
45
+ "models_t5_*",
46
+ "google/*",
47
+ "tokenizer*",
48
+ "special_tokens_map.json",
49
+ "xlm-roberta-large/*",
50
+ "relighting_lora.ckpt",
51
+ "relighting_lora/*",
52
+ "process_checkpoint/sam2/*",
53
+ ]
54
+ )
55
+ print("[init] Wan2.2-Animate-14B downloaded")
56
+
57
+ # ── Symlink SAM2 into model's expected path ───────────────────────────────────
58
+ sam2_dst = "./Wan2.2-Animate-14B/process_checkpoint/sam2"
59
+ sam2_src = "./process_checkpoint/sam2"
60
+ if not os.path.exists(sam2_dst) and os.path.exists(sam2_src):
61
+ os.makedirs(os.path.dirname(sam2_dst), exist_ok=True)
62
+ os.symlink(os.path.abspath(sam2_src), sam2_dst)
63
+ print("[init] SAM2 symlink created")
64
+
65
+ # ── Copy helper scripts ───────────────────────────────────────────────────────
66
+ for fname in ["generate.py", "preprocess_data.py"]:
67
+ if os.path.exists(f"./{fname}") and not os.path.exists(f"./Wan2.2/{fname}"):
68
+ shutil.copy(f"./{fname}", f"./Wan2.2/{fname}")
69
+
70
+ # ── Lazy model init ───────────────────────────────────────────────────────────
71
  _wan_animate = None
72
 
73
  def get_wan_animate():
74
  global _wan_animate
75
  if _wan_animate is None:
76
+ sys.path.insert(0, "./Wan2.2")
77
+ from generate import load_model
78
+ _wan_animate = load_model(False)
79
  return _wan_animate
80
 
81
+ # ── Inference ─────────────────────────────────────────────────────────────────
82
+ @spaces.GPU(duration=300)
83
+ def run_animate(ref_image, template_video, mode, quality, max_duration):
84
+ import uuid
85
+ from generate import generate
86
 
87
+ wan_animate = get_wan_animate()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ uid = str(uuid.uuid4())[:8]
90
+ work_dir = f"/tmp/wan_{uid}"
91
+ os.makedirs(work_dir, exist_ok=True)
92
 
 
 
 
 
 
 
 
 
 
 
93
  try:
94
+ ref_path = os.path.join(work_dir, "ref.jpg")
95
+ tmpl_path = os.path.join(work_dir, "template.mp4")
96
+
97
+ import numpy as np
98
+ from PIL import Image
99
+ if isinstance(ref_image, np.ndarray):
100
+ Image.fromarray(ref_image).save(ref_path)
101
+ else:
102
+ shutil.copy(ref_image, ref_path)
103
+ shutil.copy(template_video, tmpl_path)
104
+
105
+ pose_path = os.path.join(work_dir, "pose.mp4")
106
+ face_path = os.path.join(work_dir, "face.png")
107
+ bg_path = os.path.join(work_dir, "bg.png")
108
+ mask_path = os.path.join(work_dir, "mask.png")
109
+
110
+ from preprocess_data import preprocess
111
+ preprocess(
112
+ ref_image=ref_path,
113
+ template_video=tmpl_path,
114
+ output_pose=pose_path,
115
+ output_face=face_path,
116
+ output_bg=bg_path,
117
+ output_mask=mask_path,
118
+ mode=mode,
119
+ )
120
 
121
+ out_path = os.path.join(work_dir, "output.mp4")
122
+ generate(
123
+ wan_animate=wan_animate,
124
+ src_pose_path=pose_path,
125
+ src_face_path=face_path,
126
+ src_bg_path=bg_path,
127
+ src_mask_path=mask_path,
128
+ src_ref_path=ref_path,
129
+ save_file=out_path,
130
+ )
131
 
132
+ return out_path, "Done!"
133
  except Exception as e:
134
+ return None, f"Error: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # ── UI ────────────────────────────────────────────────────────────────────────
137
  with gr.Blocks(title="Wan2.2 Animate") as demo:
138
+ gr.Markdown("## Wan2.2 Animate — ZeroGPU (Free A100)")
 
 
 
 
 
 
 
 
 
 
139
  with gr.Row():
140
  with gr.Column():
141
+ ref_image = gr.Image(label="Reference Image", type="numpy")
142
+ template_video = gr.Video(label="Template Video")
143
+ mode = gr.Dropdown(["normal", "tiktok"], value="normal", label="Mode")
144
+ quality = gr.Dropdown(["standard", "high"], value="standard", label="Quality")
145
+ max_duration = gr.Slider(1, 10, value=5, step=1, label="Max Duration (s)")
146
+ btn = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
147
  with gr.Column():
148
+ out_video = gr.Video(label="Output Video")
149
+ status = gr.Textbox(label="Status", interactive=False)
150
 
151
+ btn.click(
152
+ run_animate,
153
+ inputs=[ref_image, template_video, mode, quality, max_duration],
154
+ outputs=[out_video, status],
155
+ )
156
 
157
+ if __name__ == "__main__":
158
+ demo.launch()