pormungtai commited on
Commit
dea8b1b
·
verified ·
1 Parent(s): 159bfa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -148
app.py CHANGED
@@ -1,158 +1,270 @@
 
1
  import os
 
2
  import sys
3
- import subprocess
4
  import shutil
5
- import spaces
6
  import gradio as gr
7
- import torch
8
- from huggingface_hub import snapshot_download
9
-
10
- # ── Patch wan/modules/t5.py before importing wan ─────────────────────────────
11
- def clone_and_patch_wan():
12
- if not os.path.exists("./Wan2.2"):
13
- subprocess.run(
14
- ["git", "clone", "https://github.com/Wan-Video/Wan2.2.git", "./Wan2.2"],
15
- check=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
- t5_path = "./Wan2.2/wan/modules/t5.py"
18
- with open(t5_path, "r") as f:
19
- src = f.read()
20
- if "device=torch.cuda.current_device()," in src:
21
- src = src.replace("device=torch.cuda.current_device(),", "device=0,")
22
- with open(t5_path, "w") as f:
23
- f.write(src)
24
- print("[patch] t5.py patched: replaced current_device() with 0")
25
-
26
- clone_and_patch_wan()
27
-
28
- if "./Wan2.2" not in sys.path:
29
- sys.path.insert(0, "./Wan2.2")
30
-
31
- # ── Download SAM2 CPU model ───────────────────────────────────────────────────
32
- if not os.path.exists("./process_checkpoint/sam2"):
33
- snapshot_download(
34
- repo_id="alexnasa/sam2_C_cpu",
35
- local_dir="./process_checkpoint/sam2",
36
- )
37
- print("[init] SAM2 CPU model downloaded")
38
-
39
- # ── Download Wan2.2-Animate-14B (skip large unused files) ────────────────────
40
- if not os.path.exists("./Wan2.2-Animate-14B"):
41
- snapshot_download(
42
- repo_id="Wan-AI/Wan2.2-Animate-14B",
43
- local_dir="./Wan2.2-Animate-14B",
44
- ignore_patterns=[
45
- "models_t5_*",
46
- "google/*",
47
- "tokenizer*",
48
- "special_tokens_map.json",
49
- "xlm-roberta-large/*",
50
- "relighting_lora.ckpt",
51
- "relighting_lora/*",
52
- "process_checkpoint/sam2/*",
53
  ]
54
- )
55
- print("[init] Wan2.2-Animate-14B downloaded")
56
-
57
- # ── Symlink SAM2 into model's expected path ───────────────────────────────────
58
- sam2_dst = "./Wan2.2-Animate-14B/process_checkpoint/sam2"
59
- sam2_src = "./process_checkpoint/sam2"
60
- if not os.path.exists(sam2_dst) and os.path.exists(sam2_src):
61
- os.makedirs(os.path.dirname(sam2_dst), exist_ok=True)
62
- os.symlink(os.path.abspath(sam2_src), sam2_dst)
63
- print("[init] SAM2 symlink created")
64
-
65
- # ── Copy helper scripts ───────────────────────────────────────────────────────
66
- for fname in ["generate.py", "preprocess_data.py"]:
67
- if os.path.exists(f"./{fname}") and not os.path.exists(f"./Wan2.2/{fname}"):
68
- shutil.copy(f"./{fname}", f"./Wan2.2/{fname}")
69
-
70
- # ── Lazy model init ───────────────────────────────────────────────────────────
71
- _wan_animate = None
72
-
73
- def get_wan_animate():
74
- global _wan_animate
75
- if _wan_animate is None:
76
- sys.path.insert(0, "./Wan2.2")
77
- from generate import load_model
78
- _wan_animate = load_model(False)
79
- return _wan_animate
80
-
81
- # ── Inference ─────────────────────────────────────────────────────────────────
82
- @spaces.GPU(duration=300)
83
- def run_animate(ref_image, template_video, mode, quality, max_duration):
84
- import uuid
85
- from generate import generate
86
-
87
- wan_animate = get_wan_animate()
88
-
89
- uid = str(uuid.uuid4())[:8]
90
- work_dir = f"/tmp/wan_{uid}"
91
- os.makedirs(work_dir, exist_ok=True)
92
-
93
- try:
94
- ref_path = os.path.join(work_dir, "ref.jpg")
95
- tmpl_path = os.path.join(work_dir, "template.mp4")
96
-
97
- import numpy as np
98
- from PIL import Image
99
- if isinstance(ref_image, np.ndarray):
100
- Image.fromarray(ref_image).save(ref_path)
101
- else:
102
- shutil.copy(ref_image, ref_path)
103
- shutil.copy(template_video, tmpl_path)
104
-
105
- pose_path = os.path.join(work_dir, "pose.mp4")
106
- face_path = os.path.join(work_dir, "face.png")
107
- bg_path = os.path.join(work_dir, "bg.png")
108
- mask_path = os.path.join(work_dir, "mask.png")
109
-
110
- from preprocess_data import preprocess
111
- preprocess(
112
- ref_image=ref_path,
113
- template_video=tmpl_path,
114
- output_pose=pose_path,
115
- output_face=face_path,
116
- output_bg=bg_path,
117
- output_mask=mask_path,
118
- mode=mode,
119
- )
120
 
121
- out_path = os.path.join(work_dir, "output.mp4")
122
- generate(
123
- wan_animate=wan_animate,
124
- src_pose_path=pose_path,
125
- src_face_path=face_path,
126
- src_bg_path=bg_path,
127
- src_mask_path=mask_path,
128
- src_ref_path=ref_path,
129
- save_file=out_path,
130
- )
131
 
132
- return out_path, "Done!"
133
- except Exception as e:
134
- return None, f"Error: {e}"
135
-
136
- # ── UI ────────────────────────────────────────────────────────────────────────
137
- with gr.Blocks(title="Wan2.2 Animate") as demo:
138
- gr.Markdown("## Wan2.2 Animate — ZeroGPU (Free A100)")
139
- with gr.Row():
140
- with gr.Column():
141
- ref_image = gr.Image(label="Reference Image", type="numpy")
142
- template_video = gr.Video(label="Template Video")
143
- mode = gr.Dropdown(["normal", "tiktok"], value="normal", label="Mode")
144
- quality = gr.Dropdown(["standard", "high"], value="standard", label="Quality")
145
- max_duration = gr.Slider(1, 10, value=5, step=1, label="Max Duration (s)")
146
- btn = gr.Button("Generate", variant="primary")
147
- with gr.Column():
148
- out_video = gr.Video(label="Output Video")
149
- status = gr.Textbox(label="Status", interactive=False)
150
-
151
- btn.click(
152
- run_animate,
153
- inputs=[ref_image, template_video, mode, quality, max_duration],
154
- outputs=[out_video, status],
155
- )
156
 
157
  if __name__ == "__main__":
158
- demo.launch()
 
1
+ # app.py
2
  import os
3
+ import oss2
4
  import sys
5
+ import uuid
6
  import shutil
7
+ import time
8
  import gradio as gr
9
+ import requests
10
+ from pathlib import Path
11
+ from datetime import datetime, timedelta
12
+ import dashscope
13
+ # from dashscope.utils.oss_utils import check_and_upload_local
14
+
15
+ DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
16
+ dashscope.api_key = DASHSCOPE_API_KEY
17
+
18
+ def get_upload_policy(api_key, model_name):
19
+ """获取文件上传凭证"""
20
+ url = "https://dashscope.aliyuncs.com/api/v1/uploads"
21
+ headers = {
22
+ "Authorization": f"Bearer {api_key}",
23
+ "Content-Type": "application/json"
24
+ }
25
+ params = {
26
+ "action": "getPolicy",
27
+ "model": model_name
28
+ }
29
+ response = requests.get(url, headers=headers, params=params)
30
+ if response.status_code != 200:
31
+ raise Exception(f"Failed to get upload policy: {response.text}")
32
+ return response.json()['data']
33
+
34
+ def upload_file_to_oss(policy_data, file_path):
35
+ """将文件上传到临时存储OSS"""
36
+ file_name = Path(file_path).name
37
+ key = f"{policy_data['upload_dir']}/{file_name}"
38
+ with open(file_path, 'rb') as file:
39
+ files = {
40
+ 'OSSAccessKeyId': (None, policy_data['oss_access_key_id']),
41
+ 'Signature': (None, policy_data['signature']),
42
+ 'policy': (None, policy_data['policy']),
43
+ 'x-oss-object-acl': (None, policy_data['x_oss_object_acl']),
44
+ 'x-oss-forbid-overwrite': (None, policy_data['x_oss_forbid_overwrite']),
45
+ 'key': (None, key),
46
+ 'success_action_status': (None, '200'),
47
+ 'file': (file_name, file)
48
+ }
49
+ response = requests.post(policy_data['upload_host'], files=files)
50
+ if response.status_code != 200:
51
+ raise Exception(f"Failed to upload file: {response.text}")
52
+ return f"oss://{key}"
53
+
54
+ def upload_file_and_get_url(api_key, model_name, file_path):
55
+ """上传文件并获取URL"""
56
+ # 1. 获取上传凭证,上传凭证接口有限流,超出限流将导致请求失败
57
+ policy_data = get_upload_policy(api_key, model_name)
58
+ # 2. 上传文件到OSS
59
+ oss_url = upload_file_to_oss(policy_data, file_path)
60
+ return oss_url
61
+
62
+
63
+ class WanAnimateApp:
64
+ def __init__(self, url, get_url):
65
+ self.url = url
66
+ self.get_url = get_url
67
+
68
+ def predict(
69
+ self,
70
+ ref_img,
71
+ video,
72
+ model_id,
73
+ model,
74
+ ):
75
+ # Upload files to OSS if needed and get URLs
76
+ image_url = upload_file_and_get_url(DASHSCOPE_API_KEY, model_id, ref_img)
77
+ video_url = upload_file_and_get_url(DASHSCOPE_API_KEY, model_id, video)
78
+
79
+ # Prepare the request payload
80
+ payload = {
81
+ "model": model_id,
82
+ "input": {
83
+ "image_url": image_url,
84
+ "video_url": video_url
85
+ },
86
+ "parameters": {
87
+ "check_image": True,
88
+ "mode": model,
89
+ }
90
+ }
91
+
92
+ # Set up headers
93
+ headers = {
94
+ "X-DashScope-Async": "enable",
95
+ "X-DashScope-OssResourceResolve": "enable",
96
+ "Authorization": f"Bearer {DASHSCOPE_API_KEY}",
97
+ "Content-Type": "application/json"
98
+ }
99
+
100
+ # Make the initial API request
101
+ url = self.url
102
+ response = requests.post(url, json=payload, headers=headers, timeout=60)
103
+
104
+ # Check if request was successful
105
+ if response.status_code != 200:
106
+ raise Exception(f"Initial request failed with status code {response.status_code}: {response.text}")
107
+
108
+ # Get the task ID from response
109
+ result = response.json()
110
+ task_id = result.get("output", {}).get("task_id")
111
+ if not task_id:
112
+ raise Exception("Failed to get task ID from response")
113
+
114
+ # Poll for results
115
+ get_url = f"{self.get_url}/{task_id}"
116
+ headers = {
117
+ "Authorization": f"Bearer {DASHSCOPE_API_KEY}",
118
+ "Content-Type": "application/json"
119
+ }
120
+
121
+ while True:
122
+ response = requests.get(get_url, headers=headers, timeout=60)
123
+ if response.status_code != 200:
124
+ raise Exception(f"Failed to get task status: {response.status_code}: {response.text}")
125
+
126
+ result = response.json()
127
+ print(result)
128
+ task_status = result.get("output", {}).get("task_status")
129
+
130
+ if task_status == "SUCCEEDED":
131
+ # Task completed successfully, return video URL
132
+ video_url = result["output"]["results"]["video_url"]
133
+ return video_url, "SUCCEEDED"
134
+ elif task_status == "PENDING" or task_status == "RUNNING":
135
+ # Task is still running, wait and retry
136
+ time.sleep(10) # Wait 10 seconds before polling again
137
+ else:
138
+ # Task failed or unknown, raise an exception with error message
139
+ error_msg = result.get("output", {}).get("message", "Unknown error")
140
+ code_msg = result.get("output", {}).get("code", "Unknown code")
141
+ print(f"\n\nTask failed: {error_msg} Code: {code_msg} TaskId: {task_id}\n\n")
142
+ return None, f"Task failed: {error_msg} Code: {code_msg} TaskId: {task_id}"
143
+
144
+
145
+ def start_app():
146
+ import argparse
147
+ parser = argparse.ArgumentParser(description="Wan2.2-Animate 视频生成工具")
148
+ args = parser.parse_args()
149
+
150
+ url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/video-synthesis/"
151
+ get_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/"
152
+
153
+ app = WanAnimateApp(url=url, get_url=get_url)
154
+
155
+ with gr.Blocks(title="Wan2.2-Animate 视频生成") as demo:
156
+ gr.HTML("""
157
+ <div style="padding: 2rem; text-align: center; max-width: 1200px; margin: 0 auto; font-family: Arial, sans-serif;">
158
+ <h1 style="font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem; color: #333;">
159
+ Wan2.2-Animate: Unified Character Animation and Replacement with Holistic Replication
160
+ </h1>
161
+ <h3 style="font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem; color: #333;">
162
+ Wan2.2-Animate: 统一的角色动画和视频人物替换模型
163
+ </h3>
164
+ <div style="font-size: 1.25rem; margin-bottom: 1.5rem; color: #555;">
165
+ Tongyi Lab, Alibaba
166
+ </div>
167
+ <div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 1rem; margin-bottom: 1rem;">
168
+ <a href="https://arxiv.org/abs/2509.14055" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
169
+ <span style="margin-right: 0.5rem;">📄</span><span>Paper</span>
170
+ </a>
171
+ <a href="https://github.com/Wan-Video/Wan2.2" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
172
+ <span style="margin-right: 0.5rem;">💻</span><span>GitHub</span>
173
+ </a>
174
+ <a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
175
+ <span style="margin-right: 0.5rem;">🤗</span><span>HF Model</span>
176
+ </a>
177
+ <a href="https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
178
+ <span style="margin-right: 0.5rem;">🤖</span><span>MS Model</span>
179
+ </a>
180
+ </div>
181
+ </div>
182
+ """)
183
+ gr.HTML("""
184
+ <details>
185
+ <summary>‼️Usage (使用说明)</summary>
186
+ Wan-Animate supports two mode:
187
+ <ul>
188
+ <li>Move Mode: animate the character in input image with movements from the input video</li>
189
+ <li>Mix Mode: replace the character in input video with the character in input image</li>
190
+ </ul>
191
+ Currently, the following restrictions apply to inputs:
192
+ <ul>
193
+ <li>Video file size: Less than 200MB</li>
194
+ <li>Video resolution: The shorter side must be greater than 200, and the longer side must be less than 2048</li>
195
+ <li>Video duration: 2s to 30s</li>
196
+ <li>Video aspect ratio: 1:3 to 3:1</li>
197
+ <li>Video formats: mp4, avi, mov</li>
198
+ <li>Image file size: Less than 5MB</li>
199
+ <li>Image resolution: The shorter side must be greater than 200, and the longer side must be less than 4096</li>
200
+ <li>Image formats: jpg, png, jpeg, webp, bmp</li>
201
+ </ul>
202
+ <ul>
203
+ <li> wan-pro: 25fps, 720p </li>
204
+ <li> wan-std: 15fps, 720p </li>
205
+ </ul>
206
+ </details>
207
+ """)
208
+ with gr.Row():
209
+ with gr.Column():
210
+ ref_img = gr.Image(
211
+ label="Reference Image(参考图像)",
212
+ type="filepath",
213
+ sources=["upload"],
214
+ )
215
+ video = gr.Video(
216
+ label="Template Video(模版视频)",
217
+ sources=["upload"],
218
+ )
219
+ with gr.Row():
220
+ model_id = gr.Dropdown(
221
+ label="Mode(模式)",
222
+ choices=["wan2.2-animate-move", "wan2.2-animate-mix"],
223
+ value="wan2.2-animate-move",
224
+ info=""
225
+ )
226
+ model = gr.Dropdown(
227
+ label="推理质量(Inference Quality)",
228
+ choices=["wan-pro", "wan-std"],
229
+ value="wan-pro",
230
+ )
231
+ run_button = gr.Button("Generate Video(生成视频)")
232
+ with gr.Column():
233
+ output_video = gr.Video(label="Output Video(输出视频)")
234
+ output_status = gr.Textbox(label="Status(状态)")
235
+
236
+ run_button.click(
237
+ fn=app.predict,
238
+ inputs=[
239
+ ref_img,
240
+ video,
241
+ model_id,
242
+ model,
243
+ ],
244
+ outputs=[output_video, output_status],
245
  )
246
+
247
+ example_data = [
248
+ ['./examples/mov/1/1.jpeg', './examples/mov/1/1.mp4', 'wan2.2-animate-move', 'wan-pro'],
249
+ ['./examples/mov/2/2.jpeg', './examples/mov/2/2.mp4', 'wan2.2-animate-move', 'wan-pro'],
250
+ ['./examples/mix/1/1.jpeg', './examples/mix/1/1.mp4', 'wan2.2-animate-mix', 'wan-pro'],
251
+ ['./examples/mix/2/2.jpeg', './examples/mix/2/2.mp4', 'wan2.2-animate-mix', 'wan-pro']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ if example_data:
255
+ gr.Examples(
256
+ examples=example_data,
257
+ inputs=[ref_img, video, model_id, model],
258
+ outputs=[output_video, output_status],
259
+ fn=app.predict,
260
+ cache_examples="lazy",
261
+ )
 
 
262
 
263
+ demo.queue(default_concurrency_limit=100)
264
+ demo.launch(
265
+ server_name="0.0.0.0",
266
+ server_port=7860
267
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  if __name__ == "__main__":
270
+ start_app()