pormungtai commited on
Commit
12b25cc
·
verified ·
1 Parent(s): 40fe309

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +143 -0
inference.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import logging
4
+ import math
5
+ from omegaconf import OmegaConf
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.jit
12
+ from torchvision.datasets.folder import pil_loader
13
+ from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
14
+ from torchvision.transforms.functional import to_pil_image
15
+
16
+ from mimicmotion.utils.geglu_patch import patch_geglu_inplace
17
+ patch_geglu_inplace()
18
+
19
+ from constants import ASPECT_RATIO
20
+
21
+ from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
22
+ from mimicmotion.utils.loader import create_pipeline
23
+ from mimicmotion.utils.utils import save_to_mp4
24
+ from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose
25
+
26
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s")
27
+ logger = logging.getLogger(__name__)
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+
31
+ def preprocess(video_path, image_path, resolution=576, sample_stride=2):
32
+ """preprocess ref image pose and video pose
33
+
34
+ Args:
35
+ video_path (str): input video pose path
36
+ image_path (str): reference image path
37
+ resolution (int, optional): Defaults to 576.
38
+ sample_stride (int, optional): Defaults to 2.
39
+ """
40
+ image_pixels = pil_loader(image_path)
41
+ image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
42
+ h, w = image_pixels.shape[-2:]
43
+ ############################ compute target h/w according to original aspect ratio ###############################
44
+ if h>w:
45
+ w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
46
+ else:
47
+ w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
48
+ h_w_ratio = float(h) / float(w)
49
+ if h_w_ratio < h_target / w_target:
50
+ h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
51
+ else:
52
+ h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
53
+ image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
54
+ image_pixels = center_crop(image_pixels, [h_target, w_target])
55
+ image_pixels = image_pixels.permute((1, 2, 0)).numpy()
56
+ ##################################### get image&video pose value #################################################
57
+ image_pose = get_image_pose(image_pixels)
58
+ video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride)
59
+ pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
60
+ image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
61
+ return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1
62
+
63
+
64
+ def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config):
65
+ image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]
66
+ generator = torch.Generator(device=device)
67
+ generator.manual_seed(task_config.seed)
68
+ frames = pipeline(
69
+ image_pixels,
70
+ image_pose=pose_pixels,
71
+ num_frames=pose_pixels.size(0),
72
+ tile_size=task_config.num_frames,
73
+ tile_overlap=task_config.frames_overlap,
74
+ height=pose_pixels.shape[-2],
75
+ width=pose_pixels.shape[-1],
76
+ fps=7,
77
+ noise_aug_strength=task_config.noise_aug_strength,
78
+ num_inference_steps=task_config.num_inference_steps,
79
+ generator=generator,
80
+ min_guidance_scale=task_config.guidance_scale,
81
+ max_guidance_scale=task_config.guidance_scale,
82
+ decode_chunk_size=8,
83
+ output_type="pt",
84
+ device=device
85
+ ).frames.cpu()
86
+ video_frames = (frames * 255.0).to(torch.uint8)
87
+ for vid_idx in range(video_frames.shape[0]):
88
+ # deprecated first frame because of ref image
89
+ _video_frames = video_frames[vid_idx, 1:]
90
+ return _video_frames
91
+
92
+
93
+ @torch.no_grad()
94
+ def main(args):
95
+ if not args.no_use_float16:
96
+ torch.set_default_dtype(torch.float16)
97
+ infer_config = OmegaConf.load(args.inference_config)
98
+ pipeline = create_pipeline(infer_config, device)
99
+ for task in infer_config.test_case:
100
+ pose_pixels, image_pixels = preprocess(
101
+ task.ref_video_path,
102
+ task.ref_image_path,
103
+ resolution=task.resolution,
104
+ sample_stride=task.sample_stride
105
+ )
106
+ _video_frames = run_pipeline(
107
+ pipeline,
108
+ image_pixels,
109
+ pose_pixels,
110
+ device,
111
+ task
112
+ )
113
+ save_to_mp4(
114
+ _video_frames,
115
+ f"{args.output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}"
116
+ f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4",
117
+ fps=task.fps,
118
+ )
119
+
120
+
121
+ def set_logger(log_file=None, log_level=logging.INFO):
122
+ log_handler = logging.FileHandler(log_file, "w")
123
+ log_handler.setFormatter(
124
+ logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
125
+ )
126
+ log_handler.setLevel(log_level)
127
+ logger.addHandler(log_handler)
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument("--log_file", type=str, default=None)
133
+ parser.add_argument("--inference_config", type=str, default="configs/test.yaml")
134
+ parser.add_argument("--output_dir", type=str, default="outputs/", help="path to output")
135
+ parser.add_argument("--no_use_float16",
136
+ action="store_true",
137
+ help="Whether use float16 to speed up inference",
138
+ )
139
+ args = parser.parse_args()
140
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
141
+ set_logger(args.log_file if args.log_file is not None else f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log")
142
+ main(args)
143
+ logger.info(f"--- Finished ---")