vinesnt commited on
Commit
7875e3a
·
verified ·
1 Parent(s): f9fb75b

Upload wan_transformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. wan_transformer.py +135 -0
wan_transformer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Union
2
+
3
+ import torch
4
+ from diffusers import WanTransformer3DModel
5
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
6
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
7
+ from wan_teacache import TeaCache
8
+
9
+
10
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11
+
12
+
13
+ class CustomWanTransformer3DModel(WanTransformer3DModel):
14
+ def forward(
15
+ self,
16
+ hidden_states: torch.Tensor,
17
+ timestep: torch.LongTensor,
18
+ encoder_hidden_states: torch.Tensor,
19
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
20
+ return_dict: bool = True,
21
+ attention_kwargs: Optional[Dict[str, Any]] = None,
22
+
23
+ controlnet_states: torch.Tensor = None,
24
+ controlnet_weight: Optional[float] = 1.0,
25
+ controlnet_stride: Optional[int] = 1,
26
+ teacache: Optional[TeaCache] = None,
27
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
28
+ if attention_kwargs is not None:
29
+ attention_kwargs = attention_kwargs.copy()
30
+ lora_scale = attention_kwargs.pop("scale", 1.0)
31
+ else:
32
+ lora_scale = 1.0
33
+
34
+ if USE_PEFT_BACKEND:
35
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
36
+ scale_lora_layers(self, lora_scale)
37
+ else:
38
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
39
+ logger.warning(
40
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
41
+ )
42
+
43
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
44
+ p_t, p_h, p_w = self.config.patch_size
45
+ post_patch_num_frames = num_frames // p_t
46
+ post_patch_height = height // p_h
47
+ post_patch_width = width // p_w
48
+
49
+ rotary_emb = self.rope(hidden_states)
50
+
51
+ hidden_states = self.patch_embedding(hidden_states)
52
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
53
+
54
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
55
+ if timestep.ndim == 2:
56
+ ts_seq_len = timestep.shape[1]
57
+ timestep = timestep.flatten() # batch_size * seq_len
58
+ else:
59
+ ts_seq_len = None
60
+
61
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
62
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
63
+ )
64
+ if ts_seq_len is not None:
65
+ # batch_size, seq_len, 6, inner_dim
66
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
67
+ else:
68
+ # batch_size, 6, inner_dim
69
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
70
+
71
+ if encoder_hidden_states_image is not None:
72
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
73
+
74
+ use_cached_value = False
75
+ original_hidden_states = None
76
+ if (teacache is not None) and (teacache.treshold > 0.0):
77
+ original_hidden_states = hidden_states.clone()
78
+ use_cached_value = teacache.check_for_using_cached_value(temb)
79
+
80
+ if use_cached_value:
81
+ hidden_states = teacache.use_cache(hidden_states)
82
+ else:
83
+ # 4. Transformer blocks
84
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
85
+ for i, block in enumerate(self.blocks):
86
+ hidden_states = self._gradient_checkpointing_func(
87
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
88
+ )
89
+
90
+ if (controlnet_states is not None) and (i % controlnet_stride == 0) and (i // controlnet_stride < len(controlnet_states)):
91
+ hidden_states = hidden_states + controlnet_states[i // controlnet_stride] * controlnet_weight
92
+ else:
93
+ for i, block in enumerate(self.blocks):
94
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
95
+
96
+ if (controlnet_states is not None) and (i % controlnet_stride == 0) and (i // controlnet_stride < len(controlnet_states)):
97
+ hidden_states = hidden_states + controlnet_states[i // controlnet_stride] * controlnet_weight
98
+
99
+ if (teacache is not None) and (teacache.treshold > 0.0):
100
+ teacache.update(hidden_states - original_hidden_states)
101
+
102
+ # 5. Output norm, projection & unpatchify
103
+ if temb.ndim == 3:
104
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
105
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
106
+ shift = shift.squeeze(2)
107
+ scale = scale.squeeze(2)
108
+ else:
109
+ # batch_size, inner_dim
110
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
111
+
112
+ # Move the shift and scale tensors to the same device as hidden_states.
113
+ # When using multi-GPU inference via accelerate these will be on the
114
+ # first device rather than the last device, which hidden_states ends up
115
+ # on.
116
+ shift = shift.to(hidden_states.device)
117
+ scale = scale.to(hidden_states.device)
118
+
119
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
120
+ hidden_states = self.proj_out(hidden_states)
121
+
122
+ hidden_states = hidden_states.reshape(
123
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
124
+ )
125
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
126
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
127
+
128
+ if USE_PEFT_BACKEND:
129
+ # remove `lora_scale` from each PEFT layer
130
+ unscale_lora_layers(self, lora_scale)
131
+
132
+ if not return_dict:
133
+ return (output,)
134
+
135
+ return Transformer2DModelOutput(sample=output)