tracyjxm commited on
Commit
b31d63a
·
verified ·
1 Parent(s): 1bf77b5

Upload folder using huggingface_hub

Browse files
Files changed (35) hide show
  1. kandinsky3/.ipynb_checkpoints/__init__-checkpoint.py +267 -0
  2. kandinsky3/.ipynb_checkpoints/condition_encoders-checkpoint.py +40 -0
  3. kandinsky3/.ipynb_checkpoints/condition_processors-checkpoint.py +34 -0
  4. kandinsky3/.ipynb_checkpoints/inpainting_pipeline-checkpoint.py +168 -0
  5. kandinsky3/.ipynb_checkpoints/movq-checkpoint.py +431 -0
  6. kandinsky3/.ipynb_checkpoints/t2i_pipeline-checkpoint.py +109 -0
  7. kandinsky3/.ipynb_checkpoints/utils-checkpoint.py +71 -0
  8. kandinsky3/__init__.py +267 -0
  9. kandinsky3/__pycache__/__init__.cpython-310.pyc +0 -0
  10. kandinsky3/__pycache__/condition_encoders.cpython-310.pyc +0 -0
  11. kandinsky3/__pycache__/condition_processors.cpython-310.pyc +0 -0
  12. kandinsky3/__pycache__/inpainting_pipeline.cpython-310.pyc +0 -0
  13. kandinsky3/__pycache__/movq.cpython-310.pyc +0 -0
  14. kandinsky3/__pycache__/t2i_pipeline.cpython-310.pyc +0 -0
  15. kandinsky3/__pycache__/utils.cpython-310.pyc +0 -0
  16. kandinsky3/condition_encoders.py +40 -0
  17. kandinsky3/condition_processors.py +34 -0
  18. kandinsky3/inpainting_pipeline.py +168 -0
  19. kandinsky3/model/.ipynb_checkpoints/diffusion-checkpoint.py +200 -0
  20. kandinsky3/model/.ipynb_checkpoints/unet-checkpoint.py +516 -0
  21. kandinsky3/model/__init__.py +0 -0
  22. kandinsky3/model/__pycache__/__init__.cpython-310.pyc +0 -0
  23. kandinsky3/model/__pycache__/diffusion.cpython-310.pyc +0 -0
  24. kandinsky3/model/__pycache__/nn.cpython-310.pyc +0 -0
  25. kandinsky3/model/__pycache__/unet.cpython-310.pyc +0 -0
  26. kandinsky3/model/__pycache__/utils.cpython-310.pyc +0 -0
  27. kandinsky3/model/diffusion.py +200 -0
  28. kandinsky3/model/nn.py +84 -0
  29. kandinsky3/model/unet.py +516 -0
  30. kandinsky3/model/utils.py +62 -0
  31. kandinsky3/movq.py +431 -0
  32. kandinsky3/setup.py +38 -0
  33. kandinsky3/t2i_pipeline.py +109 -0
  34. kandinsky3/utils.py +71 -0
  35. unet_model_checkpoint.pt +3 -0
kandinsky3/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download, snapshot_download
6
+
7
+ from kandinsky3.model.unet import UNet
8
+ from kandinsky3.movq import MoVQ
9
+ from kandinsky3.condition_encoders import T5TextConditionEncoder
10
+ from kandinsky3.condition_processors import T5TextConditionProcessor
11
+ from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
12
+
13
+ from .t2i_pipeline import Kandinsky3T2IPipeline
14
+ from .inpainting_pipeline import Kandinsky3InpaintingPipeline
15
+
16
+
17
+ def get_T2I_unet(
18
+ device: Union[str, torch.device],
19
+ weights_path: Optional[str] = None,
20
+ dtype: Union[str, torch.dtype] = torch.float32,
21
+ ) -> (UNet, Optional[torch.Tensor], Optional[dict]):
22
+ unet = UNet(
23
+ model_channels=384,
24
+ num_channels=4,
25
+ init_channels=192,
26
+ time_embed_dim=1536,
27
+ context_dim=4096,
28
+ groups=32,
29
+ head_dim=64,
30
+ expansion_ratio=4,
31
+ compression_ratio=2,
32
+ dim_mult=(1, 2, 4, 8),
33
+ num_blocks=(3, 3, 3, 3),
34
+ add_cross_attention=(False, True, True, True),
35
+ add_self_attention=(False, True, True, True),
36
+ )
37
+
38
+ null_embedding = None
39
+ if weights_path:
40
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
41
+ null_embedding = state_dict['null_embedding']
42
+ unet.load_state_dict(state_dict['unet'])
43
+
44
+ unet.to(device=device, dtype=dtype).eval()
45
+ return unet, null_embedding
46
+
47
+
48
+ def get_T5encoder(
49
+ device: Union[str, torch.device],
50
+ weights_path: str,
51
+ projection_name: str,
52
+ dtype: Union[str, torch.dtype] = torch.float32,
53
+ low_cpu_mem_usage: bool = True,
54
+ load_in_8bit: bool = False,
55
+ load_in_4bit: bool = False,
56
+ ) -> (T5TextConditionProcessor, T5TextConditionEncoder):
57
+ tokens_length = 128
58
+ context_dim = 4096
59
+ processor = T5TextConditionProcessor(tokens_length, weights_path)
60
+ condition_encoder = T5TextConditionEncoder(
61
+ weights_path, context_dim, low_cpu_mem_usage=low_cpu_mem_usage, device=device,
62
+ dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
63
+ )
64
+
65
+ if weights_path:
66
+ projections_weights_path = os.path.join(weights_path, projection_name)
67
+ state_dict = torch.load(projections_weights_path, map_location=torch.device('cpu'))
68
+ condition_encoder.projection.load_state_dict(state_dict)
69
+
70
+ condition_encoder.projection.to(device=device, dtype=dtype).eval()
71
+ return processor, condition_encoder
72
+
73
+
74
+ def get_movq(
75
+ device: Union[str, torch.device],
76
+ weights_path: Optional[str] = None,
77
+ dtype: Union[str, torch.dtype] = torch.float32,
78
+ ) -> MoVQ:
79
+ generator_config = {
80
+ 'double_z': False,
81
+ 'z_channels': 4,
82
+ 'resolution': 256,
83
+ 'in_channels': 3,
84
+ 'out_ch': 3,
85
+ 'ch': 256,
86
+ 'ch_mult': [1, 2, 2, 4],
87
+ 'num_res_blocks': 2,
88
+ 'attn_resolutions': [32],
89
+ 'dropout': 0.0
90
+ }
91
+ movq = MoVQ(generator_config)
92
+
93
+ if weights_path:
94
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
95
+ movq.load_state_dict(state_dict)
96
+
97
+ movq.to(device=device, dtype=dtype).eval()
98
+ return movq
99
+
100
+
101
+ def get_inpainting_unet(
102
+ device: Union[str, torch.device],
103
+ weights_path: Optional[str] = None,
104
+ dtype: Union[str, torch.dtype] = torch.float32,
105
+ ) -> (UNet, Optional[torch.Tensor], Optional[dict]):
106
+ unet = UNet(
107
+ model_channels=384,
108
+ num_channels=9,
109
+ init_channels=192,
110
+ time_embed_dim=1536,
111
+ context_dim=4096,
112
+ groups=32,
113
+ head_dim=64,
114
+ expansion_ratio=4,
115
+ compression_ratio=2,
116
+ dim_mult=(1, 2, 4, 8),
117
+ num_blocks=(3, 3, 3, 3),
118
+ add_cross_attention=(False, True, True, True),
119
+ add_self_attention=(False, True, True, True),
120
+ )
121
+
122
+ null_embedding = None
123
+ if weights_path:
124
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
125
+ null_embedding = state_dict['null_embedding']
126
+ unet.load_state_dict(state_dict['unet'])
127
+
128
+ unet.to(device=device, dtype=dtype).eval()
129
+ return unet, null_embedding
130
+
131
+
132
+ def get_T2I_pipeline(
133
+ device_map: Union[str, torch.device, dict],
134
+ dtype_map: Union[str, torch.dtype, dict] = torch.float32,
135
+ low_cpu_mem_usage: bool = True,
136
+ load_in_8bit: bool = False,
137
+ load_in_4bit: bool = False,
138
+ cache_dir: str = '/tmp/kandinsky3/',
139
+ unet_path: str = None,
140
+ text_encoder_path: str = None,
141
+ movq_path: str = None,
142
+ ) -> Kandinsky3T2IPipeline:
143
+ # assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
144
+ if not isinstance(device_map, dict):
145
+ device_map = {
146
+ 'unet': device_map, 'text_encoder': device_map, 'movq': device_map
147
+ }
148
+ if not isinstance(dtype_map, dict):
149
+ dtype_map = {
150
+ 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
151
+ }
152
+
153
+ if unet_path is None:
154
+ unet_path = hf_hub_download(
155
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3.pt', cache_dir=cache_dir
156
+ )
157
+ if text_encoder_path is None:
158
+ text_encoder_path = snapshot_download(
159
+ repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
160
+ )
161
+ text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
162
+ if movq_path is None:
163
+ movq_path = hf_hub_download(
164
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
165
+ )
166
+
167
+ unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
168
+ processor, condition_encoder = get_T5encoder(
169
+ device_map['text_encoder'], text_encoder_path, 'projection.pt', dtype=dtype_map['text_encoder'],
170
+ low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
171
+ )
172
+ movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
173
+ return Kandinsky3T2IPipeline(
174
+ device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, False
175
+ )
176
+
177
+
178
+ def get_T2I_Flash_pipeline(
179
+ device_map: Union[str, torch.device, dict],
180
+ dtype_map: Union[str, torch.dtype, dict] = torch.float32,
181
+ low_cpu_mem_usage: bool = True,
182
+ load_in_8bit: bool = False,
183
+ load_in_4bit: bool = False,
184
+ cache_dir: str = '/tmp/kandinsky3/',
185
+ unet_path: str = None,
186
+ text_encoder_path: str = None,
187
+ movq_path: str = None,
188
+ ) -> Kandinsky3T2IPipeline:
189
+ # assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
190
+ if not isinstance(device_map, dict):
191
+ device_map = {
192
+ 'unet': device_map, 'text_encoder': device_map, 'movq': device_map
193
+ }
194
+ if not isinstance(dtype_map, dict):
195
+ dtype_map = {
196
+ 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
197
+ }
198
+
199
+ if unet_path is None:
200
+ unet_path = hf_hub_download(
201
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir
202
+ )
203
+ if text_encoder_path is None:
204
+ text_encoder_path = snapshot_download(
205
+ repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
206
+ )
207
+ text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
208
+ if movq_path is None:
209
+ movq_path = hf_hub_download(
210
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
211
+ )
212
+
213
+ unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
214
+ processor, condition_encoder = get_T5encoder(
215
+ device_map['text_encoder'], text_encoder_path, 'projection_flash.pt', dtype=dtype_map['text_encoder'],
216
+ low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
217
+ )
218
+ movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
219
+ return Kandinsky3T2IPipeline(
220
+ device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, True
221
+ )
222
+
223
+
224
+ def get_inpainting_pipeline(
225
+ device_map: Union[str, torch.device, dict],
226
+ dtype_map: Union[str, torch.dtype, dict] = torch.float32,
227
+ low_cpu_mem_usage: bool = True,
228
+ load_in_8bit: bool = False,
229
+ load_in_4bit: bool = False,
230
+ cache_dir: str = '/tmp/kandinsky3/',
231
+ unet_path: str = None,
232
+ text_encoder_path: str = None,
233
+ movq_path: str = None,
234
+ ) -> Kandinsky3InpaintingPipeline:
235
+ # assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
236
+ if not isinstance(device_map, dict):
237
+ device_map = {
238
+ 'unet': device_map, 'text_encoder': device_map, 'movq': device_map
239
+ }
240
+ if not isinstance(dtype_map, dict):
241
+ dtype_map = {
242
+ 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
243
+ }
244
+
245
+ if unet_path is None:
246
+ unet_path = hf_hub_download(
247
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_inpainting.pt', cache_dir=cache_dir
248
+ )
249
+ if text_encoder_path is None:
250
+ text_encoder_path = snapshot_download(
251
+ repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
252
+ )
253
+ text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
254
+ if movq_path is None:
255
+ movq_path = hf_hub_download(
256
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
257
+ )
258
+
259
+ unet, null_embedding = get_inpainting_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
260
+ processor, condition_encoder = get_T5encoder(
261
+ device_map['text_encoder'], text_encoder_path, 'projection_inpainting.pt', dtype=dtype_map['text_encoder'],
262
+ low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
263
+ )
264
+ movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
265
+ return Kandinsky3InpaintingPipeline(
266
+ device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq
267
+ )
kandinsky3/.ipynb_checkpoints/condition_encoders-checkpoint.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import T5EncoderModel
4
+ from typing import Optional, Union
5
+
6
+
7
+ class T5TextConditionEncoder(nn.Module):
8
+
9
+ def __init__(
10
+ self, model_path, context_dim,
11
+ low_cpu_mem_usage: bool = True, device: Optional[str] = None,
12
+ dtype: Union[str, torch.dtype] = torch.float32, load_in_4bit: bool = False, load_in_8bit: bool = False
13
+ ):
14
+ super().__init__()
15
+ self.encoder = T5EncoderModel.from_pretrained(
16
+ model_path, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device,
17
+ torch_dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit,
18
+ ).encoder
19
+ self.projection = nn.Sequential(
20
+ nn.Linear(self.encoder.config.d_model, context_dim, bias=False),
21
+ nn.LayerNorm(context_dim)
22
+ )
23
+
24
+ def forward(self, model_input):
25
+ embeddings = self.encoder(**model_input).last_hidden_state
26
+ context = self.projection(embeddings)
27
+ if 'attention_mask' in model_input:
28
+ context_mask = model_input['attention_mask']
29
+ context[context_mask == 0] = torch.zeros_like(context[context_mask == 0])
30
+ max_seq_length = context_mask.sum(-1).max() + 1
31
+ context = context[:, :max_seq_length]
32
+ context_mask = context_mask[:, :max_seq_length]
33
+ else:
34
+ context_mask = torch.ones(*embeddings.shape[:-1], dtype=torch.long, device=embeddings.device)
35
+ return context, context_mask
36
+
37
+
38
+ def get_condition_encoder(conf):
39
+ return T5TextConditionEncoder(**conf)
40
+
kandinsky3/.ipynb_checkpoints/condition_processors-checkpoint.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5Tokenizer
3
+
4
+
5
+ class T5TextConditionProcessor:
6
+
7
+ def __init__(self, tokens_length, processor_path):
8
+ self.tokens_length = tokens_length
9
+ self.processor = T5Tokenizer.from_pretrained(processor_path)
10
+
11
+ def encode(self, text=None, negative_text=None):
12
+ encoded = self.processor(text, max_length=self.tokens_length, truncation=True)
13
+ pad_length = self.tokens_length - len(encoded['input_ids'])
14
+ input_ids = encoded['input_ids'] + [self.processor.pad_token_id] * pad_length
15
+ attention_mask = encoded['attention_mask'] + [0] * pad_length
16
+ condition_model_input = {
17
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
18
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
19
+ }
20
+
21
+ if negative_text is not None:
22
+ negative_encoded = self.processor(negative_text, max_length=self.tokens_length, truncation=True)
23
+ negative_input_ids = negative_encoded['input_ids'][:len(encoded['input_ids'])]
24
+ negative_input_ids[-1] = self.processor.eos_token_id
25
+ negative_pad_length = self.tokens_length - len(negative_input_ids)
26
+ negative_input_ids = negative_input_ids + [self.processor.pad_token_id] * negative_pad_length
27
+ negative_attention_mask = encoded['attention_mask'] + [0] * pad_length
28
+ negative_condition_model_input = {
29
+ 'input_ids': torch.tensor(negative_input_ids, dtype=torch.long),
30
+ 'attention_mask': torch.tensor(negative_attention_mask, dtype=torch.long)
31
+ }
32
+ else:
33
+ negative_condition_model_input = None
34
+ return condition_model_input, negative_condition_model_input
kandinsky3/.ipynb_checkpoints/inpainting_pipeline-checkpoint.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+ import PIL
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from einops import repeat
8
+
9
+ from kandinsky3.model.unet import UNet
10
+ from kandinsky3.movq import MoVQ
11
+ from kandinsky3.condition_encoders import T5TextConditionEncoder
12
+ from kandinsky3.condition_processors import T5TextConditionProcessor
13
+ from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
14
+ from kandinsky3.utils import resize_image_for_diffusion, resize_mask_for_diffusion
15
+
16
+
17
+ class Kandinsky3InpaintingPipeline:
18
+
19
+ def __init__(
20
+ self,
21
+ device_map: Union[str, torch.device, dict],
22
+ dtype_map: Union[str, torch.dtype, dict],
23
+ unet: UNet,
24
+ null_embedding: torch.Tensor,
25
+ t5_processor: T5TextConditionProcessor,
26
+ t5_encoder: T5TextConditionEncoder,
27
+ movq: MoVQ,
28
+ ):
29
+ self.device_map = device_map
30
+ self.dtype_map = dtype_map
31
+ self.to_pil = T.ToPILImage()
32
+ self.to_tensor = T.ToTensor()
33
+
34
+ self.unet = unet
35
+ self.null_embedding = null_embedding
36
+ self.t5_processor = t5_processor
37
+ self.t5_encoder = t5_encoder
38
+ self.movq = movq
39
+
40
+ def shared_step(self, batch: dict) -> dict:
41
+ image = batch['image']
42
+ condition_model_input = batch['text']
43
+ negative_condition_model_input = batch['negative_text']
44
+
45
+ bs = image.shape[0]
46
+
47
+ masked_latent = None
48
+ mask = batch['mask']
49
+
50
+ if 'masked_image' in batch:
51
+ masked_latent = batch['masked_image']
52
+ elif self.unet.in_layer.in_channels == 9:
53
+ masked_latent = image.masked_fill((1 - mask).bool(), 0)
54
+ else:
55
+ raise ValueError()
56
+
57
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
58
+ masked_latent = self.movq.encode(masked_latent)
59
+ mask = torch.nn.functional.interpolate(mask, size=(masked_latent.shape[2], masked_latent.shape[3]))
60
+
61
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
62
+ context, context_mask = self.t5_encoder(condition_model_input)
63
+
64
+ if negative_condition_model_input is not None:
65
+ negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
66
+ else:
67
+ negative_context, negative_context_mask = None, None
68
+
69
+ return {
70
+ 'context': context,
71
+ 'context_mask': context_mask,
72
+ 'negative_context': negative_context,
73
+ 'negative_context_mask': negative_context_mask,
74
+ 'image': image,
75
+ 'masked_latent': masked_latent,
76
+ 'mask': mask
77
+ }
78
+
79
+ def prepare_batch(
80
+ self,
81
+ text: str,
82
+ negative_text: str,
83
+ image: PIL.Image.Image,
84
+ mask: np.ndarray,
85
+ ) -> dict:
86
+ condition_model_input, negative_condition_model_input = self.t5_processor.encode(
87
+ text=text, negative_text=negative_text
88
+ )
89
+ batch = {
90
+ 'image': self.to_tensor(resize_image_for_diffusion(image.convert("RGB"))) * 2 - 1,
91
+ 'mask': 1 - self.to_tensor(resize_mask_for_diffusion(mask)),
92
+ 'text': condition_model_input,
93
+ 'negative_text': negative_condition_model_input
94
+ }
95
+ batch['mask'] = batch['mask'].type(self.dtype_map['movq'])
96
+
97
+ batch['image'] = batch['image'].unsqueeze(0).to(self.device_map['movq'])
98
+ batch['text']['input_ids'] = batch['text']['input_ids'].unsqueeze(0).to(self.device_map['text_encoder'])
99
+ batch['text']['attention_mask'] = batch['text']['attention_mask'].unsqueeze(0).to(
100
+ self.device_map['text_encoder'])
101
+ batch['mask'] = batch['mask'].unsqueeze(0).to(self.device_map['movq'])
102
+
103
+ if negative_condition_model_input is not None:
104
+ batch['negative_text']['input_ids'] = batch['negative_text']['input_ids'].to(
105
+ self.device_map['text_encoder'])
106
+ batch['negative_text']['attention_mask'] = batch['negative_text']['attention_mask'].to(
107
+ self.device_map['text_encoder'])
108
+
109
+ return batch
110
+
111
+ def __call__(
112
+ self,
113
+ text: str,
114
+ image: PIL.Image.Image,
115
+ mask: np.ndarray,
116
+ negative_text: str = None,
117
+ images_num: int = 1,
118
+ bs: int = 1,
119
+ steps: int = 50,
120
+ guidance_weight_text: float = 4,
121
+ eta=1.0
122
+ ) -> List[PIL.Image.Image]:
123
+
124
+ with torch.no_grad():
125
+ batch = self.prepare_batch(text, negative_text, image, mask)
126
+ processed = self.shared_step(batch)
127
+ betas = get_named_beta_schedule('cosine', 1000)
128
+ base_diffusion = BaseDiffusion(betas, percentile=0.95)
129
+ times = list(range(999, 0, -1000 // steps))
130
+
131
+ pil_images = []
132
+ k, m = images_num // bs, images_num % bs
133
+ for minibatch in [bs] * k + [m]:
134
+ if minibatch == 0:
135
+ continue
136
+
137
+ bs_context = repeat(processed['context'], '1 n d -> b n d', b=minibatch)
138
+ bs_context_mask = repeat(processed['context_mask'], '1 n -> b n', b=minibatch)
139
+
140
+ if processed['negative_context'] is not None:
141
+ bs_negative_context = repeat(processed['negative_context'], '1 n d -> b n d', b=minibatch)
142
+ bs_negative_context_mask = repeat(processed['negative_context_mask'], '1 n -> b n', b=minibatch)
143
+ else:
144
+ bs_negative_context, bs_negative_context_mask = None, None
145
+
146
+ mask = processed['mask'].repeat_interleave(minibatch, dim=0)
147
+ masked_latent = processed['masked_latent'].repeat_interleave(minibatch, dim=0)
148
+
149
+ minibatch = masked_latent.shape[0]
150
+
151
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
152
+ with torch.no_grad():
153
+ images = base_diffusion.p_sample_loop(
154
+ self.unet, (minibatch, 4, masked_latent.shape[2], masked_latent.shape[3]), times,
155
+ self.device_map['unet'],
156
+ bs_context, bs_context_mask, self.null_embedding, guidance_weight_text, eta,
157
+ negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
158
+ mask=mask, masked_latent=masked_latent, gan=False
159
+ )
160
+
161
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
162
+ images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
163
+ images = torch.clip((images + 1.) / 2., 0., 1.).cpu()
164
+
165
+ for images_chunk in images.chunk(1):
166
+ pil_images += [self.to_pil(image) for image in images_chunk]
167
+
168
+ return pil_images
kandinsky3/.ipynb_checkpoints/movq-checkpoint.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ from .utils import freeze
8
+
9
+
10
+ def nonlinearity(x):
11
+ return x*torch.sigmoid(x)
12
+
13
+
14
+ class SpatialNorm(nn.Module):
15
+ def __init__(
16
+ self, f_channels, zq_channels=None, norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=False, **norm_layer_params
17
+ ):
18
+ super().__init__()
19
+ self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
20
+ if zq_channels is not None:
21
+ if freeze_norm_layer:
22
+ for p in self.norm_layer.parameters:
23
+ p.requires_grad = False
24
+ self.add_conv = add_conv
25
+ if self.add_conv:
26
+ self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
27
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
28
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
29
+ def forward(self, f, zq=None):
30
+ norm_f = self.norm_layer(f)
31
+ if zq is not None:
32
+ f_size = f.shape[-2:]
33
+ zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
34
+ if self.add_conv:
35
+ zq = self.conv(zq)
36
+ norm_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
37
+ return norm_f
38
+
39
+
40
+ def Normalize(in_channels, zq_ch=None, add_conv=None):
41
+ return SpatialNorm(
42
+ in_channels, zq_ch, norm_layer=nn.GroupNorm,
43
+ freeze_norm_layer=False, add_conv=add_conv, num_groups=32, eps=1e-6, affine=True
44
+ )
45
+
46
+
47
+ class Upsample(nn.Module):
48
+ def __init__(self, in_channels, with_conv):
49
+ super().__init__()
50
+ self.with_conv = with_conv
51
+ if self.with_conv:
52
+ self.conv = torch.nn.Conv2d(in_channels,
53
+ in_channels,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1)
57
+
58
+ def forward(self, x):
59
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
60
+ if self.with_conv:
61
+ x = self.conv(x)
62
+ return x
63
+
64
+
65
+ class Downsample(nn.Module):
66
+ def __init__(self, in_channels, with_conv):
67
+ super().__init__()
68
+ self.with_conv = with_conv
69
+ if self.with_conv:
70
+ self.conv = torch.nn.Conv2d(in_channels,
71
+ in_channels,
72
+ kernel_size=3,
73
+ stride=2,
74
+ padding=0)
75
+
76
+ def forward(self, x):
77
+ if self.with_conv:
78
+ pad = (0,1,0,1)
79
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
80
+ x = self.conv(x)
81
+ else:
82
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
83
+ return x
84
+
85
+
86
+ class ResnetBlock(nn.Module):
87
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
88
+ dropout, temb_channels=512, zq_ch=None, add_conv=False):
89
+ super().__init__()
90
+ self.in_channels = in_channels
91
+ out_channels = in_channels if out_channels is None else out_channels
92
+ self.out_channels = out_channels
93
+ self.use_conv_shortcut = conv_shortcut
94
+
95
+ self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
96
+ self.conv1 = torch.nn.Conv2d(in_channels,
97
+ out_channels,
98
+ kernel_size=3,
99
+ stride=1,
100
+ padding=1)
101
+ if temb_channels > 0:
102
+ self.temb_proj = torch.nn.Linear(temb_channels,
103
+ out_channels)
104
+ self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
105
+ self.dropout = torch.nn.Dropout(dropout)
106
+ self.conv2 = torch.nn.Conv2d(out_channels,
107
+ out_channels,
108
+ kernel_size=3,
109
+ stride=1,
110
+ padding=1)
111
+ if self.in_channels != self.out_channels:
112
+ if self.use_conv_shortcut:
113
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
114
+ out_channels,
115
+ kernel_size=3,
116
+ stride=1,
117
+ padding=1)
118
+ else:
119
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
120
+ out_channels,
121
+ kernel_size=1,
122
+ stride=1,
123
+ padding=0)
124
+
125
+ def forward(self, x, temb, zq=None):
126
+ h = x
127
+ h = self.norm1(h, zq)
128
+ h = nonlinearity(h)
129
+ h = self.conv1(h)
130
+
131
+ if temb is not None:
132
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
133
+
134
+ h = self.norm2(h, zq)
135
+ h = nonlinearity(h)
136
+ h = self.dropout(h)
137
+ h = self.conv2(h)
138
+
139
+ if self.in_channels != self.out_channels:
140
+ if self.use_conv_shortcut:
141
+ x = self.conv_shortcut(x)
142
+ else:
143
+ x = self.nin_shortcut(x)
144
+
145
+ return x+h
146
+
147
+
148
+ class AttnBlock(nn.Module):
149
+ def __init__(self, in_channels, zq_ch=None, add_conv=False):
150
+ super().__init__()
151
+ self.in_channels = in_channels
152
+
153
+ self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
154
+ self.q = torch.nn.Conv2d(in_channels,
155
+ in_channels,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0)
159
+ self.k = torch.nn.Conv2d(in_channels,
160
+ in_channels,
161
+ kernel_size=1,
162
+ stride=1,
163
+ padding=0)
164
+ self.v = torch.nn.Conv2d(in_channels,
165
+ in_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ padding=0)
169
+ self.proj_out = torch.nn.Conv2d(in_channels,
170
+ in_channels,
171
+ kernel_size=1,
172
+ stride=1,
173
+ padding=0)
174
+
175
+
176
+ def forward(self, x, zq=None):
177
+ h_ = x
178
+ h_ = self.norm(h_, zq)
179
+ q = self.q(h_)
180
+ k = self.k(h_)
181
+ v = self.v(h_)
182
+
183
+ # compute attention
184
+ b,c,h,w = q.shape
185
+ q = q.reshape(b,c,h*w)
186
+ q = q.permute(0,2,1) # b,hw,c
187
+ k = k.reshape(b,c,h*w) # b,c,hw
188
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
189
+ w_ = w_ * (int(c)**(-0.5))
190
+ w_ = torch.nn.functional.softmax(w_, dim=2)
191
+
192
+ # attend to values
193
+ v = v.reshape(b,c,h*w)
194
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
195
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
196
+ h_ = h_.reshape(b,c,h,w)
197
+
198
+ h_ = self.proj_out(h_)
199
+
200
+ return x+h_
201
+
202
+
203
+ class Encoder(nn.Module):
204
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
205
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
206
+ resolution, z_channels, double_z=True, **ignore_kwargs):
207
+ super().__init__()
208
+ self.ch = ch
209
+ self.temb_ch = 0
210
+ self.num_resolutions = len(ch_mult)
211
+ self.num_res_blocks = num_res_blocks
212
+ self.resolution = resolution
213
+ self.in_channels = in_channels
214
+
215
+ # downsampling
216
+ self.conv_in = torch.nn.Conv2d(in_channels,
217
+ self.ch,
218
+ kernel_size=3,
219
+ stride=1,
220
+ padding=1)
221
+
222
+ curr_res = resolution
223
+ in_ch_mult = (1,)+tuple(ch_mult)
224
+ self.down = nn.ModuleList()
225
+ for i_level in range(self.num_resolutions):
226
+ block = nn.ModuleList()
227
+ attn = nn.ModuleList()
228
+ block_in = ch*in_ch_mult[i_level]
229
+ block_out = ch*ch_mult[i_level]
230
+ for i_block in range(self.num_res_blocks):
231
+ block.append(ResnetBlock(in_channels=block_in,
232
+ out_channels=block_out,
233
+ temb_channels=self.temb_ch,
234
+ dropout=dropout))
235
+ block_in = block_out
236
+ if curr_res in attn_resolutions:
237
+ attn.append(AttnBlock(block_in))
238
+ down = nn.Module()
239
+ down.block = block
240
+ down.attn = attn
241
+ if i_level != self.num_resolutions-1:
242
+ down.downsample = Downsample(block_in, resamp_with_conv)
243
+ curr_res = curr_res // 2
244
+ self.down.append(down)
245
+
246
+ # middle
247
+ self.mid = nn.Module()
248
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
249
+ out_channels=block_in,
250
+ temb_channels=self.temb_ch,
251
+ dropout=dropout)
252
+ self.mid.attn_1 = AttnBlock(block_in)
253
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
254
+ out_channels=block_in,
255
+ temb_channels=self.temb_ch,
256
+ dropout=dropout)
257
+
258
+ # end
259
+ self.norm_out = Normalize(block_in)
260
+ self.conv_out = torch.nn.Conv2d(block_in,
261
+ 2*z_channels if double_z else z_channels,
262
+ kernel_size=3,
263
+ stride=1,
264
+ padding=1)
265
+
266
+
267
+ def forward(self, x):
268
+ temb = None
269
+
270
+ # downsampling
271
+ hs = [self.conv_in(x)]
272
+ for i_level in range(self.num_resolutions):
273
+ for i_block in range(self.num_res_blocks):
274
+ h = self.down[i_level].block[i_block](hs[-1], temb)
275
+ if len(self.down[i_level].attn) > 0:
276
+ h = self.down[i_level].attn[i_block](h)
277
+ hs.append(h)
278
+ if i_level != self.num_resolutions-1:
279
+ hs.append(self.down[i_level].downsample(hs[-1]))
280
+
281
+ # middle
282
+ h = hs[-1]
283
+ h = self.mid.block_1(h, temb)
284
+ h = self.mid.attn_1(h)
285
+ h = self.mid.block_2(h, temb)
286
+
287
+ # end
288
+ h = self.norm_out(h)
289
+ h = nonlinearity(h)
290
+ h = self.conv_out(h)
291
+ return h
292
+
293
+
294
+ class Decoder(nn.Module):
295
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
296
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
297
+ resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, **ignorekwargs):
298
+ super().__init__()
299
+ self.ch = ch
300
+ self.temb_ch = 0
301
+ self.num_resolutions = len(ch_mult)
302
+ self.num_res_blocks = num_res_blocks
303
+ self.resolution = resolution
304
+ self.in_channels = in_channels
305
+ self.give_pre_end = give_pre_end
306
+
307
+ # compute in_ch_mult, block_in and curr_res at lowest res
308
+ in_ch_mult = (1,)+tuple(ch_mult)
309
+ block_in = ch*ch_mult[self.num_resolutions-1]
310
+ curr_res = resolution // 2**(self.num_resolutions-1)
311
+ self.z_shape = (1,z_channels,curr_res,curr_res)
312
+
313
+ # z to block_in
314
+ self.conv_in = torch.nn.Conv2d(z_channels,
315
+ block_in,
316
+ kernel_size=3,
317
+ stride=1,
318
+ padding=1)
319
+
320
+ # middle
321
+ self.mid = nn.Module()
322
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
323
+ out_channels=block_in,
324
+ temb_channels=self.temb_ch,
325
+ dropout=dropout,
326
+ zq_ch=zq_ch,
327
+ add_conv=add_conv)
328
+ self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
329
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
330
+ out_channels=block_in,
331
+ temb_channels=self.temb_ch,
332
+ dropout=dropout,
333
+ zq_ch=zq_ch,
334
+ add_conv=add_conv)
335
+
336
+ # upsampling
337
+ self.up = nn.ModuleList()
338
+ for i_level in reversed(range(self.num_resolutions)):
339
+ block = nn.ModuleList()
340
+ attn = nn.ModuleList()
341
+ block_out = ch*ch_mult[i_level]
342
+ for i_block in range(self.num_res_blocks+1):
343
+ block.append(ResnetBlock(in_channels=block_in,
344
+ out_channels=block_out,
345
+ temb_channels=self.temb_ch,
346
+ dropout=dropout,
347
+ zq_ch=zq_ch,
348
+ add_conv=add_conv))
349
+ block_in = block_out
350
+ if curr_res in attn_resolutions:
351
+ attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
352
+ up = nn.Module()
353
+ up.block = block
354
+ up.attn = attn
355
+ if i_level != 0:
356
+ up.upsample = Upsample(block_in, resamp_with_conv)
357
+ curr_res = curr_res * 2
358
+ self.up.insert(0, up) # prepend to get consistent order
359
+
360
+ # end
361
+ self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
362
+ self.conv_out = torch.nn.Conv2d(block_in,
363
+ out_ch,
364
+ kernel_size=3,
365
+ stride=1,
366
+ padding=1)
367
+
368
+ def forward(self, z, zq):
369
+ #assert z.shape[1:] == self.z_shape[1:]
370
+ self.last_z_shape = z.shape
371
+
372
+ # timestep embedding
373
+ temb = None
374
+
375
+ # z to block_in
376
+ h = self.conv_in(z)
377
+
378
+ # middle
379
+ h = self.mid.block_1(h, temb, zq)
380
+ h = self.mid.attn_1(h, zq)
381
+ h = self.mid.block_2(h, temb, zq)
382
+
383
+ # upsampling
384
+ for i_level in reversed(range(self.num_resolutions)):
385
+ for i_block in range(self.num_res_blocks+1):
386
+ h = self.up[i_level].block[i_block](h, temb, zq)
387
+ if len(self.up[i_level].attn) > 0:
388
+ h = self.up[i_level].attn[i_block](h, zq)
389
+ if i_level != 0:
390
+ h = self.up[i_level].upsample(h)
391
+
392
+ # end
393
+ if self.give_pre_end:
394
+ return h
395
+
396
+ h = self.norm_out(h, zq)
397
+ h = nonlinearity(h)
398
+ h = self.conv_out(h)
399
+ return h
400
+
401
+
402
+ class MoVQ(nn.Module):
403
+
404
+ def __init__(self, generator_params):
405
+ super().__init__()
406
+ z_channels = generator_params["z_channels"]
407
+ self.encoder = Encoder(**generator_params)
408
+ self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
409
+ self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
410
+ self.decoder = Decoder(zq_ch=z_channels, **generator_params)
411
+
412
+ # @torch.no_grad()
413
+ def encode(self, x):
414
+ h = self.encoder(x)
415
+ h = self.quant_conv(h)
416
+ return h
417
+
418
+ # @torch.no_grad()
419
+ def decode(self, quant):
420
+ decoder_input = self.post_quant_conv(quant)
421
+ decoded = self.decoder(decoder_input, quant)
422
+ return decoded
423
+
424
+
425
+ def get_vae(conf):
426
+ movq = MoVQ(conf.params)
427
+ if conf.checkpoint is not None:
428
+ movq_state_dict = torch.load(conf.checkpoint)
429
+ movq.load_state_dict(movq_state_dict)
430
+ movq = freeze(movq)
431
+ return movq
kandinsky3/.ipynb_checkpoints/t2i_pipeline-checkpoint.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+ import PIL
3
+
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from einops import repeat
7
+
8
+ from kandinsky3.model.unet import UNet
9
+ from kandinsky3.movq import MoVQ
10
+ from kandinsky3.condition_encoders import T5TextConditionEncoder
11
+ from kandinsky3.condition_processors import T5TextConditionProcessor
12
+ from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
13
+
14
+
15
+ class Kandinsky3T2IPipeline:
16
+
17
+ def __init__(
18
+ self,
19
+ device_map: Union[str, torch.device, dict],
20
+ dtype_map: Union[str, torch.dtype, dict],
21
+ unet: UNet,
22
+ null_embedding: torch.Tensor,
23
+ t5_processor: T5TextConditionProcessor,
24
+ t5_encoder: T5TextConditionEncoder,
25
+ movq: MoVQ,
26
+ gan: bool,
27
+ ):
28
+ self.device_map = device_map
29
+ self.dtype_map = dtype_map
30
+ self.to_pil = T.ToPILImage()
31
+
32
+ self.unet = unet
33
+ self.null_embedding = null_embedding
34
+ self.t5_processor = t5_processor
35
+ self.t5_encoder = t5_encoder
36
+ self.movq = movq
37
+
38
+ self.gan = gan
39
+
40
+ def __call__(
41
+ self,
42
+ text: str,
43
+ negative_text: str = None,
44
+ images_num: int = 1,
45
+ bs: int = 1,
46
+ width: int = 1024,
47
+ height: int = 1024,
48
+ guidance_scale: float = 3.0,
49
+ steps: int = 50,
50
+ eta: float = 1.0
51
+ ) -> List[PIL.Image.Image]:
52
+
53
+ betas = get_named_beta_schedule('cosine', 1000)
54
+ base_diffusion = BaseDiffusion(betas, 0.99)
55
+ times = list(range(999, 0, -1000 // steps))
56
+ if self.gan:
57
+ times = list(range(979, 0, -250))
58
+
59
+ condition_model_input, negative_condition_model_input = self.t5_processor.encode(text, negative_text)
60
+ for input_type in condition_model_input:
61
+ condition_model_input[input_type] = condition_model_input[input_type][None].to(
62
+ self.device_map['text_encoder']
63
+ )
64
+
65
+ if negative_condition_model_input is not None:
66
+ for input_type in negative_condition_model_input:
67
+ negative_condition_model_input[input_type] = negative_condition_model_input[input_type][None].to(
68
+ self.device_map['text_encoder']
69
+ )
70
+
71
+ pil_images = []
72
+ with torch.no_grad():
73
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
74
+ context, context_mask = self.t5_encoder(condition_model_input)
75
+ if negative_condition_model_input is not None:
76
+ negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
77
+ else:
78
+ negative_context, negative_context_mask = None, None
79
+
80
+ k, m = images_num // bs, images_num % bs
81
+ for minibatch in [bs] * k + [m]:
82
+ if minibatch == 0:
83
+ continue
84
+ bs_context = repeat(context, '1 n d -> b n d', b=minibatch)
85
+ bs_context_mask = repeat(context_mask, '1 n -> b n', b=minibatch)
86
+ if negative_context is not None:
87
+ bs_negative_context = repeat(negative_context, '1 n d -> b n d', b=minibatch)
88
+ bs_negative_context_mask = repeat(negative_context_mask, '1 n -> b n', b=minibatch)
89
+ else:
90
+ bs_negative_context, bs_negative_context_mask = None, None
91
+
92
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
93
+ images = base_diffusion.p_sample_loop(
94
+ self.unet, (minibatch, 4, height // 8, width // 8), times, self.device_map['unet'],
95
+ bs_context, bs_context_mask, self.null_embedding, guidance_scale, eta,
96
+ negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
97
+ gan=self.gan
98
+ )
99
+
100
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
101
+ images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
102
+ # print(torch.max(images), torch.min(images))
103
+ images = torch.clip((images + 1.) / 2., 0., 1.)
104
+ # print(torch.max(images), torch.min(images))
105
+ # raise
106
+ for images_chunk in images.chunk(1):
107
+ pil_images += [self.to_pil(image) for image in images_chunk]
108
+
109
+ return pil_images
kandinsky3/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import numpy as np
3
+ from scipy import ndimage
4
+ import torch.nn as nn
5
+ from skimage.transform import resize
6
+
7
+
8
+ def load_conf(config_path):
9
+ conf = OmegaConf.load(config_path)
10
+ conf.data.tokens_length = conf.common.tokens_length
11
+ conf.data.processor_names = conf.model.encoders.model_names
12
+ conf.data.dataset.seed = conf.common.seed
13
+ conf.data.dataset.image_size = conf.common.image_size
14
+
15
+ conf.trainer.trainer_params.max_steps = conf.common.train_steps
16
+ conf.scheduler.params.total_steps = conf.common.train_steps
17
+ conf.logger.tensorboard.name = conf.common.experiment_name
18
+
19
+ conf.model.encoders.context_dim = conf.model.unet_params.context_dim
20
+ return conf
21
+
22
+
23
+ def freeze(model):
24
+ for p in model.parameters():
25
+ p.requires_grad = False
26
+ return model
27
+
28
+ def unfreeze(model):
29
+ for p in model.parameters():
30
+ p.requires_grad = True
31
+ return model
32
+
33
+ def zero_module(module):
34
+ for p in module.parameters():
35
+ nn.init.zeros_(p)
36
+ return module
37
+
38
+ def resize_mask_for_diffusion(mask):
39
+ reduce_factor = max(1, (mask.size / 1024**2)**0.5)
40
+ resized_mask = resize(
41
+ mask,
42
+ (
43
+ (round(mask.shape[0] / reduce_factor) // 64) * 64,
44
+ (round(mask.shape[1] / reduce_factor) // 64) * 64
45
+ ),
46
+ preserve_range=True,
47
+ anti_aliasing=False
48
+ )
49
+
50
+ return resized_mask
51
+
52
+ def resize_image_for_diffusion(image):
53
+ reduce_factor = max(1, (image.size[0] * image.size[1] / 1024**2)**0.5)
54
+ image = image.resize((
55
+ (round(image.size[0] / reduce_factor) // 64) * 64, (round(image.size[1] / reduce_factor) // 64) * 64
56
+ ))
57
+
58
+ return image
59
+
60
+ def prepare_mask(mask):
61
+ ker = np.array([[1, 1, 1, 1, 1],
62
+ [1, 5, 5, 5, 1],
63
+ [1, 5, 44, 5, 1],
64
+ [1, 5, 5, 5, 1],
65
+ [1, 1, 1, 1, 1]]) / 100
66
+ out = ndimage.convolve(mask, ker)
67
+ out = ndimage.convolve(out, ker)
68
+ out = ndimage.convolve(out, ker)
69
+
70
+ mask = (out > 0).astype(int)
71
+ return mask
kandinsky3/__init__.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ from huggingface_hub import hf_hub_download, snapshot_download
6
+
7
+ from kandinsky3.model.unet import UNet
8
+ from kandinsky3.movq import MoVQ
9
+ from kandinsky3.condition_encoders import T5TextConditionEncoder
10
+ from kandinsky3.condition_processors import T5TextConditionProcessor
11
+ from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
12
+
13
+ from .t2i_pipeline import Kandinsky3T2IPipeline
14
+ from .inpainting_pipeline import Kandinsky3InpaintingPipeline
15
+
16
+
17
+ def get_T2I_unet(
18
+ device: Union[str, torch.device],
19
+ weights_path: Optional[str] = None,
20
+ dtype: Union[str, torch.dtype] = torch.float32,
21
+ ) -> (UNet, Optional[torch.Tensor], Optional[dict]):
22
+ unet = UNet(
23
+ model_channels=384,
24
+ num_channels=4,
25
+ init_channels=192,
26
+ time_embed_dim=1536,
27
+ context_dim=4096,
28
+ groups=32,
29
+ head_dim=64,
30
+ expansion_ratio=4,
31
+ compression_ratio=2,
32
+ dim_mult=(1, 2, 4, 8),
33
+ num_blocks=(3, 3, 3, 3),
34
+ add_cross_attention=(False, True, True, True),
35
+ add_self_attention=(False, True, True, True),
36
+ )
37
+
38
+ null_embedding = None
39
+ if weights_path:
40
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
41
+ null_embedding = state_dict['null_embedding']
42
+ unet.load_state_dict(state_dict['unet'])
43
+
44
+ unet.to(device=device, dtype=dtype).eval()
45
+ return unet, null_embedding
46
+
47
+
48
+ def get_T5encoder(
49
+ device: Union[str, torch.device],
50
+ weights_path: str,
51
+ projection_name: str,
52
+ dtype: Union[str, torch.dtype] = torch.float32,
53
+ low_cpu_mem_usage: bool = True,
54
+ load_in_8bit: bool = False,
55
+ load_in_4bit: bool = False,
56
+ ) -> (T5TextConditionProcessor, T5TextConditionEncoder):
57
+ tokens_length = 128
58
+ context_dim = 4096
59
+ processor = T5TextConditionProcessor(tokens_length, weights_path)
60
+ condition_encoder = T5TextConditionEncoder(
61
+ weights_path, context_dim, low_cpu_mem_usage=low_cpu_mem_usage, device=device,
62
+ dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
63
+ )
64
+
65
+ if weights_path:
66
+ projections_weights_path = os.path.join(weights_path, projection_name)
67
+ state_dict = torch.load(projections_weights_path, map_location=torch.device('cpu'))
68
+ condition_encoder.projection.load_state_dict(state_dict)
69
+
70
+ condition_encoder.projection.to(device=device, dtype=dtype).eval()
71
+ return processor, condition_encoder
72
+
73
+
74
+ def get_movq(
75
+ device: Union[str, torch.device],
76
+ weights_path: Optional[str] = None,
77
+ dtype: Union[str, torch.dtype] = torch.float32,
78
+ ) -> MoVQ:
79
+ generator_config = {
80
+ 'double_z': False,
81
+ 'z_channels': 4,
82
+ 'resolution': 256,
83
+ 'in_channels': 3,
84
+ 'out_ch': 3,
85
+ 'ch': 256,
86
+ 'ch_mult': [1, 2, 2, 4],
87
+ 'num_res_blocks': 2,
88
+ 'attn_resolutions': [32],
89
+ 'dropout': 0.0
90
+ }
91
+ movq = MoVQ(generator_config)
92
+
93
+ if weights_path:
94
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
95
+ movq.load_state_dict(state_dict)
96
+
97
+ movq.to(device=device, dtype=dtype).eval()
98
+ return movq
99
+
100
+
101
+ def get_inpainting_unet(
102
+ device: Union[str, torch.device],
103
+ weights_path: Optional[str] = None,
104
+ dtype: Union[str, torch.dtype] = torch.float32,
105
+ ) -> (UNet, Optional[torch.Tensor], Optional[dict]):
106
+ unet = UNet(
107
+ model_channels=384,
108
+ num_channels=9,
109
+ init_channels=192,
110
+ time_embed_dim=1536,
111
+ context_dim=4096,
112
+ groups=32,
113
+ head_dim=64,
114
+ expansion_ratio=4,
115
+ compression_ratio=2,
116
+ dim_mult=(1, 2, 4, 8),
117
+ num_blocks=(3, 3, 3, 3),
118
+ add_cross_attention=(False, True, True, True),
119
+ add_self_attention=(False, True, True, True),
120
+ )
121
+
122
+ null_embedding = None
123
+ if weights_path:
124
+ state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
125
+ null_embedding = state_dict['null_embedding']
126
+ unet.load_state_dict(state_dict['unet'])
127
+
128
+ unet.to(device=device, dtype=dtype).eval()
129
+ return unet, null_embedding
130
+
131
+
132
+ def get_T2I_pipeline(
133
+ device_map: Union[str, torch.device, dict],
134
+ dtype_map: Union[str, torch.dtype, dict] = torch.float32,
135
+ low_cpu_mem_usage: bool = True,
136
+ load_in_8bit: bool = False,
137
+ load_in_4bit: bool = False,
138
+ cache_dir: str = '/tmp/kandinsky3/',
139
+ unet_path: str = None,
140
+ text_encoder_path: str = None,
141
+ movq_path: str = None,
142
+ ) -> Kandinsky3T2IPipeline:
143
+ # assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
144
+ if not isinstance(device_map, dict):
145
+ device_map = {
146
+ 'unet': device_map, 'text_encoder': device_map, 'movq': device_map
147
+ }
148
+ if not isinstance(dtype_map, dict):
149
+ dtype_map = {
150
+ 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
151
+ }
152
+
153
+ if unet_path is None:
154
+ unet_path = hf_hub_download(
155
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3.pt', cache_dir=cache_dir
156
+ )
157
+ if text_encoder_path is None:
158
+ text_encoder_path = snapshot_download(
159
+ repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
160
+ )
161
+ text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
162
+ if movq_path is None:
163
+ movq_path = hf_hub_download(
164
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
165
+ )
166
+
167
+ unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
168
+ processor, condition_encoder = get_T5encoder(
169
+ device_map['text_encoder'], text_encoder_path, 'projection.pt', dtype=dtype_map['text_encoder'],
170
+ low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
171
+ )
172
+ movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
173
+ return Kandinsky3T2IPipeline(
174
+ device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, False
175
+ )
176
+
177
+
178
+ def get_T2I_Flash_pipeline(
179
+ device_map: Union[str, torch.device, dict],
180
+ dtype_map: Union[str, torch.dtype, dict] = torch.float32,
181
+ low_cpu_mem_usage: bool = True,
182
+ load_in_8bit: bool = False,
183
+ load_in_4bit: bool = False,
184
+ cache_dir: str = '/tmp/kandinsky3/',
185
+ unet_path: str = None,
186
+ text_encoder_path: str = None,
187
+ movq_path: str = None,
188
+ ) -> Kandinsky3T2IPipeline:
189
+ # assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
190
+ if not isinstance(device_map, dict):
191
+ device_map = {
192
+ 'unet': device_map, 'text_encoder': device_map, 'movq': device_map
193
+ }
194
+ if not isinstance(dtype_map, dict):
195
+ dtype_map = {
196
+ 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
197
+ }
198
+
199
+ if unet_path is None:
200
+ unet_path = hf_hub_download(
201
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir
202
+ )
203
+ if text_encoder_path is None:
204
+ text_encoder_path = snapshot_download(
205
+ repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
206
+ )
207
+ text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
208
+ if movq_path is None:
209
+ movq_path = hf_hub_download(
210
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
211
+ )
212
+
213
+ unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
214
+ processor, condition_encoder = get_T5encoder(
215
+ device_map['text_encoder'], text_encoder_path, 'projection_flash.pt', dtype=dtype_map['text_encoder'],
216
+ low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
217
+ )
218
+ movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
219
+ return Kandinsky3T2IPipeline(
220
+ device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, True
221
+ )
222
+
223
+
224
+ def get_inpainting_pipeline(
225
+ device_map: Union[str, torch.device, dict],
226
+ dtype_map: Union[str, torch.dtype, dict] = torch.float32,
227
+ low_cpu_mem_usage: bool = True,
228
+ load_in_8bit: bool = False,
229
+ load_in_4bit: bool = False,
230
+ cache_dir: str = '/tmp/kandinsky3/',
231
+ unet_path: str = None,
232
+ text_encoder_path: str = None,
233
+ movq_path: str = None,
234
+ ) -> Kandinsky3InpaintingPipeline:
235
+ # assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
236
+ if not isinstance(device_map, dict):
237
+ device_map = {
238
+ 'unet': device_map, 'text_encoder': device_map, 'movq': device_map
239
+ }
240
+ if not isinstance(dtype_map, dict):
241
+ dtype_map = {
242
+ 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
243
+ }
244
+
245
+ if unet_path is None:
246
+ unet_path = hf_hub_download(
247
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_inpainting.pt', cache_dir=cache_dir
248
+ )
249
+ if text_encoder_path is None:
250
+ text_encoder_path = snapshot_download(
251
+ repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
252
+ )
253
+ text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
254
+ if movq_path is None:
255
+ movq_path = hf_hub_download(
256
+ repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
257
+ )
258
+
259
+ unet, null_embedding = get_inpainting_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
260
+ processor, condition_encoder = get_T5encoder(
261
+ device_map['text_encoder'], text_encoder_path, 'projection_inpainting.pt', dtype=dtype_map['text_encoder'],
262
+ low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
263
+ )
264
+ movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
265
+ return Kandinsky3InpaintingPipeline(
266
+ device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq
267
+ )
kandinsky3/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (5.38 kB). View file
 
kandinsky3/__pycache__/condition_encoders.cpython-310.pyc ADDED
Binary file (1.84 kB). View file
 
kandinsky3/__pycache__/condition_processors.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
kandinsky3/__pycache__/inpainting_pipeline.cpython-310.pyc ADDED
Binary file (5.32 kB). View file
 
kandinsky3/__pycache__/movq.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
kandinsky3/__pycache__/t2i_pipeline.cpython-310.pyc ADDED
Binary file (3.63 kB). View file
 
kandinsky3/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
kandinsky3/condition_encoders.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import T5EncoderModel
4
+ from typing import Optional, Union
5
+
6
+
7
+ class T5TextConditionEncoder(nn.Module):
8
+
9
+ def __init__(
10
+ self, model_path, context_dim,
11
+ low_cpu_mem_usage: bool = True, device: Optional[str] = None,
12
+ dtype: Union[str, torch.dtype] = torch.float32, load_in_4bit: bool = False, load_in_8bit: bool = False
13
+ ):
14
+ super().__init__()
15
+ self.encoder = T5EncoderModel.from_pretrained(
16
+ model_path, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device,
17
+ torch_dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit,
18
+ ).encoder
19
+ self.projection = nn.Sequential(
20
+ nn.Linear(self.encoder.config.d_model, context_dim, bias=False),
21
+ nn.LayerNorm(context_dim)
22
+ )
23
+
24
+ def forward(self, model_input):
25
+ embeddings = self.encoder(**model_input).last_hidden_state
26
+ context = self.projection(embeddings)
27
+ if 'attention_mask' in model_input:
28
+ context_mask = model_input['attention_mask']
29
+ context[context_mask == 0] = torch.zeros_like(context[context_mask == 0])
30
+ max_seq_length = context_mask.sum(-1).max() + 1
31
+ context = context[:, :max_seq_length]
32
+ context_mask = context_mask[:, :max_seq_length]
33
+ else:
34
+ context_mask = torch.ones(*embeddings.shape[:-1], dtype=torch.long, device=embeddings.device)
35
+ return context, context_mask
36
+
37
+
38
+ def get_condition_encoder(conf):
39
+ return T5TextConditionEncoder(**conf)
40
+
kandinsky3/condition_processors.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5Tokenizer
3
+
4
+
5
+ class T5TextConditionProcessor:
6
+
7
+ def __init__(self, tokens_length, processor_path):
8
+ self.tokens_length = tokens_length
9
+ self.processor = T5Tokenizer.from_pretrained(processor_path)
10
+
11
+ def encode(self, text=None, negative_text=None):
12
+ encoded = self.processor(text, max_length=self.tokens_length, truncation=True)
13
+ pad_length = self.tokens_length - len(encoded['input_ids'])
14
+ input_ids = encoded['input_ids'] + [self.processor.pad_token_id] * pad_length
15
+ attention_mask = encoded['attention_mask'] + [0] * pad_length
16
+ condition_model_input = {
17
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
18
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
19
+ }
20
+
21
+ if negative_text is not None:
22
+ negative_encoded = self.processor(negative_text, max_length=self.tokens_length, truncation=True)
23
+ negative_input_ids = negative_encoded['input_ids'][:len(encoded['input_ids'])]
24
+ negative_input_ids[-1] = self.processor.eos_token_id
25
+ negative_pad_length = self.tokens_length - len(negative_input_ids)
26
+ negative_input_ids = negative_input_ids + [self.processor.pad_token_id] * negative_pad_length
27
+ negative_attention_mask = encoded['attention_mask'] + [0] * pad_length
28
+ negative_condition_model_input = {
29
+ 'input_ids': torch.tensor(negative_input_ids, dtype=torch.long),
30
+ 'attention_mask': torch.tensor(negative_attention_mask, dtype=torch.long)
31
+ }
32
+ else:
33
+ negative_condition_model_input = None
34
+ return condition_model_input, negative_condition_model_input
kandinsky3/inpainting_pipeline.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+ import PIL
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from einops import repeat
8
+
9
+ from kandinsky3.model.unet import UNet
10
+ from kandinsky3.movq import MoVQ
11
+ from kandinsky3.condition_encoders import T5TextConditionEncoder
12
+ from kandinsky3.condition_processors import T5TextConditionProcessor
13
+ from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
14
+ from kandinsky3.utils import resize_image_for_diffusion, resize_mask_for_diffusion
15
+
16
+
17
+ class Kandinsky3InpaintingPipeline:
18
+
19
+ def __init__(
20
+ self,
21
+ device_map: Union[str, torch.device, dict],
22
+ dtype_map: Union[str, torch.dtype, dict],
23
+ unet: UNet,
24
+ null_embedding: torch.Tensor,
25
+ t5_processor: T5TextConditionProcessor,
26
+ t5_encoder: T5TextConditionEncoder,
27
+ movq: MoVQ,
28
+ ):
29
+ self.device_map = device_map
30
+ self.dtype_map = dtype_map
31
+ self.to_pil = T.ToPILImage()
32
+ self.to_tensor = T.ToTensor()
33
+
34
+ self.unet = unet
35
+ self.null_embedding = null_embedding
36
+ self.t5_processor = t5_processor
37
+ self.t5_encoder = t5_encoder
38
+ self.movq = movq
39
+
40
+ def shared_step(self, batch: dict) -> dict:
41
+ image = batch['image']
42
+ condition_model_input = batch['text']
43
+ negative_condition_model_input = batch['negative_text']
44
+
45
+ bs = image.shape[0]
46
+
47
+ masked_latent = None
48
+ mask = batch['mask']
49
+
50
+ if 'masked_image' in batch:
51
+ masked_latent = batch['masked_image']
52
+ elif self.unet.in_layer.in_channels == 9:
53
+ masked_latent = image.masked_fill((1 - mask).bool(), 0)
54
+ else:
55
+ raise ValueError()
56
+
57
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
58
+ masked_latent = self.movq.encode(masked_latent)
59
+ mask = torch.nn.functional.interpolate(mask, size=(masked_latent.shape[2], masked_latent.shape[3]))
60
+
61
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
62
+ context, context_mask = self.t5_encoder(condition_model_input)
63
+
64
+ if negative_condition_model_input is not None:
65
+ negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
66
+ else:
67
+ negative_context, negative_context_mask = None, None
68
+
69
+ return {
70
+ 'context': context,
71
+ 'context_mask': context_mask,
72
+ 'negative_context': negative_context,
73
+ 'negative_context_mask': negative_context_mask,
74
+ 'image': image,
75
+ 'masked_latent': masked_latent,
76
+ 'mask': mask
77
+ }
78
+
79
+ def prepare_batch(
80
+ self,
81
+ text: str,
82
+ negative_text: str,
83
+ image: PIL.Image.Image,
84
+ mask: np.ndarray,
85
+ ) -> dict:
86
+ condition_model_input, negative_condition_model_input = self.t5_processor.encode(
87
+ text=text, negative_text=negative_text
88
+ )
89
+ batch = {
90
+ 'image': self.to_tensor(resize_image_for_diffusion(image.convert("RGB"))) * 2 - 1,
91
+ 'mask': 1 - self.to_tensor(resize_mask_for_diffusion(mask)),
92
+ 'text': condition_model_input,
93
+ 'negative_text': negative_condition_model_input
94
+ }
95
+ batch['mask'] = batch['mask'].type(self.dtype_map['movq'])
96
+
97
+ batch['image'] = batch['image'].unsqueeze(0).to(self.device_map['movq'])
98
+ batch['text']['input_ids'] = batch['text']['input_ids'].unsqueeze(0).to(self.device_map['text_encoder'])
99
+ batch['text']['attention_mask'] = batch['text']['attention_mask'].unsqueeze(0).to(
100
+ self.device_map['text_encoder'])
101
+ batch['mask'] = batch['mask'].unsqueeze(0).to(self.device_map['movq'])
102
+
103
+ if negative_condition_model_input is not None:
104
+ batch['negative_text']['input_ids'] = batch['negative_text']['input_ids'].to(
105
+ self.device_map['text_encoder'])
106
+ batch['negative_text']['attention_mask'] = batch['negative_text']['attention_mask'].to(
107
+ self.device_map['text_encoder'])
108
+
109
+ return batch
110
+
111
+ def __call__(
112
+ self,
113
+ text: str,
114
+ image: PIL.Image.Image,
115
+ mask: np.ndarray,
116
+ negative_text: str = None,
117
+ images_num: int = 1,
118
+ bs: int = 1,
119
+ steps: int = 50,
120
+ guidance_weight_text: float = 4,
121
+ eta=1.0
122
+ ) -> List[PIL.Image.Image]:
123
+
124
+ with torch.no_grad():
125
+ batch = self.prepare_batch(text, negative_text, image, mask)
126
+ processed = self.shared_step(batch)
127
+ betas = get_named_beta_schedule('cosine', 1000)
128
+ base_diffusion = BaseDiffusion(betas, percentile=0.95)
129
+ times = list(range(999, 0, -1000 // steps))
130
+
131
+ pil_images = []
132
+ k, m = images_num // bs, images_num % bs
133
+ for minibatch in [bs] * k + [m]:
134
+ if minibatch == 0:
135
+ continue
136
+
137
+ bs_context = repeat(processed['context'], '1 n d -> b n d', b=minibatch)
138
+ bs_context_mask = repeat(processed['context_mask'], '1 n -> b n', b=minibatch)
139
+
140
+ if processed['negative_context'] is not None:
141
+ bs_negative_context = repeat(processed['negative_context'], '1 n d -> b n d', b=minibatch)
142
+ bs_negative_context_mask = repeat(processed['negative_context_mask'], '1 n -> b n', b=minibatch)
143
+ else:
144
+ bs_negative_context, bs_negative_context_mask = None, None
145
+
146
+ mask = processed['mask'].repeat_interleave(minibatch, dim=0)
147
+ masked_latent = processed['masked_latent'].repeat_interleave(minibatch, dim=0)
148
+
149
+ minibatch = masked_latent.shape[0]
150
+
151
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
152
+ with torch.no_grad():
153
+ images = base_diffusion.p_sample_loop(
154
+ self.unet, (minibatch, 4, masked_latent.shape[2], masked_latent.shape[3]), times,
155
+ self.device_map['unet'],
156
+ bs_context, bs_context_mask, self.null_embedding, guidance_weight_text, eta,
157
+ negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
158
+ mask=mask, masked_latent=masked_latent, gan=False
159
+ )
160
+
161
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
162
+ images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
163
+ images = torch.clip((images + 1.) / 2., 0., 1.).cpu()
164
+
165
+ for images_chunk in images.chunk(1):
166
+ pil_images += [self.to_pil(image) for image in images_chunk]
167
+
168
+ return pil_images
kandinsky3/model/.ipynb_checkpoints/diffusion-checkpoint.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from tqdm import tqdm
6
+
7
+ from .utils import get_tensor_items
8
+
9
+
10
+ def get_named_beta_schedule(schedule_name, timesteps):
11
+ if schedule_name == "linear":
12
+ scale = 1000 / timesteps
13
+ beta_start = scale * 0.0001
14
+ beta_end = scale * 0.02
15
+ return torch.linspace(
16
+ beta_start, beta_end, timesteps, dtype=torch.float32
17
+ )
18
+ elif schedule_name == "cosine":
19
+ alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
20
+ betas = []
21
+ for i in range(timesteps):
22
+ t1 = i / timesteps
23
+ t2 = (i + 1) / timesteps
24
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
25
+ return torch.tensor(betas, dtype=torch.float32)
26
+
27
+
28
+ class BaseDiffusion:
29
+
30
+ def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
31
+ self.betas = betas
32
+ self.num_timesteps = betas.shape[0]
33
+
34
+ alphas = 1. - betas
35
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
36
+ self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
37
+
38
+ # calculate q(x_t | x_{t-1})
39
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
40
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
41
+
42
+ # calculate q(x_{t-1} | x_t, x_0)
43
+ self.posterior_mean_coef_1 = torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod)
44
+ self.posterior_mean_coef_2 = torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
45
+ self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
46
+ self.posterior_log_variance = torch.log(
47
+ torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])
48
+ )
49
+
50
+ self.percentile = percentile
51
+ self.time_scale = 1000 // self.num_timesteps
52
+ self.gen_noise = gen_noise
53
+ self.jump_length = 3
54
+
55
+ def process_x_start(self, x_start):
56
+ bs, ndims = x_start.shape[0], len(x_start.shape[1:])
57
+ if self.percentile is not None:
58
+ quantile = torch.quantile(
59
+ rearrange(x_start, 'b ... -> b (...)').abs(),
60
+ self.percentile,
61
+ dim=-1
62
+ )
63
+ quantile = torch.clip(quantile, min=1.)
64
+ quantile = quantile.reshape(bs, *((1,) * ndims))
65
+ return torch.clip(x_start, -quantile, quantile) / quantile
66
+ else:
67
+ return torch.clip(x_start, -1., 1.)
68
+
69
+ def get_x_start(self, x, t, noise):
70
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
71
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape)
72
+ pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod
73
+ return pred_x_start
74
+
75
+ def get_noise(self, x, t, x_start):
76
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
77
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
78
+ pred_noise = (x - sqrt_alphas_cumprod * x_start) / sqrt_one_minus_alphas_cumprod
79
+ return pred_noise
80
+
81
+ def q_sample(self, x_start, t, noise=None):
82
+ if noise is None:
83
+ noise = self.gen_noise(x_start)
84
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
85
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
86
+ x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
87
+ return x_t
88
+
89
+ def q_posterior_mean_variance(self, x_start, x_t, t):
90
+ posterior_mean_coef_1 = get_tensor_items(self.posterior_mean_coef_1, t, x_start.shape)
91
+ posterior_mean_coef_2 = get_tensor_items(self.posterior_mean_coef_2, t, x_t.shape)
92
+ posterior_mean = posterior_mean_coef_1 * x_start + posterior_mean_coef_2 * x_t
93
+
94
+ posterior_variance = get_tensor_items(self.posterior_variance, t, x_start.shape)
95
+ posterior_log_variance = get_tensor_items(self.posterior_log_variance, t, x_start.shape)
96
+ return posterior_mean, posterior_variance, posterior_log_variance
97
+
98
+ def q_posterior_variance(self, t, prev_t, shape, eta=1., ):
99
+ alphas_cumprod = get_tensor_items(self.alphas_cumprod, t, shape)
100
+ prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, shape)
101
+
102
+ posterior_variance = torch.sqrt(
103
+ eta * (1. - alphas_cumprod / prev_alphas_cumprod) * (1. - prev_alphas_cumprod) / (1. - alphas_cumprod)
104
+ )
105
+ return posterior_variance
106
+
107
+ def text_guidance(
108
+ self, model, x, t, context, context_mask, null_embedding, guidance_weight_text,
109
+ uncondition_context=None, uncondition_context_mask=None, mask=None, masked_latent=None
110
+ ):
111
+ large_x = x.repeat(2, 1, 1, 1)
112
+ large_t = t.repeat(2).to(x.dtype)
113
+
114
+ if uncondition_context is None:
115
+ uncondition_context = torch.zeros_like(context)
116
+ uncondition_context_mask = torch.zeros_like(context_mask)
117
+ uncondition_context[:, 0] = null_embedding
118
+ uncondition_context_mask[:, 0] = 1
119
+ large_context = torch.cat([context, uncondition_context])
120
+ large_context_mask = torch.cat([context_mask, uncondition_context_mask])
121
+
122
+ if mask is not None:
123
+ mask = mask.repeat(2, 1, 1, 1)
124
+ if masked_latent is not None:
125
+ masked_latent = masked_latent.repeat(2, 1, 1, 1)
126
+
127
+ if model.in_layer.in_channels == 9:
128
+ large_x = torch.cat([large_x, mask, masked_latent], dim=1)
129
+
130
+ pred_large_noise = model(large_x, large_t * self.time_scale, large_context, large_context_mask.bool())
131
+ pred_noise, uncond_pred_noise = torch.chunk(pred_large_noise, 2)
132
+ pred_noise = (guidance_weight_text + 1.) * pred_noise - guidance_weight_text * uncond_pred_noise
133
+ return pred_noise
134
+
135
+ def p_mean_variance(
136
+ self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
137
+ negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
138
+ ):
139
+
140
+ pred_noise = self.text_guidance(
141
+ model, x, t, context, context_mask, null_embedding, guidance_weight_text,
142
+ negative_context, negative_context_mask, mask, masked_latent
143
+ )
144
+
145
+ pred_x_start = self.get_x_start(x, t, pred_noise)
146
+ pred_x_start = self.process_x_start(pred_x_start)
147
+ pred_noise = self.get_noise(x, t, pred_x_start)
148
+ pred_var = self.q_posterior_variance(t, prev_t, x.shape, eta)
149
+
150
+ prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, x.shape)
151
+ pred_mean = torch.sqrt(prev_alphas_cumprod) * pred_x_start
152
+ pred_mean += torch.sqrt(1. - prev_alphas_cumprod - pred_var ** 2) * pred_noise
153
+ return pred_mean, pred_var
154
+
155
+ # @torch.no_grad()
156
+ def p_sample(
157
+ self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
158
+ negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
159
+ ):
160
+ bs = x.shape[0]
161
+ ndims = len(x.shape[1:])
162
+ pred_mean, pred_var = self.p_mean_variance(
163
+ model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta,
164
+ negative_context=negative_context, negative_context_mask=negative_context_mask,
165
+ mask=mask, masked_latent=masked_latent
166
+ )
167
+ noise = torch.randn_like(x)
168
+ mask = (prev_t != 0).reshape(bs, *((1,) * ndims))
169
+ sample = pred_mean + mask * pred_var * noise
170
+ return sample
171
+
172
+ # @torch.no_grad()
173
+ def p_sample_loop(
174
+ self, model, shape, times, device, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
175
+ negative_context=None, negative_context_mask=None, mask=None, masked_latent=None, gan=False,
176
+ ):
177
+ img = torch.randn(*shape, device=device)
178
+ times = times + [0, ]
179
+ times = list(zip(times[:-1], times[1:]))
180
+
181
+ for time, prev_time in tqdm(times):
182
+ time = torch.tensor([time] * shape[0], device=device)
183
+ if gan:
184
+ x_t = self.q_sample(img, time)
185
+ pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool())
186
+ img = self.get_x_start(x_t, time, pred_noise)
187
+ else:
188
+ prev_time = torch.tensor([prev_time] * shape[0], device=device)
189
+ img = self.p_sample(
190
+ model, img, time, prev_time, context, context_mask, null_embedding, guidance_weight_text, eta,
191
+ negative_context=negative_context, negative_context_mask=negative_context_mask,
192
+ mask=mask, masked_latent=masked_latent
193
+ )
194
+ return img
195
+
196
+
197
+ def get_diffusion(conf):
198
+ betas = get_named_beta_schedule(**conf.schedule_params)
199
+ base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
200
+ return base_diffusion
kandinsky3/model/.ipynb_checkpoints/unet-checkpoint.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ from einops import rearrange
4
+
5
+ from .nn import Identity, Attention, SinusoidalPosEmb, ConditionalGroupNorm
6
+ from .utils import exist, set_default_item, set_default_layer
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class Block(nn.Module):
11
+
12
+ def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
13
+ super().__init__()
14
+ self.group_norm = ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
15
+ self.activation = nn.SiLU()
16
+ self.up_sample = set_default_layer(
17
+ exist(up_resolution) and up_resolution,
18
+ nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
19
+ )
20
+ padding = set_default_item(kernel_size == 1, 0, 1)
21
+ self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
22
+ self.down_sample = set_default_layer(
23
+ exist(up_resolution) and not up_resolution,
24
+ nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
25
+ )
26
+
27
+ def forward(self, x, time_embed):
28
+ x = self.group_norm(x, time_embed)
29
+ x = self.activation(x)
30
+ x = self.up_sample(x)
31
+ x = self.projection(x)
32
+ x = self.down_sample(x)
33
+ return x
34
+
35
+
36
+ class ResNetBlock(nn.Module):
37
+
38
+ def __init__(
39
+ self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4*[None]
40
+ ):
41
+ super().__init__()
42
+ kernel_sizes = [1, 3, 3, 1]
43
+ hidden_channel = max(in_channels, out_channels) // compression_ratio
44
+ hidden_channels = [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
45
+ self.resnet_blocks = nn.ModuleList([
46
+ Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
47
+ for (in_channel, out_channel), kernel_size, up_resolution in zip(hidden_channels, kernel_sizes, up_resolutions)
48
+ ])
49
+
50
+ self.shortcut_up_sample = set_default_layer(
51
+ True in up_resolutions,
52
+ nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
53
+ )
54
+ self.shortcut_projection = set_default_layer(
55
+ in_channels != out_channels,
56
+ nn.Conv2d, (in_channels, out_channels), {'kernel_size': 1}
57
+ )
58
+ self.shortcut_down_sample = set_default_layer(
59
+ False in up_resolutions,
60
+ nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
61
+ )
62
+
63
+ def forward(self, x, time_embed):
64
+ out = x
65
+ for resnet_block in self.resnet_blocks:
66
+ out = resnet_block(out, time_embed)
67
+
68
+ x = self.shortcut_up_sample(x)
69
+ x = self.shortcut_projection(x)
70
+ x = self.shortcut_down_sample(x)
71
+ x = x + out
72
+ return x
73
+
74
+
75
+ class AttentionPolling(nn.Module):
76
+
77
+ def __init__(self, num_channels, context_dim, head_dim=64):
78
+ super().__init__()
79
+ self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
80
+
81
+ def forward(self, x, context, context_mask=None):
82
+ context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
83
+ return x + context.squeeze(1)
84
+
85
+
86
+ class AttentionBlock(nn.Module):
87
+
88
+ def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
89
+ super().__init__()
90
+ self.in_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
91
+ self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
92
+
93
+ hidden_channels = expansion_ratio * num_channels
94
+ self.out_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
95
+ self.feed_forward = nn.Sequential(
96
+ nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
97
+ nn.SiLU(),
98
+ nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
99
+ )
100
+
101
+ def forward(self, x, time_embed, context=None, context_mask=None):
102
+ height, width = x.shape[-2:]
103
+ out = self.in_norm(x, time_embed)
104
+ out = rearrange(out, 'b c h w -> b (h w) c', h=height, w=width)
105
+ context = set_default_item(exist(context), context, out)
106
+ out = self.attention(out, context, context_mask)
107
+ out = rearrange(out, 'b (h w) c -> b c h w', h=height, w=width)
108
+ x = x + out
109
+
110
+ out = self.out_norm(x, time_embed)
111
+ out = self.feed_forward(out)
112
+ x = x + out
113
+ return x
114
+
115
+
116
+ class DownSampleBlock(nn.Module):
117
+
118
+ def __init__(
119
+ self, in_channels, out_channels, time_embed_dim, context_dim=None,
120
+ num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
121
+ down_sample=True, self_attention=True
122
+ ):
123
+ super().__init__()
124
+ self.self_attention_block = set_default_layer(
125
+ self_attention,
126
+ AttentionBlock,
127
+ (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
128
+ layer_2=Identity
129
+ )
130
+
131
+ up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
132
+ hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
133
+ self.resnet_attn_blocks = nn.ModuleList([
134
+ nn.ModuleList([
135
+ ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
136
+ set_default_layer(
137
+ exist(context_dim),
138
+ AttentionBlock,
139
+ (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
140
+ layer_2=Identity
141
+ ),
142
+ ResNetBlock(out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution),
143
+ ]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
144
+ ])
145
+
146
+ def forward(self, x, time_embed, context=None, context_mask=None, control_net_residual=None):
147
+ x = self.self_attention_block(x, time_embed)
148
+ for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
149
+ x = in_resnet_block(x, time_embed)
150
+ x = attention(x, time_embed, context, context_mask)
151
+ x = out_resnet_block(x, time_embed)
152
+ return x
153
+
154
+
155
+ class UpSampleBlock(nn.Module):
156
+
157
+ def __init__(
158
+ self, in_channels, cat_dim, out_channels, time_embed_dim, context_dim=None,
159
+ num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
160
+ up_sample=True, self_attention=True
161
+ ):
162
+ super().__init__()
163
+ up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
164
+ hidden_channels = [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) + [(in_channels, out_channels)]
165
+ self.resnet_attn_blocks = nn.ModuleList([
166
+ nn.ModuleList([
167
+ ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution),
168
+ set_default_layer(
169
+ exist(context_dim),
170
+ AttentionBlock,
171
+ (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
172
+ layer_2=Identity
173
+ ),
174
+ ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
175
+ ]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
176
+ ])
177
+
178
+ self.self_attention_block = set_default_layer(
179
+ self_attention,
180
+ AttentionBlock,
181
+ (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
182
+ layer_2=Identity
183
+ )
184
+
185
+ def forward(self, x, time_embed, context=None, context_mask=None):
186
+ for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
187
+ x = in_resnet_block(x, time_embed)
188
+ x = attention(x, time_embed, context, context_mask)
189
+ x = out_resnet_block(x, time_embed)
190
+ x = self.self_attention_block(x, time_embed)
191
+ return x
192
+
193
+ class ControlNetModel(nn.Module):
194
+ def __init__(self,
195
+ model_channels,
196
+ init_channels=None,
197
+ num_channels=3,
198
+ out_channels=4,
199
+ time_embed_dim=None,
200
+ context_dim=None,
201
+ groups=32,
202
+ head_dim=64,
203
+ expansion_ratio=4,
204
+ compression_ratio=2,
205
+ dim_mult=(1, 2, 4, 8),
206
+ num_blocks=(3, 3, 3, 3),
207
+ add_cross_attention=(False, True, True, True),
208
+ add_self_attention=(False, True, True, True)
209
+ ):
210
+ super().__init__()
211
+ init_channels = init_channels or model_channels
212
+ self.to_time_embed = nn.Sequential(
213
+ SinusoidalPosEmb(init_channels),
214
+ nn.Linear(init_channels, time_embed_dim),
215
+ nn.SiLU(),
216
+ nn.Linear(time_embed_dim, time_embed_dim)
217
+ )
218
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
219
+
220
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
221
+
222
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
223
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
224
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
225
+ layer_params = [num_blocks, text_dims, add_self_attention]
226
+ rev_layer_params = map(reversed, layer_params)
227
+
228
+ cat_dims = []
229
+ self.num_levels = len(in_out_dims)
230
+ self.down_samples = nn.ModuleList([])
231
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
232
+ down_sample = level != (self.num_levels - 1)
233
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
234
+ self.down_samples.append(
235
+ DownSampleBlock(
236
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
237
+ compression_ratio, down_sample, self_attention
238
+ )
239
+ )
240
+
241
+ def forward(self, x, time, context=None, context_mask=None):
242
+ time_embed = self.to_time_embed(time)
243
+ if exist(context):
244
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
245
+
246
+ hidden_states = []
247
+ x = self.in_layer(x)
248
+ for level, down_sample in enumerate(self.down_samples):
249
+ x = down_sample(x, time_embed, context, context_mask)
250
+ if level != self.num_levels - 1:
251
+ hidden_states.append(x)
252
+ return hidden_states
253
+
254
+ class UNet(nn.Module):
255
+
256
+ def __init__(self,
257
+ model_channels,
258
+ init_channels=None,
259
+ num_channels=3,
260
+ out_channels=4,
261
+ time_embed_dim=None,
262
+ context_dim=None,
263
+ groups=32,
264
+ head_dim=64,
265
+ expansion_ratio=4,
266
+ compression_ratio=2,
267
+ dim_mult=(1, 2, 4, 8),
268
+ num_blocks=(3, 3, 3, 3),
269
+ add_cross_attention=(False, True, True, True),
270
+ add_self_attention=(False, True, True, True),
271
+ *args,
272
+ **kwargs,
273
+ ):
274
+ super().__init__()
275
+ init_channels = init_channels or model_channels
276
+ self.to_time_embed = nn.Sequential(
277
+ SinusoidalPosEmb(init_channels),
278
+ nn.Linear(init_channels, time_embed_dim),
279
+ nn.SiLU(),
280
+ nn.Linear(time_embed_dim, time_embed_dim)
281
+ )
282
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
283
+
284
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
285
+
286
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
287
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
288
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
289
+ layer_params = [num_blocks, text_dims, add_self_attention]
290
+ rev_layer_params = map(reversed, layer_params)
291
+
292
+ cat_dims = []
293
+ self.num_levels = len(in_out_dims)
294
+ self.down_samples = nn.ModuleList([])
295
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
296
+ down_sample = level != (self.num_levels - 1)
297
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
298
+ self.down_samples.append(
299
+ DownSampleBlock(
300
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
301
+ compression_ratio, down_sample, self_attention
302
+ )
303
+ )
304
+
305
+ self.up_samples = nn.ModuleList([])
306
+ for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
307
+ up_sample = level != 0
308
+ self.up_samples.append(
309
+ UpSampleBlock(
310
+ in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
311
+ expansion_ratio, compression_ratio, up_sample, self_attention
312
+ )
313
+ )
314
+
315
+ self.out_layer = nn.Sequential(
316
+ nn.GroupNorm(groups, init_channels),
317
+ nn.SiLU(),
318
+ nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
319
+ )
320
+
321
+ self.control_net = None
322
+
323
+ def forward(self, x, time, context=None, context_mask=None, control_net_residual=None):
324
+ time_embed = self.to_time_embed(time)
325
+ if exist(context):
326
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
327
+
328
+ hidden_states = []
329
+ x = self.in_layer(x)
330
+ for level, down_sample in enumerate(self.down_samples):
331
+ x = down_sample(x, time_embed, context, context_mask, control_net_residual)
332
+ if level != self.num_levels - 1:
333
+ hidden_states.append(x)
334
+ for level, up_sample in enumerate(self.up_samples):
335
+ if level != 0:
336
+ x = torch.cat([x, hidden_states.pop()], dim=1)
337
+ x = up_sample(x, time_embed, context, context_mask)
338
+ x = self.out_layer(x)
339
+ return x
340
+
341
+
342
+ class ControlNetModel(nn.Module):
343
+ def __init__(self,
344
+ model_channels,
345
+ init_channels=None,
346
+ num_channels=3,
347
+ out_channels=4,
348
+ time_embed_dim=None,
349
+ context_dim=None,
350
+ groups=32,
351
+ head_dim=64,
352
+ expansion_ratio=4,
353
+ compression_ratio=2,
354
+ dim_mult=(1, 2, 4, 8),
355
+ num_blocks=(3, 3, 3, 3),
356
+ add_cross_attention=(False, True, True, True),
357
+ add_self_attention=(False, True, True, True),
358
+ *args,
359
+ **kwargs,
360
+ ):
361
+ super().__init__()
362
+ init_channels = init_channels or model_channels
363
+ self.to_time_embed = nn.Sequential(
364
+ SinusoidalPosEmb(init_channels),
365
+ nn.Linear(init_channels, time_embed_dim),
366
+ nn.SiLU(),
367
+ nn.Linear(time_embed_dim, time_embed_dim)
368
+ )
369
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
370
+
371
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
372
+
373
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
374
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
375
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
376
+ layer_params = [num_blocks, text_dims, add_self_attention]
377
+ rev_layer_params = map(reversed, layer_params)
378
+
379
+ cat_dims = []
380
+ self.num_levels = len(in_out_dims)
381
+ self.down_samples = nn.ModuleList([])
382
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
383
+ down_sample = level != (self.num_levels - 1)
384
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
385
+ self.down_samples.append(
386
+ DownSampleBlock(
387
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
388
+ compression_ratio, down_sample, self_attention
389
+ )
390
+ )
391
+
392
+ def forward(self, x, time, context=None, context_mask=None):
393
+ time_embed = self.to_time_embed(time)
394
+ if exist(context):
395
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
396
+
397
+ hidden_states = []
398
+ x = self.in_layer(x)
399
+ for level, down_sample in enumerate(self.down_samples):
400
+ x = down_sample(x, time_embed, context, context_mask)
401
+ if level != self.num_levels - 1:
402
+ hidden_states.append(x)
403
+ return hidden_states
404
+
405
+ class ControlUNet(nn.Module):
406
+
407
+ def __init__(self,
408
+ model_channels,
409
+ init_channels=None,
410
+ num_channels=3,
411
+ out_channels=4,
412
+ time_embed_dim=None,
413
+ context_dim=None,
414
+ groups=32,
415
+ head_dim=64,
416
+ expansion_ratio=4,
417
+ compression_ratio=2,
418
+ dim_mult=(1, 2, 4, 8),
419
+ num_blocks=(3, 3, 3, 3),
420
+ add_cross_attention=(False, True, True, True),
421
+ add_self_attention=(False, True, True, True),
422
+ control_net_channels=5,
423
+ *args,
424
+ **kwargs,
425
+ ):
426
+ super().__init__()
427
+ init_channels = init_channels or model_channels
428
+ self.to_time_embed = nn.Sequential(
429
+ SinusoidalPosEmb(init_channels),
430
+ nn.Linear(init_channels, time_embed_dim),
431
+ nn.SiLU(),
432
+ nn.Linear(time_embed_dim, time_embed_dim)
433
+ )
434
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
435
+
436
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
437
+
438
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
439
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
440
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
441
+ layer_params = [num_blocks, text_dims, add_self_attention]
442
+ rev_layer_params = map(reversed, layer_params)
443
+
444
+ cat_dims = []
445
+ self.num_levels = len(in_out_dims)
446
+ self.down_samples = nn.ModuleList([])
447
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
448
+ down_sample = level != (self.num_levels - 1)
449
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
450
+ self.down_samples.append(
451
+ DownSampleBlock(
452
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
453
+ compression_ratio, down_sample, self_attention
454
+ )
455
+ )
456
+
457
+ self.up_samples = nn.ModuleList([])
458
+ for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
459
+ up_sample = level != 0
460
+ self.up_samples.append(
461
+ UpSampleBlock(
462
+ in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
463
+ expansion_ratio, compression_ratio, up_sample, self_attention
464
+ )
465
+ )
466
+
467
+ self.out_layer = nn.Sequential(
468
+ nn.GroupNorm(groups, init_channels),
469
+ nn.SiLU(),
470
+ nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
471
+ )
472
+
473
+ self.control_net = ControlNetModel(model_channels,
474
+ init_channels,
475
+ control_net_channels,
476
+ out_channels,
477
+ time_embed_dim,
478
+ context_dim,
479
+ groups,
480
+ head_dim,
481
+ expansion_ratio,
482
+ compression_ratio,
483
+ dim_mult,
484
+ num_blocks,
485
+ add_cross_attention,
486
+ add_self_attention)
487
+
488
+ def forward(self, x, time, context=None, context_mask=None, control_net_data=None):
489
+ time_embed = self.to_time_embed(time)
490
+ if exist(context):
491
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
492
+
493
+ control_net_hiddens = self.control_net(control_net_data, time, context, context_mask)
494
+ hidden_states = []
495
+ x = self.in_layer(x)
496
+ for level, down_sample in enumerate(self.down_samples):
497
+ x = down_sample(x, time_embed, context, context_mask)
498
+ if level != self.num_levels - 1:
499
+ x += control_net_hiddens.pop(0)
500
+ hidden_states.append(x)
501
+ for level, up_sample in enumerate(self.up_samples):
502
+ if level != 0:
503
+ x = torch.cat([x, hidden_states.pop()], dim=1)
504
+ x = up_sample(x, time_embed, context, context_mask)
505
+ x = self.out_layer(x)
506
+ return x
507
+
508
+
509
+ def get_control_unet(conf):
510
+ unet = ControlUNet(**conf)
511
+ return unet
512
+
513
+
514
+ def get_unet(conf):
515
+ unet = UNet(**conf)
516
+ return unet
kandinsky3/model/__init__.py ADDED
File without changes
kandinsky3/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (212 Bytes). View file
 
kandinsky3/model/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (6.4 kB). View file
 
kandinsky3/model/__pycache__/nn.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
kandinsky3/model/__pycache__/unet.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
kandinsky3/model/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
kandinsky3/model/diffusion.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from tqdm import tqdm
6
+
7
+ from .utils import get_tensor_items
8
+
9
+
10
+ def get_named_beta_schedule(schedule_name, timesteps):
11
+ if schedule_name == "linear":
12
+ scale = 1000 / timesteps
13
+ beta_start = scale * 0.0001
14
+ beta_end = scale * 0.02
15
+ return torch.linspace(
16
+ beta_start, beta_end, timesteps, dtype=torch.float32
17
+ )
18
+ elif schedule_name == "cosine":
19
+ alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
20
+ betas = []
21
+ for i in range(timesteps):
22
+ t1 = i / timesteps
23
+ t2 = (i + 1) / timesteps
24
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
25
+ return torch.tensor(betas, dtype=torch.float32)
26
+
27
+
28
+ class BaseDiffusion:
29
+
30
+ def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
31
+ self.betas = betas
32
+ self.num_timesteps = betas.shape[0]
33
+
34
+ alphas = 1. - betas
35
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
36
+ self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
37
+
38
+ # calculate q(x_t | x_{t-1})
39
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
40
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
41
+
42
+ # calculate q(x_{t-1} | x_t, x_0)
43
+ self.posterior_mean_coef_1 = torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod)
44
+ self.posterior_mean_coef_2 = torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
45
+ self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
46
+ self.posterior_log_variance = torch.log(
47
+ torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])
48
+ )
49
+
50
+ self.percentile = percentile
51
+ self.time_scale = 1000 // self.num_timesteps
52
+ self.gen_noise = gen_noise
53
+ self.jump_length = 3
54
+
55
+ def process_x_start(self, x_start):
56
+ bs, ndims = x_start.shape[0], len(x_start.shape[1:])
57
+ if self.percentile is not None:
58
+ quantile = torch.quantile(
59
+ rearrange(x_start, 'b ... -> b (...)').abs(),
60
+ self.percentile,
61
+ dim=-1
62
+ )
63
+ quantile = torch.clip(quantile, min=1.)
64
+ quantile = quantile.reshape(bs, *((1,) * ndims))
65
+ return torch.clip(x_start, -quantile, quantile) / quantile
66
+ else:
67
+ return torch.clip(x_start, -1., 1.)
68
+
69
+ def get_x_start(self, x, t, noise):
70
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
71
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape)
72
+ pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod
73
+ return pred_x_start
74
+
75
+ def get_noise(self, x, t, x_start):
76
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
77
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
78
+ pred_noise = (x - sqrt_alphas_cumprod * x_start) / sqrt_one_minus_alphas_cumprod
79
+ return pred_noise
80
+
81
+ def q_sample(self, x_start, t, noise=None):
82
+ if noise is None:
83
+ noise = self.gen_noise(x_start)
84
+ sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
85
+ sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
86
+ x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
87
+ return x_t
88
+
89
+ def q_posterior_mean_variance(self, x_start, x_t, t):
90
+ posterior_mean_coef_1 = get_tensor_items(self.posterior_mean_coef_1, t, x_start.shape)
91
+ posterior_mean_coef_2 = get_tensor_items(self.posterior_mean_coef_2, t, x_t.shape)
92
+ posterior_mean = posterior_mean_coef_1 * x_start + posterior_mean_coef_2 * x_t
93
+
94
+ posterior_variance = get_tensor_items(self.posterior_variance, t, x_start.shape)
95
+ posterior_log_variance = get_tensor_items(self.posterior_log_variance, t, x_start.shape)
96
+ return posterior_mean, posterior_variance, posterior_log_variance
97
+
98
+ def q_posterior_variance(self, t, prev_t, shape, eta=1., ):
99
+ alphas_cumprod = get_tensor_items(self.alphas_cumprod, t, shape)
100
+ prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, shape)
101
+
102
+ posterior_variance = torch.sqrt(
103
+ eta * (1. - alphas_cumprod / prev_alphas_cumprod) * (1. - prev_alphas_cumprod) / (1. - alphas_cumprod)
104
+ )
105
+ return posterior_variance
106
+
107
+ def text_guidance(
108
+ self, model, x, t, context, context_mask, null_embedding, guidance_weight_text,
109
+ uncondition_context=None, uncondition_context_mask=None, mask=None, masked_latent=None
110
+ ):
111
+ large_x = x.repeat(2, 1, 1, 1)
112
+ large_t = t.repeat(2).to(x.dtype)
113
+
114
+ if uncondition_context is None:
115
+ uncondition_context = torch.zeros_like(context)
116
+ uncondition_context_mask = torch.zeros_like(context_mask)
117
+ uncondition_context[:, 0] = null_embedding
118
+ uncondition_context_mask[:, 0] = 1
119
+ large_context = torch.cat([context, uncondition_context])
120
+ large_context_mask = torch.cat([context_mask, uncondition_context_mask])
121
+
122
+ if mask is not None:
123
+ mask = mask.repeat(2, 1, 1, 1)
124
+ if masked_latent is not None:
125
+ masked_latent = masked_latent.repeat(2, 1, 1, 1)
126
+
127
+ if model.in_layer.in_channels == 9:
128
+ large_x = torch.cat([large_x, mask, masked_latent], dim=1)
129
+
130
+ pred_large_noise = model(large_x, large_t * self.time_scale, large_context, large_context_mask.bool())
131
+ pred_noise, uncond_pred_noise = torch.chunk(pred_large_noise, 2)
132
+ pred_noise = (guidance_weight_text + 1.) * pred_noise - guidance_weight_text * uncond_pred_noise
133
+ return pred_noise
134
+
135
+ def p_mean_variance(
136
+ self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
137
+ negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
138
+ ):
139
+
140
+ pred_noise = self.text_guidance(
141
+ model, x, t, context, context_mask, null_embedding, guidance_weight_text,
142
+ negative_context, negative_context_mask, mask, masked_latent
143
+ )
144
+
145
+ pred_x_start = self.get_x_start(x, t, pred_noise)
146
+ pred_x_start = self.process_x_start(pred_x_start)
147
+ pred_noise = self.get_noise(x, t, pred_x_start)
148
+ pred_var = self.q_posterior_variance(t, prev_t, x.shape, eta)
149
+
150
+ prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, x.shape)
151
+ pred_mean = torch.sqrt(prev_alphas_cumprod) * pred_x_start
152
+ pred_mean += torch.sqrt(1. - prev_alphas_cumprod - pred_var ** 2) * pred_noise
153
+ return pred_mean, pred_var
154
+
155
+ # @torch.no_grad()
156
+ def p_sample(
157
+ self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
158
+ negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
159
+ ):
160
+ bs = x.shape[0]
161
+ ndims = len(x.shape[1:])
162
+ pred_mean, pred_var = self.p_mean_variance(
163
+ model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta,
164
+ negative_context=negative_context, negative_context_mask=negative_context_mask,
165
+ mask=mask, masked_latent=masked_latent
166
+ )
167
+ noise = torch.randn_like(x)
168
+ mask = (prev_t != 0).reshape(bs, *((1,) * ndims))
169
+ sample = pred_mean + mask * pred_var * noise
170
+ return sample
171
+
172
+ # @torch.no_grad()
173
+ def p_sample_loop(
174
+ self, model, shape, times, device, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
175
+ negative_context=None, negative_context_mask=None, mask=None, masked_latent=None, gan=False,
176
+ ):
177
+ img = torch.randn(*shape, device=device)
178
+ times = times + [0, ]
179
+ times = list(zip(times[:-1], times[1:]))
180
+
181
+ for time, prev_time in tqdm(times):
182
+ time = torch.tensor([time] * shape[0], device=device)
183
+ if gan:
184
+ x_t = self.q_sample(img, time)
185
+ pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool())
186
+ img = self.get_x_start(x_t, time, pred_noise)
187
+ else:
188
+ prev_time = torch.tensor([prev_time] * shape[0], device=device)
189
+ img = self.p_sample(
190
+ model, img, time, prev_time, context, context_mask, null_embedding, guidance_weight_text, eta,
191
+ negative_context=negative_context, negative_context_mask=negative_context_mask,
192
+ mask=mask, masked_latent=masked_latent
193
+ )
194
+ return img
195
+
196
+
197
+ def get_diffusion(conf):
198
+ betas = get_named_beta_schedule(**conf.schedule_params)
199
+ base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
200
+ return base_diffusion
kandinsky3/model/nn.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn, einsum
5
+ from einops import rearrange, repeat
6
+
7
+ from .utils import exist
8
+
9
+
10
+ class Identity(nn.Module):
11
+ def __init__(self, *args, **kwargs):
12
+ super().__init__()
13
+
14
+ @staticmethod
15
+ def forward(x, *args, **kwargs):
16
+ return x
17
+
18
+
19
+ class SinusoidalPosEmb(nn.Module):
20
+
21
+ def __init__(self, dim):
22
+ super().__init__()
23
+ self.dim = dim
24
+
25
+ def forward(self, x):
26
+ half_dim = self.dim // 2
27
+ emb = math.log(10000) / (half_dim - 1)
28
+ emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
29
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
30
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
31
+
32
+
33
+ class ConditionalGroupNorm(nn.Module):
34
+
35
+ def __init__(self, groups, normalized_shape, context_dim):
36
+ super().__init__()
37
+ self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
38
+ self.context_mlp = nn.Sequential(
39
+ nn.SiLU(),
40
+ nn.Linear(context_dim, 2 * normalized_shape)
41
+ )
42
+ self.context_mlp[1].weight.data.zero_()
43
+ self.context_mlp[1].bias.data.zero_()
44
+
45
+ def forward(self, x, context):
46
+ context = self.context_mlp(context)
47
+ ndims = ' 1' * len(x.shape[2:])
48
+ context = rearrange(context, f'b c -> b c{ndims}')
49
+
50
+ scale, shift = context.chunk(2, dim=1)
51
+ x = self.norm(x) * (scale + 1.) + shift
52
+ return x
53
+
54
+
55
+ class Attention(nn.Module):
56
+
57
+ def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
58
+ super().__init__()
59
+ assert out_channels % head_dim == 0
60
+ self.num_heads = out_channels // head_dim
61
+ self.scale = head_dim ** -0.5
62
+
63
+ self.to_query = nn.Linear(in_channels, out_channels, bias=False)
64
+ self.to_key = nn.Linear(context_dim, out_channels, bias=False)
65
+ self.to_value = nn.Linear(context_dim, out_channels, bias=False)
66
+
67
+ self.output_layer = nn.Linear(out_channels, out_channels, bias=False)
68
+
69
+ def forward(self, x, context, context_mask=None):
70
+ query = rearrange(self.to_query(x), 'b n (h d) -> b h n d', h=self.num_heads)
71
+ key = rearrange(self.to_key(context), 'b n (h d) -> b h n d', h=self.num_heads)
72
+ value = rearrange(self.to_value(context), 'b n (h d) -> b h n d', h=self.num_heads)
73
+
74
+ attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key) * self.scale
75
+ if exist(context_mask):
76
+ max_neg_value = -torch.finfo(attention_matrix.dtype).max
77
+ context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
78
+ attention_matrix = attention_matrix.masked_fill(~context_mask, max_neg_value)
79
+ attention_matrix = attention_matrix.softmax(dim=-1)
80
+
81
+ out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
82
+ out = rearrange(out, 'b h n d -> b n (h d)')
83
+ out = self.output_layer(out)
84
+ return out
kandinsky3/model/unet.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ from einops import rearrange
4
+
5
+ from .nn import Identity, Attention, SinusoidalPosEmb, ConditionalGroupNorm
6
+ from .utils import exist, set_default_item, set_default_layer
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class Block(nn.Module):
11
+
12
+ def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
13
+ super().__init__()
14
+ self.group_norm = ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
15
+ self.activation = nn.SiLU()
16
+ self.up_sample = set_default_layer(
17
+ exist(up_resolution) and up_resolution,
18
+ nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
19
+ )
20
+ padding = set_default_item(kernel_size == 1, 0, 1)
21
+ self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
22
+ self.down_sample = set_default_layer(
23
+ exist(up_resolution) and not up_resolution,
24
+ nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
25
+ )
26
+
27
+ def forward(self, x, time_embed):
28
+ x = self.group_norm(x, time_embed)
29
+ x = self.activation(x)
30
+ x = self.up_sample(x)
31
+ x = self.projection(x)
32
+ x = self.down_sample(x)
33
+ return x
34
+
35
+
36
+ class ResNetBlock(nn.Module):
37
+
38
+ def __init__(
39
+ self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4*[None]
40
+ ):
41
+ super().__init__()
42
+ kernel_sizes = [1, 3, 3, 1]
43
+ hidden_channel = max(in_channels, out_channels) // compression_ratio
44
+ hidden_channels = [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
45
+ self.resnet_blocks = nn.ModuleList([
46
+ Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
47
+ for (in_channel, out_channel), kernel_size, up_resolution in zip(hidden_channels, kernel_sizes, up_resolutions)
48
+ ])
49
+
50
+ self.shortcut_up_sample = set_default_layer(
51
+ True in up_resolutions,
52
+ nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
53
+ )
54
+ self.shortcut_projection = set_default_layer(
55
+ in_channels != out_channels,
56
+ nn.Conv2d, (in_channels, out_channels), {'kernel_size': 1}
57
+ )
58
+ self.shortcut_down_sample = set_default_layer(
59
+ False in up_resolutions,
60
+ nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
61
+ )
62
+
63
+ def forward(self, x, time_embed):
64
+ out = x
65
+ for resnet_block in self.resnet_blocks:
66
+ out = resnet_block(out, time_embed)
67
+
68
+ x = self.shortcut_up_sample(x)
69
+ x = self.shortcut_projection(x)
70
+ x = self.shortcut_down_sample(x)
71
+ x = x + out
72
+ return x
73
+
74
+
75
+ class AttentionPolling(nn.Module):
76
+
77
+ def __init__(self, num_channels, context_dim, head_dim=64):
78
+ super().__init__()
79
+ self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
80
+
81
+ def forward(self, x, context, context_mask=None):
82
+ context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
83
+ return x + context.squeeze(1)
84
+
85
+
86
+ class AttentionBlock(nn.Module):
87
+
88
+ def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
89
+ super().__init__()
90
+ self.in_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
91
+ self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
92
+
93
+ hidden_channels = expansion_ratio * num_channels
94
+ self.out_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
95
+ self.feed_forward = nn.Sequential(
96
+ nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
97
+ nn.SiLU(),
98
+ nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
99
+ )
100
+
101
+ def forward(self, x, time_embed, context=None, context_mask=None):
102
+ height, width = x.shape[-2:]
103
+ out = self.in_norm(x, time_embed)
104
+ out = rearrange(out, 'b c h w -> b (h w) c', h=height, w=width)
105
+ context = set_default_item(exist(context), context, out)
106
+ out = self.attention(out, context, context_mask)
107
+ out = rearrange(out, 'b (h w) c -> b c h w', h=height, w=width)
108
+ x = x + out
109
+
110
+ out = self.out_norm(x, time_embed)
111
+ out = self.feed_forward(out)
112
+ x = x + out
113
+ return x
114
+
115
+
116
+ class DownSampleBlock(nn.Module):
117
+
118
+ def __init__(
119
+ self, in_channels, out_channels, time_embed_dim, context_dim=None,
120
+ num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
121
+ down_sample=True, self_attention=True
122
+ ):
123
+ super().__init__()
124
+ self.self_attention_block = set_default_layer(
125
+ self_attention,
126
+ AttentionBlock,
127
+ (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
128
+ layer_2=Identity
129
+ )
130
+
131
+ up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
132
+ hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
133
+ self.resnet_attn_blocks = nn.ModuleList([
134
+ nn.ModuleList([
135
+ ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
136
+ set_default_layer(
137
+ exist(context_dim),
138
+ AttentionBlock,
139
+ (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
140
+ layer_2=Identity
141
+ ),
142
+ ResNetBlock(out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution),
143
+ ]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
144
+ ])
145
+
146
+ def forward(self, x, time_embed, context=None, context_mask=None, control_net_residual=None):
147
+ x = self.self_attention_block(x, time_embed)
148
+ for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
149
+ x = in_resnet_block(x, time_embed)
150
+ x = attention(x, time_embed, context, context_mask)
151
+ x = out_resnet_block(x, time_embed)
152
+ return x
153
+
154
+
155
+ class UpSampleBlock(nn.Module):
156
+
157
+ def __init__(
158
+ self, in_channels, cat_dim, out_channels, time_embed_dim, context_dim=None,
159
+ num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
160
+ up_sample=True, self_attention=True
161
+ ):
162
+ super().__init__()
163
+ up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
164
+ hidden_channels = [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) + [(in_channels, out_channels)]
165
+ self.resnet_attn_blocks = nn.ModuleList([
166
+ nn.ModuleList([
167
+ ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution),
168
+ set_default_layer(
169
+ exist(context_dim),
170
+ AttentionBlock,
171
+ (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
172
+ layer_2=Identity
173
+ ),
174
+ ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
175
+ ]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
176
+ ])
177
+
178
+ self.self_attention_block = set_default_layer(
179
+ self_attention,
180
+ AttentionBlock,
181
+ (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
182
+ layer_2=Identity
183
+ )
184
+
185
+ def forward(self, x, time_embed, context=None, context_mask=None):
186
+ for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
187
+ x = in_resnet_block(x, time_embed)
188
+ x = attention(x, time_embed, context, context_mask)
189
+ x = out_resnet_block(x, time_embed)
190
+ x = self.self_attention_block(x, time_embed)
191
+ return x
192
+
193
+ class ControlNetModel(nn.Module):
194
+ def __init__(self,
195
+ model_channels,
196
+ init_channels=None,
197
+ num_channels=3,
198
+ out_channels=4,
199
+ time_embed_dim=None,
200
+ context_dim=None,
201
+ groups=32,
202
+ head_dim=64,
203
+ expansion_ratio=4,
204
+ compression_ratio=2,
205
+ dim_mult=(1, 2, 4, 8),
206
+ num_blocks=(3, 3, 3, 3),
207
+ add_cross_attention=(False, True, True, True),
208
+ add_self_attention=(False, True, True, True)
209
+ ):
210
+ super().__init__()
211
+ init_channels = init_channels or model_channels
212
+ self.to_time_embed = nn.Sequential(
213
+ SinusoidalPosEmb(init_channels),
214
+ nn.Linear(init_channels, time_embed_dim),
215
+ nn.SiLU(),
216
+ nn.Linear(time_embed_dim, time_embed_dim)
217
+ )
218
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
219
+
220
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
221
+
222
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
223
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
224
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
225
+ layer_params = [num_blocks, text_dims, add_self_attention]
226
+ rev_layer_params = map(reversed, layer_params)
227
+
228
+ cat_dims = []
229
+ self.num_levels = len(in_out_dims)
230
+ self.down_samples = nn.ModuleList([])
231
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
232
+ down_sample = level != (self.num_levels - 1)
233
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
234
+ self.down_samples.append(
235
+ DownSampleBlock(
236
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
237
+ compression_ratio, down_sample, self_attention
238
+ )
239
+ )
240
+
241
+ def forward(self, x, time, context=None, context_mask=None):
242
+ time_embed = self.to_time_embed(time)
243
+ if exist(context):
244
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
245
+
246
+ hidden_states = []
247
+ x = self.in_layer(x)
248
+ for level, down_sample in enumerate(self.down_samples):
249
+ x = down_sample(x, time_embed, context, context_mask)
250
+ if level != self.num_levels - 1:
251
+ hidden_states.append(x)
252
+ return hidden_states
253
+
254
+ class UNet(nn.Module):
255
+
256
+ def __init__(self,
257
+ model_channels,
258
+ init_channels=None,
259
+ num_channels=3,
260
+ out_channels=4,
261
+ time_embed_dim=None,
262
+ context_dim=None,
263
+ groups=32,
264
+ head_dim=64,
265
+ expansion_ratio=4,
266
+ compression_ratio=2,
267
+ dim_mult=(1, 2, 4, 8),
268
+ num_blocks=(3, 3, 3, 3),
269
+ add_cross_attention=(False, True, True, True),
270
+ add_self_attention=(False, True, True, True),
271
+ *args,
272
+ **kwargs,
273
+ ):
274
+ super().__init__()
275
+ init_channels = init_channels or model_channels
276
+ self.to_time_embed = nn.Sequential(
277
+ SinusoidalPosEmb(init_channels),
278
+ nn.Linear(init_channels, time_embed_dim),
279
+ nn.SiLU(),
280
+ nn.Linear(time_embed_dim, time_embed_dim)
281
+ )
282
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
283
+
284
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
285
+
286
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
287
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
288
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
289
+ layer_params = [num_blocks, text_dims, add_self_attention]
290
+ rev_layer_params = map(reversed, layer_params)
291
+
292
+ cat_dims = []
293
+ self.num_levels = len(in_out_dims)
294
+ self.down_samples = nn.ModuleList([])
295
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
296
+ down_sample = level != (self.num_levels - 1)
297
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
298
+ self.down_samples.append(
299
+ DownSampleBlock(
300
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
301
+ compression_ratio, down_sample, self_attention
302
+ )
303
+ )
304
+
305
+ self.up_samples = nn.ModuleList([])
306
+ for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
307
+ up_sample = level != 0
308
+ self.up_samples.append(
309
+ UpSampleBlock(
310
+ in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
311
+ expansion_ratio, compression_ratio, up_sample, self_attention
312
+ )
313
+ )
314
+
315
+ self.out_layer = nn.Sequential(
316
+ nn.GroupNorm(groups, init_channels),
317
+ nn.SiLU(),
318
+ nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
319
+ )
320
+
321
+ self.control_net = None
322
+
323
+ def forward(self, x, time, context=None, context_mask=None, control_net_residual=None):
324
+ time_embed = self.to_time_embed(time)
325
+ if exist(context):
326
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
327
+
328
+ hidden_states = []
329
+ x = self.in_layer(x)
330
+ for level, down_sample in enumerate(self.down_samples):
331
+ x = down_sample(x, time_embed, context, context_mask, control_net_residual)
332
+ if level != self.num_levels - 1:
333
+ hidden_states.append(x)
334
+ for level, up_sample in enumerate(self.up_samples):
335
+ if level != 0:
336
+ x = torch.cat([x, hidden_states.pop()], dim=1)
337
+ x = up_sample(x, time_embed, context, context_mask)
338
+ x = self.out_layer(x)
339
+ return x
340
+
341
+
342
+ class ControlNetModel(nn.Module):
343
+ def __init__(self,
344
+ model_channels,
345
+ init_channels=None,
346
+ num_channels=3,
347
+ out_channels=4,
348
+ time_embed_dim=None,
349
+ context_dim=None,
350
+ groups=32,
351
+ head_dim=64,
352
+ expansion_ratio=4,
353
+ compression_ratio=2,
354
+ dim_mult=(1, 2, 4, 8),
355
+ num_blocks=(3, 3, 3, 3),
356
+ add_cross_attention=(False, True, True, True),
357
+ add_self_attention=(False, True, True, True),
358
+ *args,
359
+ **kwargs,
360
+ ):
361
+ super().__init__()
362
+ init_channels = init_channels or model_channels
363
+ self.to_time_embed = nn.Sequential(
364
+ SinusoidalPosEmb(init_channels),
365
+ nn.Linear(init_channels, time_embed_dim),
366
+ nn.SiLU(),
367
+ nn.Linear(time_embed_dim, time_embed_dim)
368
+ )
369
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
370
+
371
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
372
+
373
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
374
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
375
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
376
+ layer_params = [num_blocks, text_dims, add_self_attention]
377
+ rev_layer_params = map(reversed, layer_params)
378
+
379
+ cat_dims = []
380
+ self.num_levels = len(in_out_dims)
381
+ self.down_samples = nn.ModuleList([])
382
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
383
+ down_sample = level != (self.num_levels - 1)
384
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
385
+ self.down_samples.append(
386
+ DownSampleBlock(
387
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
388
+ compression_ratio, down_sample, self_attention
389
+ )
390
+ )
391
+
392
+ def forward(self, x, time, context=None, context_mask=None):
393
+ time_embed = self.to_time_embed(time)
394
+ if exist(context):
395
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
396
+
397
+ hidden_states = []
398
+ x = self.in_layer(x)
399
+ for level, down_sample in enumerate(self.down_samples):
400
+ x = down_sample(x, time_embed, context, context_mask)
401
+ if level != self.num_levels - 1:
402
+ hidden_states.append(x)
403
+ return hidden_states
404
+
405
+ class ControlUNet(nn.Module):
406
+
407
+ def __init__(self,
408
+ model_channels,
409
+ init_channels=None,
410
+ num_channels=3,
411
+ out_channels=4,
412
+ time_embed_dim=None,
413
+ context_dim=None,
414
+ groups=32,
415
+ head_dim=64,
416
+ expansion_ratio=4,
417
+ compression_ratio=2,
418
+ dim_mult=(1, 2, 4, 8),
419
+ num_blocks=(3, 3, 3, 3),
420
+ add_cross_attention=(False, True, True, True),
421
+ add_self_attention=(False, True, True, True),
422
+ control_net_channels=5,
423
+ *args,
424
+ **kwargs,
425
+ ):
426
+ super().__init__()
427
+ init_channels = init_channels or model_channels
428
+ self.to_time_embed = nn.Sequential(
429
+ SinusoidalPosEmb(init_channels),
430
+ nn.Linear(init_channels, time_embed_dim),
431
+ nn.SiLU(),
432
+ nn.Linear(time_embed_dim, time_embed_dim)
433
+ )
434
+ self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
435
+
436
+ self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
437
+
438
+ hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
439
+ in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
440
+ text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
441
+ layer_params = [num_blocks, text_dims, add_self_attention]
442
+ rev_layer_params = map(reversed, layer_params)
443
+
444
+ cat_dims = []
445
+ self.num_levels = len(in_out_dims)
446
+ self.down_samples = nn.ModuleList([])
447
+ for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
448
+ down_sample = level != (self.num_levels - 1)
449
+ cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
450
+ self.down_samples.append(
451
+ DownSampleBlock(
452
+ in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
453
+ compression_ratio, down_sample, self_attention
454
+ )
455
+ )
456
+
457
+ self.up_samples = nn.ModuleList([])
458
+ for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
459
+ up_sample = level != 0
460
+ self.up_samples.append(
461
+ UpSampleBlock(
462
+ in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
463
+ expansion_ratio, compression_ratio, up_sample, self_attention
464
+ )
465
+ )
466
+
467
+ self.out_layer = nn.Sequential(
468
+ nn.GroupNorm(groups, init_channels),
469
+ nn.SiLU(),
470
+ nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
471
+ )
472
+
473
+ self.control_net = ControlNetModel(model_channels,
474
+ init_channels,
475
+ control_net_channels,
476
+ out_channels,
477
+ time_embed_dim,
478
+ context_dim,
479
+ groups,
480
+ head_dim,
481
+ expansion_ratio,
482
+ compression_ratio,
483
+ dim_mult,
484
+ num_blocks,
485
+ add_cross_attention,
486
+ add_self_attention)
487
+
488
+ def forward(self, x, time, context=None, context_mask=None, control_net_data=None):
489
+ time_embed = self.to_time_embed(time)
490
+ if exist(context):
491
+ time_embed = self.feature_pooling(time_embed, context, context_mask)
492
+
493
+ control_net_hiddens = self.control_net(control_net_data, time, context, context_mask)
494
+ hidden_states = []
495
+ x = self.in_layer(x)
496
+ for level, down_sample in enumerate(self.down_samples):
497
+ x = down_sample(x, time_embed, context, context_mask)
498
+ if level != self.num_levels - 1:
499
+ x += control_net_hiddens.pop(0)
500
+ hidden_states.append(x)
501
+ for level, up_sample in enumerate(self.up_samples):
502
+ if level != 0:
503
+ x = torch.cat([x, hidden_states.pop()], dim=1)
504
+ x = up_sample(x, time_embed, context, context_mask)
505
+ x = self.out_layer(x)
506
+ return x
507
+
508
+
509
+ def get_control_unet(conf):
510
+ unet = ControlUNet(**conf)
511
+ return unet
512
+
513
+
514
+ def get_unet(conf):
515
+ unet = UNet(**conf)
516
+ return unet
kandinsky3/model/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Identity
2
+ from einops import rearrange
3
+
4
+
5
+ def exist(item):
6
+ return item is not None
7
+
8
+
9
+ def set_default_item(condition, item_1, item_2=None):
10
+ if condition:
11
+ return item_1
12
+ else:
13
+ return item_2
14
+
15
+
16
+ def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=Identity, args_2=[], kwargs_2={}):
17
+ if condition:
18
+ return layer_1(*args_1, **kwargs_1)
19
+ else:
20
+ return layer_2(*args_2, **kwargs_2)
21
+
22
+
23
+ def get_tensor_items(x, pos, broadcast_shape):
24
+ device = pos.device
25
+ bs = pos.shape[0]
26
+ ndims = len(broadcast_shape[1:])
27
+ x = x.cpu()[pos.cpu()]
28
+ return x.reshape(bs, *((1,) * ndims)).to(device)
29
+
30
+
31
+ def local_patching(x, height, width, group_size):
32
+ if group_size > 0:
33
+ x = rearrange(
34
+ x, 'b c (h g1) (w g2) -> b (h w) (g1 g2) c',
35
+ h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
36
+ )
37
+ else:
38
+ x = rearrange(x, 'b c h w -> b (h w) c', h=height, w=width)
39
+ return x
40
+
41
+
42
+ def local_merge(x, height, width, group_size):
43
+ if group_size > 0:
44
+ x = rearrange(
45
+ x, 'b (h w) (g1 g2) c -> b c (h g1) (w g2)',
46
+ h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
47
+ )
48
+ else:
49
+ x = rearrange(x, 'b (h w) c -> b c h w', h=height, w=width)
50
+ return x
51
+
52
+
53
+ def global_patching(x, height, width, group_size):
54
+ x = local_patching(x, height, width, height//group_size)
55
+ x = x.transpose(-2, -3)
56
+ return x
57
+
58
+
59
+ def global_merge(x, height, width, group_size):
60
+ x = x.transpose(-2, -3)
61
+ x = local_merge(x, height, width, height//group_size)
62
+ return x
kandinsky3/movq.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+
7
+ from .utils import freeze
8
+
9
+
10
+ def nonlinearity(x):
11
+ return x*torch.sigmoid(x)
12
+
13
+
14
+ class SpatialNorm(nn.Module):
15
+ def __init__(
16
+ self, f_channels, zq_channels=None, norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=False, **norm_layer_params
17
+ ):
18
+ super().__init__()
19
+ self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
20
+ if zq_channels is not None:
21
+ if freeze_norm_layer:
22
+ for p in self.norm_layer.parameters:
23
+ p.requires_grad = False
24
+ self.add_conv = add_conv
25
+ if self.add_conv:
26
+ self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
27
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
28
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
29
+ def forward(self, f, zq=None):
30
+ norm_f = self.norm_layer(f)
31
+ if zq is not None:
32
+ f_size = f.shape[-2:]
33
+ zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
34
+ if self.add_conv:
35
+ zq = self.conv(zq)
36
+ norm_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
37
+ return norm_f
38
+
39
+
40
+ def Normalize(in_channels, zq_ch=None, add_conv=None):
41
+ return SpatialNorm(
42
+ in_channels, zq_ch, norm_layer=nn.GroupNorm,
43
+ freeze_norm_layer=False, add_conv=add_conv, num_groups=32, eps=1e-6, affine=True
44
+ )
45
+
46
+
47
+ class Upsample(nn.Module):
48
+ def __init__(self, in_channels, with_conv):
49
+ super().__init__()
50
+ self.with_conv = with_conv
51
+ if self.with_conv:
52
+ self.conv = torch.nn.Conv2d(in_channels,
53
+ in_channels,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1)
57
+
58
+ def forward(self, x):
59
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
60
+ if self.with_conv:
61
+ x = self.conv(x)
62
+ return x
63
+
64
+
65
+ class Downsample(nn.Module):
66
+ def __init__(self, in_channels, with_conv):
67
+ super().__init__()
68
+ self.with_conv = with_conv
69
+ if self.with_conv:
70
+ self.conv = torch.nn.Conv2d(in_channels,
71
+ in_channels,
72
+ kernel_size=3,
73
+ stride=2,
74
+ padding=0)
75
+
76
+ def forward(self, x):
77
+ if self.with_conv:
78
+ pad = (0,1,0,1)
79
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
80
+ x = self.conv(x)
81
+ else:
82
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
83
+ return x
84
+
85
+
86
+ class ResnetBlock(nn.Module):
87
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
88
+ dropout, temb_channels=512, zq_ch=None, add_conv=False):
89
+ super().__init__()
90
+ self.in_channels = in_channels
91
+ out_channels = in_channels if out_channels is None else out_channels
92
+ self.out_channels = out_channels
93
+ self.use_conv_shortcut = conv_shortcut
94
+
95
+ self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
96
+ self.conv1 = torch.nn.Conv2d(in_channels,
97
+ out_channels,
98
+ kernel_size=3,
99
+ stride=1,
100
+ padding=1)
101
+ if temb_channels > 0:
102
+ self.temb_proj = torch.nn.Linear(temb_channels,
103
+ out_channels)
104
+ self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
105
+ self.dropout = torch.nn.Dropout(dropout)
106
+ self.conv2 = torch.nn.Conv2d(out_channels,
107
+ out_channels,
108
+ kernel_size=3,
109
+ stride=1,
110
+ padding=1)
111
+ if self.in_channels != self.out_channels:
112
+ if self.use_conv_shortcut:
113
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
114
+ out_channels,
115
+ kernel_size=3,
116
+ stride=1,
117
+ padding=1)
118
+ else:
119
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
120
+ out_channels,
121
+ kernel_size=1,
122
+ stride=1,
123
+ padding=0)
124
+
125
+ def forward(self, x, temb, zq=None):
126
+ h = x
127
+ h = self.norm1(h, zq)
128
+ h = nonlinearity(h)
129
+ h = self.conv1(h)
130
+
131
+ if temb is not None:
132
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
133
+
134
+ h = self.norm2(h, zq)
135
+ h = nonlinearity(h)
136
+ h = self.dropout(h)
137
+ h = self.conv2(h)
138
+
139
+ if self.in_channels != self.out_channels:
140
+ if self.use_conv_shortcut:
141
+ x = self.conv_shortcut(x)
142
+ else:
143
+ x = self.nin_shortcut(x)
144
+
145
+ return x+h
146
+
147
+
148
+ class AttnBlock(nn.Module):
149
+ def __init__(self, in_channels, zq_ch=None, add_conv=False):
150
+ super().__init__()
151
+ self.in_channels = in_channels
152
+
153
+ self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
154
+ self.q = torch.nn.Conv2d(in_channels,
155
+ in_channels,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0)
159
+ self.k = torch.nn.Conv2d(in_channels,
160
+ in_channels,
161
+ kernel_size=1,
162
+ stride=1,
163
+ padding=0)
164
+ self.v = torch.nn.Conv2d(in_channels,
165
+ in_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ padding=0)
169
+ self.proj_out = torch.nn.Conv2d(in_channels,
170
+ in_channels,
171
+ kernel_size=1,
172
+ stride=1,
173
+ padding=0)
174
+
175
+
176
+ def forward(self, x, zq=None):
177
+ h_ = x
178
+ h_ = self.norm(h_, zq)
179
+ q = self.q(h_)
180
+ k = self.k(h_)
181
+ v = self.v(h_)
182
+
183
+ # compute attention
184
+ b,c,h,w = q.shape
185
+ q = q.reshape(b,c,h*w)
186
+ q = q.permute(0,2,1) # b,hw,c
187
+ k = k.reshape(b,c,h*w) # b,c,hw
188
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
189
+ w_ = w_ * (int(c)**(-0.5))
190
+ w_ = torch.nn.functional.softmax(w_, dim=2)
191
+
192
+ # attend to values
193
+ v = v.reshape(b,c,h*w)
194
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
195
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
196
+ h_ = h_.reshape(b,c,h,w)
197
+
198
+ h_ = self.proj_out(h_)
199
+
200
+ return x+h_
201
+
202
+
203
+ class Encoder(nn.Module):
204
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
205
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
206
+ resolution, z_channels, double_z=True, **ignore_kwargs):
207
+ super().__init__()
208
+ self.ch = ch
209
+ self.temb_ch = 0
210
+ self.num_resolutions = len(ch_mult)
211
+ self.num_res_blocks = num_res_blocks
212
+ self.resolution = resolution
213
+ self.in_channels = in_channels
214
+
215
+ # downsampling
216
+ self.conv_in = torch.nn.Conv2d(in_channels,
217
+ self.ch,
218
+ kernel_size=3,
219
+ stride=1,
220
+ padding=1)
221
+
222
+ curr_res = resolution
223
+ in_ch_mult = (1,)+tuple(ch_mult)
224
+ self.down = nn.ModuleList()
225
+ for i_level in range(self.num_resolutions):
226
+ block = nn.ModuleList()
227
+ attn = nn.ModuleList()
228
+ block_in = ch*in_ch_mult[i_level]
229
+ block_out = ch*ch_mult[i_level]
230
+ for i_block in range(self.num_res_blocks):
231
+ block.append(ResnetBlock(in_channels=block_in,
232
+ out_channels=block_out,
233
+ temb_channels=self.temb_ch,
234
+ dropout=dropout))
235
+ block_in = block_out
236
+ if curr_res in attn_resolutions:
237
+ attn.append(AttnBlock(block_in))
238
+ down = nn.Module()
239
+ down.block = block
240
+ down.attn = attn
241
+ if i_level != self.num_resolutions-1:
242
+ down.downsample = Downsample(block_in, resamp_with_conv)
243
+ curr_res = curr_res // 2
244
+ self.down.append(down)
245
+
246
+ # middle
247
+ self.mid = nn.Module()
248
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
249
+ out_channels=block_in,
250
+ temb_channels=self.temb_ch,
251
+ dropout=dropout)
252
+ self.mid.attn_1 = AttnBlock(block_in)
253
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
254
+ out_channels=block_in,
255
+ temb_channels=self.temb_ch,
256
+ dropout=dropout)
257
+
258
+ # end
259
+ self.norm_out = Normalize(block_in)
260
+ self.conv_out = torch.nn.Conv2d(block_in,
261
+ 2*z_channels if double_z else z_channels,
262
+ kernel_size=3,
263
+ stride=1,
264
+ padding=1)
265
+
266
+
267
+ def forward(self, x):
268
+ temb = None
269
+
270
+ # downsampling
271
+ hs = [self.conv_in(x)]
272
+ for i_level in range(self.num_resolutions):
273
+ for i_block in range(self.num_res_blocks):
274
+ h = self.down[i_level].block[i_block](hs[-1], temb)
275
+ if len(self.down[i_level].attn) > 0:
276
+ h = self.down[i_level].attn[i_block](h)
277
+ hs.append(h)
278
+ if i_level != self.num_resolutions-1:
279
+ hs.append(self.down[i_level].downsample(hs[-1]))
280
+
281
+ # middle
282
+ h = hs[-1]
283
+ h = self.mid.block_1(h, temb)
284
+ h = self.mid.attn_1(h)
285
+ h = self.mid.block_2(h, temb)
286
+
287
+ # end
288
+ h = self.norm_out(h)
289
+ h = nonlinearity(h)
290
+ h = self.conv_out(h)
291
+ return h
292
+
293
+
294
+ class Decoder(nn.Module):
295
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
296
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
297
+ resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, **ignorekwargs):
298
+ super().__init__()
299
+ self.ch = ch
300
+ self.temb_ch = 0
301
+ self.num_resolutions = len(ch_mult)
302
+ self.num_res_blocks = num_res_blocks
303
+ self.resolution = resolution
304
+ self.in_channels = in_channels
305
+ self.give_pre_end = give_pre_end
306
+
307
+ # compute in_ch_mult, block_in and curr_res at lowest res
308
+ in_ch_mult = (1,)+tuple(ch_mult)
309
+ block_in = ch*ch_mult[self.num_resolutions-1]
310
+ curr_res = resolution // 2**(self.num_resolutions-1)
311
+ self.z_shape = (1,z_channels,curr_res,curr_res)
312
+
313
+ # z to block_in
314
+ self.conv_in = torch.nn.Conv2d(z_channels,
315
+ block_in,
316
+ kernel_size=3,
317
+ stride=1,
318
+ padding=1)
319
+
320
+ # middle
321
+ self.mid = nn.Module()
322
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
323
+ out_channels=block_in,
324
+ temb_channels=self.temb_ch,
325
+ dropout=dropout,
326
+ zq_ch=zq_ch,
327
+ add_conv=add_conv)
328
+ self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
329
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
330
+ out_channels=block_in,
331
+ temb_channels=self.temb_ch,
332
+ dropout=dropout,
333
+ zq_ch=zq_ch,
334
+ add_conv=add_conv)
335
+
336
+ # upsampling
337
+ self.up = nn.ModuleList()
338
+ for i_level in reversed(range(self.num_resolutions)):
339
+ block = nn.ModuleList()
340
+ attn = nn.ModuleList()
341
+ block_out = ch*ch_mult[i_level]
342
+ for i_block in range(self.num_res_blocks+1):
343
+ block.append(ResnetBlock(in_channels=block_in,
344
+ out_channels=block_out,
345
+ temb_channels=self.temb_ch,
346
+ dropout=dropout,
347
+ zq_ch=zq_ch,
348
+ add_conv=add_conv))
349
+ block_in = block_out
350
+ if curr_res in attn_resolutions:
351
+ attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
352
+ up = nn.Module()
353
+ up.block = block
354
+ up.attn = attn
355
+ if i_level != 0:
356
+ up.upsample = Upsample(block_in, resamp_with_conv)
357
+ curr_res = curr_res * 2
358
+ self.up.insert(0, up) # prepend to get consistent order
359
+
360
+ # end
361
+ self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
362
+ self.conv_out = torch.nn.Conv2d(block_in,
363
+ out_ch,
364
+ kernel_size=3,
365
+ stride=1,
366
+ padding=1)
367
+
368
+ def forward(self, z, zq):
369
+ #assert z.shape[1:] == self.z_shape[1:]
370
+ self.last_z_shape = z.shape
371
+
372
+ # timestep embedding
373
+ temb = None
374
+
375
+ # z to block_in
376
+ h = self.conv_in(z)
377
+
378
+ # middle
379
+ h = self.mid.block_1(h, temb, zq)
380
+ h = self.mid.attn_1(h, zq)
381
+ h = self.mid.block_2(h, temb, zq)
382
+
383
+ # upsampling
384
+ for i_level in reversed(range(self.num_resolutions)):
385
+ for i_block in range(self.num_res_blocks+1):
386
+ h = self.up[i_level].block[i_block](h, temb, zq)
387
+ if len(self.up[i_level].attn) > 0:
388
+ h = self.up[i_level].attn[i_block](h, zq)
389
+ if i_level != 0:
390
+ h = self.up[i_level].upsample(h)
391
+
392
+ # end
393
+ if self.give_pre_end:
394
+ return h
395
+
396
+ h = self.norm_out(h, zq)
397
+ h = nonlinearity(h)
398
+ h = self.conv_out(h)
399
+ return h
400
+
401
+
402
+ class MoVQ(nn.Module):
403
+
404
+ def __init__(self, generator_params):
405
+ super().__init__()
406
+ z_channels = generator_params["z_channels"]
407
+ self.encoder = Encoder(**generator_params)
408
+ self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
409
+ self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
410
+ self.decoder = Decoder(zq_ch=z_channels, **generator_params)
411
+
412
+ # @torch.no_grad()
413
+ def encode(self, x):
414
+ h = self.encoder(x)
415
+ h = self.quant_conv(h)
416
+ return h
417
+
418
+ # @torch.no_grad()
419
+ def decode(self, quant):
420
+ decoder_input = self.post_quant_conv(quant)
421
+ decoded = self.decoder(decoder_input, quant)
422
+ return decoded
423
+
424
+
425
+ def get_vae(conf):
426
+ movq = MoVQ(conf.params)
427
+ if conf.checkpoint is not None:
428
+ movq_state_dict = torch.load(conf.checkpoint)
429
+ movq.load_state_dict(movq_state_dict)
430
+ movq = freeze(movq)
431
+ return movq
kandinsky3/setup.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup(
4
+ name="kandinsky3",
5
+ packages=[
6
+ "kandinsky3",
7
+ "kandinsky3/model"
8
+ ],
9
+ install_requires=[
10
+ "timm",
11
+ "torch==1.10.1+cu111",
12
+ "torchvision==0.11.2+cu111",
13
+ "torchaudio==0.10.1",
14
+ "pytorch_lightning==1.7.5",
15
+ "transformers",
16
+ "accelerate",
17
+ "diffusers",
18
+ "setuptools==59.5.0",
19
+ "omegaconf",
20
+ "datasets",
21
+ "einops",
22
+ "webdataset",
23
+ "fsspec",
24
+ "s3fs",
25
+ "hydra-core",
26
+ "scikit-image",
27
+ "matplotlib",
28
+ "wandb",
29
+ "albumentations",
30
+ "bezier",
31
+ "scipy",
32
+ "Pillow",
33
+ "tqdm",
34
+ "huggingface_hub"
35
+
36
+ ],
37
+ author="",
38
+ )
kandinsky3/t2i_pipeline.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+ import PIL
3
+
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from einops import repeat
7
+
8
+ from kandinsky3.model.unet import UNet
9
+ from kandinsky3.movq import MoVQ
10
+ from kandinsky3.condition_encoders import T5TextConditionEncoder
11
+ from kandinsky3.condition_processors import T5TextConditionProcessor
12
+ from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
13
+
14
+
15
+ class Kandinsky3T2IPipeline:
16
+
17
+ def __init__(
18
+ self,
19
+ device_map: Union[str, torch.device, dict],
20
+ dtype_map: Union[str, torch.dtype, dict],
21
+ unet: UNet,
22
+ null_embedding: torch.Tensor,
23
+ t5_processor: T5TextConditionProcessor,
24
+ t5_encoder: T5TextConditionEncoder,
25
+ movq: MoVQ,
26
+ gan: bool,
27
+ ):
28
+ self.device_map = device_map
29
+ self.dtype_map = dtype_map
30
+ self.to_pil = T.ToPILImage()
31
+
32
+ self.unet = unet
33
+ self.null_embedding = null_embedding
34
+ self.t5_processor = t5_processor
35
+ self.t5_encoder = t5_encoder
36
+ self.movq = movq
37
+
38
+ self.gan = gan
39
+
40
+ def __call__(
41
+ self,
42
+ text: str,
43
+ negative_text: str = None,
44
+ images_num: int = 1,
45
+ bs: int = 1,
46
+ width: int = 1024,
47
+ height: int = 1024,
48
+ guidance_scale: float = 3.0,
49
+ steps: int = 50,
50
+ eta: float = 1.0
51
+ ) -> List[PIL.Image.Image]:
52
+
53
+ betas = get_named_beta_schedule('cosine', 1000)
54
+ base_diffusion = BaseDiffusion(betas, 0.99)
55
+ times = list(range(999, 0, -1000 // steps))
56
+ if self.gan:
57
+ times = list(range(979, 0, -250))
58
+
59
+ condition_model_input, negative_condition_model_input = self.t5_processor.encode(text, negative_text)
60
+ for input_type in condition_model_input:
61
+ condition_model_input[input_type] = condition_model_input[input_type][None].to(
62
+ self.device_map['text_encoder']
63
+ )
64
+
65
+ if negative_condition_model_input is not None:
66
+ for input_type in negative_condition_model_input:
67
+ negative_condition_model_input[input_type] = negative_condition_model_input[input_type][None].to(
68
+ self.device_map['text_encoder']
69
+ )
70
+
71
+ pil_images = []
72
+ with torch.no_grad():
73
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
74
+ context, context_mask = self.t5_encoder(condition_model_input)
75
+ if negative_condition_model_input is not None:
76
+ negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
77
+ else:
78
+ negative_context, negative_context_mask = None, None
79
+
80
+ k, m = images_num // bs, images_num % bs
81
+ for minibatch in [bs] * k + [m]:
82
+ if minibatch == 0:
83
+ continue
84
+ bs_context = repeat(context, '1 n d -> b n d', b=minibatch)
85
+ bs_context_mask = repeat(context_mask, '1 n -> b n', b=minibatch)
86
+ if negative_context is not None:
87
+ bs_negative_context = repeat(negative_context, '1 n d -> b n d', b=minibatch)
88
+ bs_negative_context_mask = repeat(negative_context_mask, '1 n -> b n', b=minibatch)
89
+ else:
90
+ bs_negative_context, bs_negative_context_mask = None, None
91
+
92
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
93
+ images = base_diffusion.p_sample_loop(
94
+ self.unet, (minibatch, 4, height // 8, width // 8), times, self.device_map['unet'],
95
+ bs_context, bs_context_mask, self.null_embedding, guidance_scale, eta,
96
+ negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
97
+ gan=self.gan
98
+ )
99
+
100
+ with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
101
+ images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
102
+ # print(torch.max(images), torch.min(images))
103
+ images = torch.clip((images + 1.) / 2., 0., 1.)
104
+ # print(torch.max(images), torch.min(images))
105
+ # raise
106
+ for images_chunk in images.chunk(1):
107
+ pil_images += [self.to_pil(image) for image in images_chunk]
108
+
109
+ return pil_images
kandinsky3/utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import numpy as np
3
+ from scipy import ndimage
4
+ import torch.nn as nn
5
+ from skimage.transform import resize
6
+
7
+
8
+ def load_conf(config_path):
9
+ conf = OmegaConf.load(config_path)
10
+ conf.data.tokens_length = conf.common.tokens_length
11
+ conf.data.processor_names = conf.model.encoders.model_names
12
+ conf.data.dataset.seed = conf.common.seed
13
+ conf.data.dataset.image_size = conf.common.image_size
14
+
15
+ conf.trainer.trainer_params.max_steps = conf.common.train_steps
16
+ conf.scheduler.params.total_steps = conf.common.train_steps
17
+ conf.logger.tensorboard.name = conf.common.experiment_name
18
+
19
+ conf.model.encoders.context_dim = conf.model.unet_params.context_dim
20
+ return conf
21
+
22
+
23
+ def freeze(model):
24
+ for p in model.parameters():
25
+ p.requires_grad = False
26
+ return model
27
+
28
+ def unfreeze(model):
29
+ for p in model.parameters():
30
+ p.requires_grad = True
31
+ return model
32
+
33
+ def zero_module(module):
34
+ for p in module.parameters():
35
+ nn.init.zeros_(p)
36
+ return module
37
+
38
+ def resize_mask_for_diffusion(mask):
39
+ reduce_factor = max(1, (mask.size / 1024**2)**0.5)
40
+ resized_mask = resize(
41
+ mask,
42
+ (
43
+ (round(mask.shape[0] / reduce_factor) // 64) * 64,
44
+ (round(mask.shape[1] / reduce_factor) // 64) * 64
45
+ ),
46
+ preserve_range=True,
47
+ anti_aliasing=False
48
+ )
49
+
50
+ return resized_mask
51
+
52
+ def resize_image_for_diffusion(image):
53
+ reduce_factor = max(1, (image.size[0] * image.size[1] / 1024**2)**0.5)
54
+ image = image.resize((
55
+ (round(image.size[0] / reduce_factor) // 64) * 64, (round(image.size[1] / reduce_factor) // 64) * 64
56
+ ))
57
+
58
+ return image
59
+
60
+ def prepare_mask(mask):
61
+ ker = np.array([[1, 1, 1, 1, 1],
62
+ [1, 5, 5, 5, 1],
63
+ [1, 5, 44, 5, 1],
64
+ [1, 5, 5, 5, 1],
65
+ [1, 1, 1, 1, 1]]) / 100
66
+ out = ndimage.convolve(mask, ker)
67
+ out = ndimage.convolve(out, ker)
68
+ out = ndimage.convolve(out, ker)
69
+
70
+ mask = (out > 0).astype(int)
71
+ return mask
unet_model_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4165a1e48f2a7630729c03c8c7662b6acf8b9f6a590102ba80051c01afb480eb
3
+ size 12154895798