Fred808 commited on
Commit
1d84ccc
·
verified ·
1 Parent(s): 5952a56

Update utils/model_util.py

Browse files
Files changed (1) hide show
  1. utils/model_util.py +124 -132
utils/model_util.py CHANGED
@@ -1,132 +1,124 @@
1
- import torch
2
- from model.mdm import MDM
3
- from diffusion import gaussian_diffusion as gd
4
- from diffusion.respace import SpacedDiffusion, space_timesteps
5
- from utils.parser_util import get_cond_mode
6
- from data_loaders.humanml_utils import HML_EE_JOINT_NAMES
7
-
8
- def load_model_wo_clip(model, state_dict):
9
- # assert (state_dict['sequence_pos_encoder.pe'][:model.sequence_pos_encoder.pe.shape[0]] == model.sequence_pos_encoder.pe).all() # TEST
10
- # assert (state_dict['embed_timestep.sequence_pos_encoder.pe'][:model.embed_timestep.sequence_pos_encoder.pe.shape[0]] == model.embed_timestep.sequence_pos_encoder.pe).all() # TEST
11
- del state_dict['sequence_pos_encoder.pe'] # no need to load it (fixed), and causes size mismatch for older models
12
- del state_dict['embed_timestep.sequence_pos_encoder.pe'] # no need to load it (fixed), and causes size mismatch for older models
13
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
14
- assert len(unexpected_keys) == 0
15
- assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys])
16
-
17
-
18
- def create_model_and_diffusion(args, data):
19
- model = MDM(**get_model_args(args, data))
20
- diffusion = create_gaussian_diffusion(args)
21
- return model, diffusion
22
-
23
-
24
- def get_model_args(args, data):
25
-
26
- # default args
27
- clip_version = 'ViT-B/32'
28
- action_emb = 'tensor'
29
- cond_mode = get_cond_mode(args)
30
- if hasattr(data.dataset, 'num_actions'):
31
- num_actions = data.dataset.num_actions
32
- else:
33
- num_actions = 1
34
-
35
- # SMPL defaults
36
- data_rep = 'rot6d'
37
- njoints = 25
38
- nfeats = 6
39
- all_goal_joint_names = []
40
-
41
- if args.dataset == 'humanml':
42
- data_rep = 'hml_vec'
43
- njoints = 263
44
- nfeats = 1
45
- all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES
46
- elif args.dataset == 'kit':
47
- data_rep = 'hml_vec'
48
- njoints = 251
49
- nfeats = 1
50
-
51
- # Compatibility with old models
52
- if not hasattr(args, 'pred_len'):
53
- args.pred_len = 0
54
- args.context_len = 0
55
-
56
- emb_policy = args.__dict__.get('emb_policy', 'add')
57
- multi_target_cond = args.__dict__.get('multi_target_cond', False)
58
- multi_encoder_type = args.__dict__.get('multi_encoder_type', 'multi')
59
- target_enc_layers = args.__dict__.get('target_enc_layers', 1)
60
-
61
- return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions,
62
- 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True,
63
- 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4,
64
- 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode,
65
- 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch,
66
- 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset,
67
- 'text_encoder_type': args.text_encoder_type,
68
- 'pos_embed_max_len': args.pos_embed_max_len, 'mask_frames': args.mask_frames,
69
- 'pred_len': args.pred_len, 'context_len': args.context_len, 'emb_policy': emb_policy,
70
- 'all_goal_joint_names': all_goal_joint_names, 'multi_target_cond': multi_target_cond, 'multi_encoder_type': multi_encoder_type, 'target_enc_layers': target_enc_layers,
71
- }
72
-
73
-
74
-
75
- def create_gaussian_diffusion(args):
76
- # default params
77
- predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal!
78
- steps = args.diffusion_steps
79
- scale_beta = 1. # no scaling
80
- timestep_respacing = '' # can be used for ddim sampling, we don't use it.
81
- learn_sigma = False
82
- rescale_timesteps = False
83
-
84
- betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
85
- loss_type = gd.LossType.MSE
86
-
87
- if not timestep_respacing:
88
- timestep_respacing = [steps]
89
-
90
- if hasattr(args, 'lambda_target_loc'):
91
- lambda_target_loc = args.lambda_target_loc
92
- else:
93
- lambda_target_loc = 0.
94
-
95
- return SpacedDiffusion(
96
- use_timesteps=space_timesteps(steps, timestep_respacing),
97
- betas=betas,
98
- model_mean_type=(
99
- gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
100
- ),
101
- model_var_type=(
102
- (
103
- gd.ModelVarType.FIXED_LARGE
104
- if not args.sigma_small
105
- else gd.ModelVarType.FIXED_SMALL
106
- )
107
- if not learn_sigma
108
- else gd.ModelVarType.LEARNED_RANGE
109
- ),
110
- loss_type=loss_type,
111
- rescale_timesteps=rescale_timesteps,
112
- lambda_vel=args.lambda_vel,
113
- lambda_rcxyz=args.lambda_rcxyz,
114
- lambda_fc=args.lambda_fc,
115
- lambda_target_loc=lambda_target_loc,
116
- )
117
-
118
- def load_saved_model(model, model_path, use_avg: bool=False): # use_avg_model
119
- state_dict = torch.load(model_path, map_location='cpu')
120
- # Use average model when possible
121
- if use_avg and 'model_avg' in state_dict.keys():
122
- # if use_avg_model:
123
- print('loading avg model')
124
- state_dict = state_dict['model_avg']
125
- else:
126
- if 'model' in state_dict:
127
- print('loading model without avg')
128
- state_dict = state_dict['model']
129
- else:
130
- print('checkpoint has no avg model, loading as usual.')
131
- load_model_wo_clip(model, state_dict)
132
- return model
 
1
+ import torch
2
+ from model.mdm import MDM
3
+ from diffusers import DDPMScheduler
4
+ from utils.parser_util import get_cond_mode
5
+ from data_loaders.humanml_utils import HML_EE_JOINT_NAMES
6
+
7
+ def load_model_wo_clip(model, state_dict):
8
+ """
9
+ Load model weights, skipping positional encodings from CLIP to avoid mismatches.
10
+ """
11
+ # Remove fixed positional encodings to avoid size mismatches
12
+ state_dict.pop('sequence_pos_encoder.pe', None)
13
+ state_dict.pop('embed_timestep.sequence_pos_encoder.pe', None)
14
+
15
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
16
+ assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
17
+ assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys]), \
18
+ f"Missing keys: {missing_keys}"
19
+
20
+
21
+ def create_model_and_diffusion(args, data):
22
+ """
23
+ Instantiate the MDM model and the diffusion scheduler.
24
+ """
25
+ model = MDM(**get_model_args(args, data))
26
+ scheduler = create_diffusion_scheduler(args)
27
+ return model, scheduler
28
+
29
+
30
+ def get_model_args(args, data):
31
+ # Default configuration
32
+ clip_version = 'ViT-B/32'
33
+ action_emb = 'tensor'
34
+ cond_mode = get_cond_mode(args)
35
+ num_actions = getattr(data.dataset, 'num_actions', 1)
36
+
37
+ # Data representation defaults
38
+ if args.dataset == 'humanml':
39
+ data_rep = 'hml_vec'
40
+ njoints, nfeats = 263, 1
41
+ all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES
42
+ elif args.dataset == 'kit':
43
+ data_rep = 'hml_vec'
44
+ njoints, nfeats = 251, 1
45
+ all_goal_joint_names = []
46
+ else:
47
+ data_rep = 'rot6d'
48
+ njoints, nfeats = 25, 6
49
+ all_goal_joint_names = []
50
+
51
+ # Ensure backward compatibility
52
+ args.pred_len = getattr(args, 'pred_len', 0)
53
+ args.context_len = getattr(args, 'context_len', 0)
54
+
55
+ return {
56
+ 'modeltype': '',
57
+ 'njoints': njoints,
58
+ 'nfeats': nfeats,
59
+ 'num_actions': num_actions,
60
+ 'translation': True,
61
+ 'pose_rep': 'rot6d',
62
+ 'glob': True,
63
+ 'glob_rot': True,
64
+ 'latent_dim': args.latent_dim,
65
+ 'ff_size': 1024,
66
+ 'num_layers': args.layers,
67
+ 'num_heads': 4,
68
+ 'dropout': 0.1,
69
+ 'activation': "gelu",
70
+ 'data_rep': data_rep,
71
+ 'cond_mode': cond_mode,
72
+ 'cond_mask_prob': args.cond_mask_prob,
73
+ 'action_emb': action_emb,
74
+ 'arch': args.arch,
75
+ 'emb_trans_dec': args.emb_trans_dec,
76
+ 'clip_version': clip_version,
77
+ 'dataset': args.dataset,
78
+ 'text_encoder_type': args.text_encoder_type,
79
+ 'pos_embed_max_len': args.pos_embed_max_len,
80
+ 'mask_frames': args.mask_frames,
81
+ 'pred_len': args.pred_len,
82
+ 'context_len': args.context_len,
83
+ 'emb_policy': getattr(args, 'emb_policy', 'add'),
84
+ 'all_goal_joint_names': all_goal_joint_names,
85
+ 'multi_target_cond': getattr(args, 'multi_target_cond', False),
86
+ 'multi_encoder_type': getattr(args, 'multi_encoder_type', 'multi'),
87
+ 'target_enc_layers': getattr(args, 'target_enc_layers', 1),
88
+ }
89
+
90
+
91
+ def create_diffusion_scheduler(args):
92
+ """
93
+ Create a DDPM scheduler using Hugging Face's `diffusers` library.
94
+ """
95
+ # Define beta schedule parameters
96
+ beta_start = getattr(args, 'beta_start', 1e-4)
97
+ beta_end = getattr(args, 'beta_end', 0.02)
98
+ beta_schedule = getattr(args, 'noise_schedule', 'linear')
99
+
100
+ scheduler = DDPMScheduler(
101
+ num_train_timesteps=args.diffusion_steps,
102
+ beta_start=beta_start,
103
+ beta_end=beta_end,
104
+ beta_schedule=beta_schedule,
105
+ )
106
+ # Initialize scheduler timesteps
107
+ scheduler.set_timesteps(args.diffusion_steps)
108
+ return scheduler
109
+
110
+
111
+ def load_saved_model(model, model_path, use_avg: bool=False):
112
+ """
113
+ Load weights from a checkpoint, optionally using an averaged model.
114
+ """
115
+ checkpoint = torch.load(model_path, map_location='cpu')
116
+ if use_avg and 'model_avg' in checkpoint:
117
+ state_dict = checkpoint['model_avg']
118
+ elif 'model' in checkpoint:
119
+ state_dict = checkpoint['model']
120
+ else:
121
+ state_dict = checkpoint
122
+
123
+ load_model_wo_clip(model, state_dict)
124
+ return model