| import ast |
| import gc |
| import random |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from diffusers.models.attention_processor import AttnProcessor2_0 |
| from diffusers.models.attention import BasicTransformerBlock |
| from decord import VideoReader |
| import wandb |
|
|
|
|
| def extract_into_tensor(a, t, x_shape): |
| b, *_ = t.shape |
| out = a.gather(-1, t) |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
| def is_attn(name): |
| return "attn1" or "attn2" == name.split(".")[-1] |
|
|
|
|
| def set_processors(attentions): |
| for attn in attentions: |
| attn.set_processor(AttnProcessor2_0()) |
|
|
|
|
| def set_torch_2_attn(unet): |
| optim_count = 0 |
|
|
| for name, module in unet.named_modules(): |
| if is_attn(name): |
| if isinstance(module, torch.nn.ModuleList): |
| for m in module: |
| if isinstance(m, BasicTransformerBlock): |
| set_processors([m.attn1, m.attn2]) |
| optim_count += 1 |
| if optim_count > 0: |
| print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") |
|
|
|
|
| |
| def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): |
| """ |
| See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
| |
| Args: |
| timesteps (`torch.Tensor`): |
| generate embedding vectors at these timesteps |
| embedding_dim (`int`, *optional*, defaults to 512): |
| dimension of the embeddings to generate |
| dtype: |
| data type of the generated embeddings |
| |
| Returns: |
| `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` |
| """ |
| assert len(w.shape) == 1 |
| w = w * 1000.0 |
|
|
| half_dim = embedding_dim // 2 |
| emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
| emb = w.to(dtype)[:, None] * emb[None, :] |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| if embedding_dim % 2 == 1: |
| emb = torch.nn.functional.pad(emb, (0, 1)) |
| assert emb.shape == (w.shape[0], embedding_dim) |
| return emb |
|
|
|
|
| def append_dims(x, target_dims): |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
| dims_to_append = target_dims - x.ndim |
| if dims_to_append < 0: |
| raise ValueError( |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" |
| ) |
| return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
| |
| def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): |
| scaled_timestep = timestep_scaling * timestep |
| c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) |
| c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 |
| return c_skip, c_out |
|
|
|
|
| |
| def get_predicted_original_sample( |
| model_output, timesteps, sample, prediction_type, alphas, sigmas |
| ): |
| alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
| sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
| if prediction_type == "epsilon": |
| pred_x_0 = (sample - sigmas * model_output) / alphas |
| elif prediction_type == "sample": |
| pred_x_0 = model_output |
| elif prediction_type == "v_prediction": |
| pred_x_0 = alphas * sample - sigmas * model_output |
| else: |
| raise ValueError( |
| f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" |
| f" are supported." |
| ) |
|
|
| return pred_x_0 |
|
|
|
|
| |
| def get_predicted_noise( |
| model_output, timesteps, sample, prediction_type, alphas, sigmas |
| ): |
| alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
| sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
| if prediction_type == "epsilon": |
| pred_epsilon = model_output |
| elif prediction_type == "sample": |
| pred_epsilon = (sample - alphas * model_output) / sigmas |
| elif prediction_type == "v_prediction": |
| pred_epsilon = alphas * model_output + sigmas * sample |
| else: |
| raise ValueError( |
| f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" |
| f" are supported." |
| ) |
|
|
| return pred_epsilon |
|
|
|
|
| |
| def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): |
| """ |
| See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
| |
| Args: |
| timesteps (`torch.Tensor`): |
| generate embedding vectors at these timesteps |
| embedding_dim (`int`, *optional*, defaults to 512): |
| dimension of the embeddings to generate |
| dtype: |
| data type of the generated embeddings |
| |
| Returns: |
| `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` |
| """ |
| assert len(w.shape) == 1 |
| w = w * 1000.0 |
|
|
| half_dim = embedding_dim // 2 |
| emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
| emb = w.to(dtype)[:, None] * emb[None, :] |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| if embedding_dim % 2 == 1: |
| emb = torch.nn.functional.pad(emb, (0, 1)) |
| assert emb.shape == (w.shape[0], embedding_dim) |
| return emb |
|
|
|
|
| def append_dims(x, target_dims): |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
| dims_to_append = target_dims - x.ndim |
| if dims_to_append < 0: |
| raise ValueError( |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" |
| ) |
| return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
| |
| def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): |
| scaled_timestep = timestep_scaling * timestep |
| c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) |
| c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 |
| return c_skip, c_out |
|
|
|
|
| |
| def get_predicted_original_sample( |
| model_output, timesteps, sample, prediction_type, alphas, sigmas |
| ): |
| alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
| sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
| if prediction_type == "epsilon": |
| pred_x_0 = (sample - sigmas * model_output) / alphas |
| elif prediction_type == "sample": |
| pred_x_0 = model_output |
| elif prediction_type == "v_prediction": |
| pred_x_0 = alphas * sample - sigmas * model_output |
| else: |
| raise ValueError( |
| f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" |
| f" are supported." |
| ) |
|
|
| return pred_x_0 |
|
|
|
|
| |
| def get_predicted_noise( |
| model_output, timesteps, sample, prediction_type, alphas, sigmas |
| ): |
| alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
| sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
| if prediction_type == "epsilon": |
| pred_epsilon = model_output |
| elif prediction_type == "sample": |
| pred_epsilon = (sample - alphas * model_output) / sigmas |
| elif prediction_type == "v_prediction": |
| pred_epsilon = alphas * model_output + sigmas * sample |
| else: |
| raise ValueError( |
| f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" |
| f" are supported." |
| ) |
|
|
| return pred_epsilon |
|
|
|
|
| def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): |
| extra_params = extra_params if len(extra_params.keys()) > 0 else None |
| return { |
| "model": model, |
| "condition": condition, |
| "extra_params": extra_params, |
| "is_lora": is_lora, |
| "negation": negation, |
| } |
|
|
|
|
| def create_optim_params(name="param", params=None, lr=5e-6, extra_params=None): |
| params = {"name": name, "params": params, "lr": lr} |
| if extra_params is not None: |
| for k, v in extra_params.items(): |
| params[k] = v |
|
|
| return params |
|
|
|
|
| def create_optimizer_params(model_list, lr): |
| import itertools |
|
|
| optimizer_params = [] |
|
|
| for optim in model_list: |
| model, condition, extra_params, is_lora, negation = optim.values() |
| |
| if is_lora and condition and isinstance(model, list): |
| params = create_optim_params( |
| params=itertools.chain(*model), extra_params=extra_params |
| ) |
| optimizer_params.append(params) |
| continue |
|
|
| if is_lora and condition and not isinstance(model, list): |
| for n, p in model.named_parameters(): |
| if "lora" in n: |
| params = create_optim_params(n, p, lr, extra_params) |
| optimizer_params.append(params) |
| continue |
|
|
| |
| if condition: |
| for n, p in model.named_parameters(): |
| should_negate = "lora" in n and not is_lora |
| if should_negate: |
| continue |
|
|
| params = create_optim_params(n, p, lr, extra_params) |
| optimizer_params.append(params) |
|
|
| return optimizer_params |
|
|
|
|
| def handle_trainable_modules( |
| model, trainable_modules=None, is_enabled=True, negation=None |
| ): |
| acc = [] |
| unfrozen_params = 0 |
|
|
| if trainable_modules is not None: |
| unlock_all = any([name == "all" for name in trainable_modules]) |
| if unlock_all: |
| model.requires_grad_(True) |
| unfrozen_params = len(list(model.parameters())) |
| else: |
| model.requires_grad_(False) |
| for name, param in model.named_parameters(): |
| for tm in trainable_modules: |
| if all([tm in name, name not in acc, "lora" not in name]): |
| param.requires_grad_(is_enabled) |
| acc.append(name) |
| unfrozen_params += 1 |
|
|
|
|
| def huber_loss(pred, target, huber_c=0.001): |
| loss = torch.sqrt((pred.float() - target.float()) ** 2 + huber_c**2) - huber_c |
| return loss.mean() |
|
|
|
|
| @torch.no_grad() |
| def update_ema(target_params, source_params, rate=0.99): |
| """ |
| Update target parameters to be closer to those of source parameters using |
| an exponential moving average. |
| |
| :param target_params: the target parameter sequence. |
| :param source_params: the source parameter sequence. |
| :param rate: the EMA rate (closer to 1 means slower). |
| """ |
| for targ, src in zip(target_params, source_params): |
| src_to_dtype = src.to(targ.dtype) |
| targ.detach().mul_(rate).add_(src_to_dtype, alpha=1 - rate) |
|
|
|
|
| def log_validation_video(pipeline, args, accelerator, save_fps): |
| if args.seed is None: |
| generator = None |
| else: |
| generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) |
|
|
| validation_prompts = [ |
| "An astronaut riding a horse.", |
| "Darth vader surfing in waves.", |
| "Robot dancing in times square.", |
| "Clown fish swimming through the coral reef.", |
| "A child excitedly swings on a rusty swing set, laughter filling the air.", |
| "With the style of van gogh, A young couple dances under the moonlight by the lake.", |
| "A young woman with glasses is jogging in the park wearing a pink headband.", |
| "Impressionist style, a yellow rubber duck floating on the wave on the sunset", |
| "Wolf, turns its head, in the wild", |
| "Iron man, walks, on the moon, 8k, high detailed, best quality", |
| "With the style of low-poly game art, A majestic, white horse gallops gracefully", |
| "a rabbit, low-poly game art style", |
| ] |
|
|
| video_logs = [] |
|
|
| if getattr(args, "use_motion_cond", False): |
| use_motion_cond = True |
| else: |
| use_motion_cond = False |
|
|
| for _, prompt in enumerate(validation_prompts): |
| if use_motion_cond: |
| motin_gs_unit = (args.motion_gs_max - args.motion_gs_min) / 2 |
| for i in range(3): |
| with torch.autocast("cuda"): |
| videos = pipeline( |
| prompt=prompt, |
| frames=args.n_frames, |
| num_inference_steps=8, |
| num_videos_per_prompt=1, |
| fps=args.fps, |
| use_motion_cond=True, |
| motion_gs=motin_gs_unit * i, |
| lcm_origin_steps=args.num_ddim_timesteps, |
| generator=generator, |
| ) |
| videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0 |
| videos = ( |
| (videos * 255) |
| .to(torch.uint8) |
| .permute(0, 2, 1, 3, 4) |
| .cpu() |
| .numpy() |
| ) |
| video_logs.append( |
| { |
| "validation_prompt": f"GS={i * motin_gs_unit}, {prompt}", |
| "videos": videos, |
| } |
| ) |
| else: |
| for i in range(2): |
| with torch.autocast("cuda"): |
| videos = pipeline( |
| prompt=prompt, |
| frames=args.n_frames, |
| num_inference_steps=4 * (i + 1), |
| num_videos_per_prompt=1, |
| fps=args.fps, |
| use_motion_cond=False, |
| lcm_origin_steps=args.num_ddim_timesteps, |
| generator=generator, |
| ) |
| videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0 |
| videos = ( |
| (videos * 255) |
| .to(torch.uint8) |
| .permute(0, 2, 1, 3, 4) |
| .cpu() |
| .numpy() |
| ) |
| video_logs.append( |
| { |
| "validation_prompt": f"Steps={4 * (i + 1)}, {prompt}", |
| "videos": videos, |
| } |
| ) |
|
|
| for tracker in accelerator.trackers: |
| if tracker.name == "wandb": |
| formatted_videos = [] |
| for log in video_logs: |
| videos = log["videos"] |
| validation_prompt = log["validation_prompt"] |
| for video in videos: |
| video = wandb.Video(video, caption=validation_prompt, fps=save_fps) |
| formatted_videos.append(video) |
|
|
| tracker.log({f"validation": formatted_videos}) |
|
|
| del pipeline |
| gc.collect() |
|
|
|
|
| def tuple_type(s): |
| if isinstance(s, tuple): |
| return s |
| value = ast.literal_eval(s) |
| if isinstance(value, tuple): |
| return value |
| raise TypeError("Argument must be a tuple") |
|
|
|
|
| def load_model_checkpoint(model, ckpt): |
| def load_checkpoint(model, ckpt, full_strict): |
| state_dict = torch.load(ckpt, map_location="cpu", weights_only=True) |
| if "state_dict" in list(state_dict.keys()): |
| state_dict = state_dict["state_dict"] |
| model.load_state_dict(state_dict, strict=full_strict) |
| del state_dict |
| gc.collect() |
| return model |
|
|
| load_checkpoint(model, ckpt, full_strict=True) |
| print(">>> model checkpoint loaded.") |
| return model |
|
|
|
|
| def read_video_to_tensor( |
| path_to_video, sample_fps, sample_frames, uniform_sampling=False |
| ): |
| video_reader = VideoReader(path_to_video) |
| video_fps = video_reader.get_avg_fps() |
| video_frames = video_reader._num_frame |
| video_duration = video_frames / video_fps |
| sample_duration = sample_frames / sample_fps |
| stride = video_fps / sample_fps |
|
|
| if uniform_sampling or video_duration <= sample_duration: |
| index_range = np.linspace(0, video_frames - 1, sample_frames).astype(np.int32) |
| else: |
| max_start_frame = video_frames - np.ceil(sample_frames * stride).astype( |
| np.int32 |
| ) |
| if max_start_frame > 0: |
| start_frame = random.randint(0, max_start_frame) |
| else: |
| start_frame = 0 |
|
|
| index_range = start_frame + np.arange(sample_frames) * stride |
| index_range = np.round(index_range).astype(np.int32) |
|
|
| sampled_frames = video_reader.get_batch(index_range).asnumpy() |
| pixel_values = torch.from_numpy(sampled_frames).permute(0, 3, 1, 2).contiguous() |
| pixel_values = pixel_values / 255.0 |
| del video_reader |
|
|
| return pixel_values |
|
|
|
|
| def calculate_motion_rank_new(tensor_ref, tensor_gen, rank_k=1): |
| if rank_k == 0: |
| loss = torch.tensor(0.0, device=tensor_ref.device) |
| elif rank_k > tensor_ref.shape[-1]: |
| raise ValueError( |
| "The value of rank_k cannot be larger than the number of frames" |
| ) |
| else: |
| |
| _, sorted_indices = torch.sort(tensor_ref, dim=-1) |
| |
| mask = torch.zeros_like(tensor_ref, dtype=torch.bool) |
| mask.scatter_(-1, sorted_indices[..., -rank_k:], True) |
| |
| loss = F.mse_loss(tensor_ref[mask].detach(), tensor_gen[mask]) |
| return loss |
|
|
|
|
| def compute_temp_loss(attention_prob, attention_prob_example): |
| temp_attn_prob_loss = [] |
| |
| for name in attention_prob.keys(): |
| attn_prob_example = attention_prob_example[name] |
| attn_prob = attention_prob[name] |
|
|
| module_attn_loss = calculate_motion_rank_new( |
| attn_prob_example.detach(), attn_prob, rank_k=1 |
| ) |
| temp_attn_prob_loss.append(module_attn_loss) |
|
|
| loss_temp = torch.stack(temp_attn_prob_loss) * 100 |
| loss = loss_temp.mean() |
| return loss |
|
|