Spaces:
Running on Zero
Running on Zero
Create inference.py
Browse files- inference.py +143 -0
inference.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.jit
|
| 12 |
+
from torchvision.datasets.folder import pil_loader
|
| 13 |
+
from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
|
| 14 |
+
from torchvision.transforms.functional import to_pil_image
|
| 15 |
+
|
| 16 |
+
from mimicmotion.utils.geglu_patch import patch_geglu_inplace
|
| 17 |
+
patch_geglu_inplace()
|
| 18 |
+
|
| 19 |
+
from constants import ASPECT_RATIO
|
| 20 |
+
|
| 21 |
+
from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
|
| 22 |
+
from mimicmotion.utils.loader import create_pipeline
|
| 23 |
+
from mimicmotion.utils.utils import save_to_mp4
|
| 24 |
+
from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s")
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def preprocess(video_path, image_path, resolution=576, sample_stride=2):
|
| 32 |
+
"""preprocess ref image pose and video pose
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
video_path (str): input video pose path
|
| 36 |
+
image_path (str): reference image path
|
| 37 |
+
resolution (int, optional): Defaults to 576.
|
| 38 |
+
sample_stride (int, optional): Defaults to 2.
|
| 39 |
+
"""
|
| 40 |
+
image_pixels = pil_loader(image_path)
|
| 41 |
+
image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
|
| 42 |
+
h, w = image_pixels.shape[-2:]
|
| 43 |
+
############################ compute target h/w according to original aspect ratio ###############################
|
| 44 |
+
if h>w:
|
| 45 |
+
w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
|
| 46 |
+
else:
|
| 47 |
+
w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
|
| 48 |
+
h_w_ratio = float(h) / float(w)
|
| 49 |
+
if h_w_ratio < h_target / w_target:
|
| 50 |
+
h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
|
| 51 |
+
else:
|
| 52 |
+
h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
|
| 53 |
+
image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
|
| 54 |
+
image_pixels = center_crop(image_pixels, [h_target, w_target])
|
| 55 |
+
image_pixels = image_pixels.permute((1, 2, 0)).numpy()
|
| 56 |
+
##################################### get image&video pose value #################################################
|
| 57 |
+
image_pose = get_image_pose(image_pixels)
|
| 58 |
+
video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride)
|
| 59 |
+
pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
|
| 60 |
+
image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
|
| 61 |
+
return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config):
|
| 65 |
+
image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]
|
| 66 |
+
generator = torch.Generator(device=device)
|
| 67 |
+
generator.manual_seed(task_config.seed)
|
| 68 |
+
frames = pipeline(
|
| 69 |
+
image_pixels,
|
| 70 |
+
image_pose=pose_pixels,
|
| 71 |
+
num_frames=pose_pixels.size(0),
|
| 72 |
+
tile_size=task_config.num_frames,
|
| 73 |
+
tile_overlap=task_config.frames_overlap,
|
| 74 |
+
height=pose_pixels.shape[-2],
|
| 75 |
+
width=pose_pixels.shape[-1],
|
| 76 |
+
fps=7,
|
| 77 |
+
noise_aug_strength=task_config.noise_aug_strength,
|
| 78 |
+
num_inference_steps=task_config.num_inference_steps,
|
| 79 |
+
generator=generator,
|
| 80 |
+
min_guidance_scale=task_config.guidance_scale,
|
| 81 |
+
max_guidance_scale=task_config.guidance_scale,
|
| 82 |
+
decode_chunk_size=8,
|
| 83 |
+
output_type="pt",
|
| 84 |
+
device=device
|
| 85 |
+
).frames.cpu()
|
| 86 |
+
video_frames = (frames * 255.0).to(torch.uint8)
|
| 87 |
+
for vid_idx in range(video_frames.shape[0]):
|
| 88 |
+
# deprecated first frame because of ref image
|
| 89 |
+
_video_frames = video_frames[vid_idx, 1:]
|
| 90 |
+
return _video_frames
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@torch.no_grad()
|
| 94 |
+
def main(args):
|
| 95 |
+
if not args.no_use_float16:
|
| 96 |
+
torch.set_default_dtype(torch.float16)
|
| 97 |
+
infer_config = OmegaConf.load(args.inference_config)
|
| 98 |
+
pipeline = create_pipeline(infer_config, device)
|
| 99 |
+
for task in infer_config.test_case:
|
| 100 |
+
pose_pixels, image_pixels = preprocess(
|
| 101 |
+
task.ref_video_path,
|
| 102 |
+
task.ref_image_path,
|
| 103 |
+
resolution=task.resolution,
|
| 104 |
+
sample_stride=task.sample_stride
|
| 105 |
+
)
|
| 106 |
+
_video_frames = run_pipeline(
|
| 107 |
+
pipeline,
|
| 108 |
+
image_pixels,
|
| 109 |
+
pose_pixels,
|
| 110 |
+
device,
|
| 111 |
+
task
|
| 112 |
+
)
|
| 113 |
+
save_to_mp4(
|
| 114 |
+
_video_frames,
|
| 115 |
+
f"{args.output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}"
|
| 116 |
+
f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4",
|
| 117 |
+
fps=task.fps,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def set_logger(log_file=None, log_level=logging.INFO):
|
| 122 |
+
log_handler = logging.FileHandler(log_file, "w")
|
| 123 |
+
log_handler.setFormatter(
|
| 124 |
+
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
|
| 125 |
+
)
|
| 126 |
+
log_handler.setLevel(log_level)
|
| 127 |
+
logger.addHandler(log_handler)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
parser = argparse.ArgumentParser()
|
| 132 |
+
parser.add_argument("--log_file", type=str, default=None)
|
| 133 |
+
parser.add_argument("--inference_config", type=str, default="configs/test.yaml")
|
| 134 |
+
parser.add_argument("--output_dir", type=str, default="outputs/", help="path to output")
|
| 135 |
+
parser.add_argument("--no_use_float16",
|
| 136 |
+
action="store_true",
|
| 137 |
+
help="Whether use float16 to speed up inference",
|
| 138 |
+
)
|
| 139 |
+
args = parser.parse_args()
|
| 140 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 141 |
+
set_logger(args.log_file if args.log_file is not None else f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log")
|
| 142 |
+
main(args)
|
| 143 |
+
logger.info(f"--- Finished ---")
|