Upload folder using huggingface_hub
Browse files- kandinsky3/.ipynb_checkpoints/__init__-checkpoint.py +267 -0
- kandinsky3/.ipynb_checkpoints/condition_encoders-checkpoint.py +40 -0
- kandinsky3/.ipynb_checkpoints/condition_processors-checkpoint.py +34 -0
- kandinsky3/.ipynb_checkpoints/inpainting_pipeline-checkpoint.py +168 -0
- kandinsky3/.ipynb_checkpoints/movq-checkpoint.py +431 -0
- kandinsky3/.ipynb_checkpoints/t2i_pipeline-checkpoint.py +109 -0
- kandinsky3/.ipynb_checkpoints/utils-checkpoint.py +71 -0
- kandinsky3/__init__.py +267 -0
- kandinsky3/__pycache__/__init__.cpython-310.pyc +0 -0
- kandinsky3/__pycache__/condition_encoders.cpython-310.pyc +0 -0
- kandinsky3/__pycache__/condition_processors.cpython-310.pyc +0 -0
- kandinsky3/__pycache__/inpainting_pipeline.cpython-310.pyc +0 -0
- kandinsky3/__pycache__/movq.cpython-310.pyc +0 -0
- kandinsky3/__pycache__/t2i_pipeline.cpython-310.pyc +0 -0
- kandinsky3/__pycache__/utils.cpython-310.pyc +0 -0
- kandinsky3/condition_encoders.py +40 -0
- kandinsky3/condition_processors.py +34 -0
- kandinsky3/inpainting_pipeline.py +168 -0
- kandinsky3/model/.ipynb_checkpoints/diffusion-checkpoint.py +200 -0
- kandinsky3/model/.ipynb_checkpoints/unet-checkpoint.py +516 -0
- kandinsky3/model/__init__.py +0 -0
- kandinsky3/model/__pycache__/__init__.cpython-310.pyc +0 -0
- kandinsky3/model/__pycache__/diffusion.cpython-310.pyc +0 -0
- kandinsky3/model/__pycache__/nn.cpython-310.pyc +0 -0
- kandinsky3/model/__pycache__/unet.cpython-310.pyc +0 -0
- kandinsky3/model/__pycache__/utils.cpython-310.pyc +0 -0
- kandinsky3/model/diffusion.py +200 -0
- kandinsky3/model/nn.py +84 -0
- kandinsky3/model/unet.py +516 -0
- kandinsky3/model/utils.py +62 -0
- kandinsky3/movq.py +431 -0
- kandinsky3/setup.py +38 -0
- kandinsky3/t2i_pipeline.py +109 -0
- kandinsky3/utils.py +71 -0
- unet_model_checkpoint.pt +3 -0
kandinsky3/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 6 |
+
|
| 7 |
+
from kandinsky3.model.unet import UNet
|
| 8 |
+
from kandinsky3.movq import MoVQ
|
| 9 |
+
from kandinsky3.condition_encoders import T5TextConditionEncoder
|
| 10 |
+
from kandinsky3.condition_processors import T5TextConditionProcessor
|
| 11 |
+
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
|
| 12 |
+
|
| 13 |
+
from .t2i_pipeline import Kandinsky3T2IPipeline
|
| 14 |
+
from .inpainting_pipeline import Kandinsky3InpaintingPipeline
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_T2I_unet(
|
| 18 |
+
device: Union[str, torch.device],
|
| 19 |
+
weights_path: Optional[str] = None,
|
| 20 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
| 21 |
+
) -> (UNet, Optional[torch.Tensor], Optional[dict]):
|
| 22 |
+
unet = UNet(
|
| 23 |
+
model_channels=384,
|
| 24 |
+
num_channels=4,
|
| 25 |
+
init_channels=192,
|
| 26 |
+
time_embed_dim=1536,
|
| 27 |
+
context_dim=4096,
|
| 28 |
+
groups=32,
|
| 29 |
+
head_dim=64,
|
| 30 |
+
expansion_ratio=4,
|
| 31 |
+
compression_ratio=2,
|
| 32 |
+
dim_mult=(1, 2, 4, 8),
|
| 33 |
+
num_blocks=(3, 3, 3, 3),
|
| 34 |
+
add_cross_attention=(False, True, True, True),
|
| 35 |
+
add_self_attention=(False, True, True, True),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
null_embedding = None
|
| 39 |
+
if weights_path:
|
| 40 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 41 |
+
null_embedding = state_dict['null_embedding']
|
| 42 |
+
unet.load_state_dict(state_dict['unet'])
|
| 43 |
+
|
| 44 |
+
unet.to(device=device, dtype=dtype).eval()
|
| 45 |
+
return unet, null_embedding
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_T5encoder(
|
| 49 |
+
device: Union[str, torch.device],
|
| 50 |
+
weights_path: str,
|
| 51 |
+
projection_name: str,
|
| 52 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
| 53 |
+
low_cpu_mem_usage: bool = True,
|
| 54 |
+
load_in_8bit: bool = False,
|
| 55 |
+
load_in_4bit: bool = False,
|
| 56 |
+
) -> (T5TextConditionProcessor, T5TextConditionEncoder):
|
| 57 |
+
tokens_length = 128
|
| 58 |
+
context_dim = 4096
|
| 59 |
+
processor = T5TextConditionProcessor(tokens_length, weights_path)
|
| 60 |
+
condition_encoder = T5TextConditionEncoder(
|
| 61 |
+
weights_path, context_dim, low_cpu_mem_usage=low_cpu_mem_usage, device=device,
|
| 62 |
+
dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if weights_path:
|
| 66 |
+
projections_weights_path = os.path.join(weights_path, projection_name)
|
| 67 |
+
state_dict = torch.load(projections_weights_path, map_location=torch.device('cpu'))
|
| 68 |
+
condition_encoder.projection.load_state_dict(state_dict)
|
| 69 |
+
|
| 70 |
+
condition_encoder.projection.to(device=device, dtype=dtype).eval()
|
| 71 |
+
return processor, condition_encoder
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_movq(
|
| 75 |
+
device: Union[str, torch.device],
|
| 76 |
+
weights_path: Optional[str] = None,
|
| 77 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
| 78 |
+
) -> MoVQ:
|
| 79 |
+
generator_config = {
|
| 80 |
+
'double_z': False,
|
| 81 |
+
'z_channels': 4,
|
| 82 |
+
'resolution': 256,
|
| 83 |
+
'in_channels': 3,
|
| 84 |
+
'out_ch': 3,
|
| 85 |
+
'ch': 256,
|
| 86 |
+
'ch_mult': [1, 2, 2, 4],
|
| 87 |
+
'num_res_blocks': 2,
|
| 88 |
+
'attn_resolutions': [32],
|
| 89 |
+
'dropout': 0.0
|
| 90 |
+
}
|
| 91 |
+
movq = MoVQ(generator_config)
|
| 92 |
+
|
| 93 |
+
if weights_path:
|
| 94 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 95 |
+
movq.load_state_dict(state_dict)
|
| 96 |
+
|
| 97 |
+
movq.to(device=device, dtype=dtype).eval()
|
| 98 |
+
return movq
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_inpainting_unet(
|
| 102 |
+
device: Union[str, torch.device],
|
| 103 |
+
weights_path: Optional[str] = None,
|
| 104 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
| 105 |
+
) -> (UNet, Optional[torch.Tensor], Optional[dict]):
|
| 106 |
+
unet = UNet(
|
| 107 |
+
model_channels=384,
|
| 108 |
+
num_channels=9,
|
| 109 |
+
init_channels=192,
|
| 110 |
+
time_embed_dim=1536,
|
| 111 |
+
context_dim=4096,
|
| 112 |
+
groups=32,
|
| 113 |
+
head_dim=64,
|
| 114 |
+
expansion_ratio=4,
|
| 115 |
+
compression_ratio=2,
|
| 116 |
+
dim_mult=(1, 2, 4, 8),
|
| 117 |
+
num_blocks=(3, 3, 3, 3),
|
| 118 |
+
add_cross_attention=(False, True, True, True),
|
| 119 |
+
add_self_attention=(False, True, True, True),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
null_embedding = None
|
| 123 |
+
if weights_path:
|
| 124 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 125 |
+
null_embedding = state_dict['null_embedding']
|
| 126 |
+
unet.load_state_dict(state_dict['unet'])
|
| 127 |
+
|
| 128 |
+
unet.to(device=device, dtype=dtype).eval()
|
| 129 |
+
return unet, null_embedding
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_T2I_pipeline(
|
| 133 |
+
device_map: Union[str, torch.device, dict],
|
| 134 |
+
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
|
| 135 |
+
low_cpu_mem_usage: bool = True,
|
| 136 |
+
load_in_8bit: bool = False,
|
| 137 |
+
load_in_4bit: bool = False,
|
| 138 |
+
cache_dir: str = '/tmp/kandinsky3/',
|
| 139 |
+
unet_path: str = None,
|
| 140 |
+
text_encoder_path: str = None,
|
| 141 |
+
movq_path: str = None,
|
| 142 |
+
) -> Kandinsky3T2IPipeline:
|
| 143 |
+
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
|
| 144 |
+
if not isinstance(device_map, dict):
|
| 145 |
+
device_map = {
|
| 146 |
+
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
|
| 147 |
+
}
|
| 148 |
+
if not isinstance(dtype_map, dict):
|
| 149 |
+
dtype_map = {
|
| 150 |
+
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
if unet_path is None:
|
| 154 |
+
unet_path = hf_hub_download(
|
| 155 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3.pt', cache_dir=cache_dir
|
| 156 |
+
)
|
| 157 |
+
if text_encoder_path is None:
|
| 158 |
+
text_encoder_path = snapshot_download(
|
| 159 |
+
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
|
| 160 |
+
)
|
| 161 |
+
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
|
| 162 |
+
if movq_path is None:
|
| 163 |
+
movq_path = hf_hub_download(
|
| 164 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
|
| 168 |
+
processor, condition_encoder = get_T5encoder(
|
| 169 |
+
device_map['text_encoder'], text_encoder_path, 'projection.pt', dtype=dtype_map['text_encoder'],
|
| 170 |
+
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
|
| 171 |
+
)
|
| 172 |
+
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
|
| 173 |
+
return Kandinsky3T2IPipeline(
|
| 174 |
+
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, False
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_T2I_Flash_pipeline(
|
| 179 |
+
device_map: Union[str, torch.device, dict],
|
| 180 |
+
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
|
| 181 |
+
low_cpu_mem_usage: bool = True,
|
| 182 |
+
load_in_8bit: bool = False,
|
| 183 |
+
load_in_4bit: bool = False,
|
| 184 |
+
cache_dir: str = '/tmp/kandinsky3/',
|
| 185 |
+
unet_path: str = None,
|
| 186 |
+
text_encoder_path: str = None,
|
| 187 |
+
movq_path: str = None,
|
| 188 |
+
) -> Kandinsky3T2IPipeline:
|
| 189 |
+
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
|
| 190 |
+
if not isinstance(device_map, dict):
|
| 191 |
+
device_map = {
|
| 192 |
+
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
|
| 193 |
+
}
|
| 194 |
+
if not isinstance(dtype_map, dict):
|
| 195 |
+
dtype_map = {
|
| 196 |
+
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
if unet_path is None:
|
| 200 |
+
unet_path = hf_hub_download(
|
| 201 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir
|
| 202 |
+
)
|
| 203 |
+
if text_encoder_path is None:
|
| 204 |
+
text_encoder_path = snapshot_download(
|
| 205 |
+
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
|
| 206 |
+
)
|
| 207 |
+
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
|
| 208 |
+
if movq_path is None:
|
| 209 |
+
movq_path = hf_hub_download(
|
| 210 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
|
| 214 |
+
processor, condition_encoder = get_T5encoder(
|
| 215 |
+
device_map['text_encoder'], text_encoder_path, 'projection_flash.pt', dtype=dtype_map['text_encoder'],
|
| 216 |
+
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
|
| 217 |
+
)
|
| 218 |
+
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
|
| 219 |
+
return Kandinsky3T2IPipeline(
|
| 220 |
+
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, True
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_inpainting_pipeline(
|
| 225 |
+
device_map: Union[str, torch.device, dict],
|
| 226 |
+
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
|
| 227 |
+
low_cpu_mem_usage: bool = True,
|
| 228 |
+
load_in_8bit: bool = False,
|
| 229 |
+
load_in_4bit: bool = False,
|
| 230 |
+
cache_dir: str = '/tmp/kandinsky3/',
|
| 231 |
+
unet_path: str = None,
|
| 232 |
+
text_encoder_path: str = None,
|
| 233 |
+
movq_path: str = None,
|
| 234 |
+
) -> Kandinsky3InpaintingPipeline:
|
| 235 |
+
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
|
| 236 |
+
if not isinstance(device_map, dict):
|
| 237 |
+
device_map = {
|
| 238 |
+
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
|
| 239 |
+
}
|
| 240 |
+
if not isinstance(dtype_map, dict):
|
| 241 |
+
dtype_map = {
|
| 242 |
+
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
if unet_path is None:
|
| 246 |
+
unet_path = hf_hub_download(
|
| 247 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_inpainting.pt', cache_dir=cache_dir
|
| 248 |
+
)
|
| 249 |
+
if text_encoder_path is None:
|
| 250 |
+
text_encoder_path = snapshot_download(
|
| 251 |
+
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
|
| 252 |
+
)
|
| 253 |
+
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
|
| 254 |
+
if movq_path is None:
|
| 255 |
+
movq_path = hf_hub_download(
|
| 256 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
unet, null_embedding = get_inpainting_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
|
| 260 |
+
processor, condition_encoder = get_T5encoder(
|
| 261 |
+
device_map['text_encoder'], text_encoder_path, 'projection_inpainting.pt', dtype=dtype_map['text_encoder'],
|
| 262 |
+
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
|
| 263 |
+
)
|
| 264 |
+
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
|
| 265 |
+
return Kandinsky3InpaintingPipeline(
|
| 266 |
+
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq
|
| 267 |
+
)
|
kandinsky3/.ipynb_checkpoints/condition_encoders-checkpoint.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from transformers import T5EncoderModel
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class T5TextConditionEncoder(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self, model_path, context_dim,
|
| 11 |
+
low_cpu_mem_usage: bool = True, device: Optional[str] = None,
|
| 12 |
+
dtype: Union[str, torch.dtype] = torch.float32, load_in_4bit: bool = False, load_in_8bit: bool = False
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.encoder = T5EncoderModel.from_pretrained(
|
| 16 |
+
model_path, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device,
|
| 17 |
+
torch_dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit,
|
| 18 |
+
).encoder
|
| 19 |
+
self.projection = nn.Sequential(
|
| 20 |
+
nn.Linear(self.encoder.config.d_model, context_dim, bias=False),
|
| 21 |
+
nn.LayerNorm(context_dim)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, model_input):
|
| 25 |
+
embeddings = self.encoder(**model_input).last_hidden_state
|
| 26 |
+
context = self.projection(embeddings)
|
| 27 |
+
if 'attention_mask' in model_input:
|
| 28 |
+
context_mask = model_input['attention_mask']
|
| 29 |
+
context[context_mask == 0] = torch.zeros_like(context[context_mask == 0])
|
| 30 |
+
max_seq_length = context_mask.sum(-1).max() + 1
|
| 31 |
+
context = context[:, :max_seq_length]
|
| 32 |
+
context_mask = context_mask[:, :max_seq_length]
|
| 33 |
+
else:
|
| 34 |
+
context_mask = torch.ones(*embeddings.shape[:-1], dtype=torch.long, device=embeddings.device)
|
| 35 |
+
return context, context_mask
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_condition_encoder(conf):
|
| 39 |
+
return T5TextConditionEncoder(**conf)
|
| 40 |
+
|
kandinsky3/.ipynb_checkpoints/condition_processors-checkpoint.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import T5Tokenizer
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class T5TextConditionProcessor:
|
| 6 |
+
|
| 7 |
+
def __init__(self, tokens_length, processor_path):
|
| 8 |
+
self.tokens_length = tokens_length
|
| 9 |
+
self.processor = T5Tokenizer.from_pretrained(processor_path)
|
| 10 |
+
|
| 11 |
+
def encode(self, text=None, negative_text=None):
|
| 12 |
+
encoded = self.processor(text, max_length=self.tokens_length, truncation=True)
|
| 13 |
+
pad_length = self.tokens_length - len(encoded['input_ids'])
|
| 14 |
+
input_ids = encoded['input_ids'] + [self.processor.pad_token_id] * pad_length
|
| 15 |
+
attention_mask = encoded['attention_mask'] + [0] * pad_length
|
| 16 |
+
condition_model_input = {
|
| 17 |
+
'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
| 18 |
+
'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
if negative_text is not None:
|
| 22 |
+
negative_encoded = self.processor(negative_text, max_length=self.tokens_length, truncation=True)
|
| 23 |
+
negative_input_ids = negative_encoded['input_ids'][:len(encoded['input_ids'])]
|
| 24 |
+
negative_input_ids[-1] = self.processor.eos_token_id
|
| 25 |
+
negative_pad_length = self.tokens_length - len(negative_input_ids)
|
| 26 |
+
negative_input_ids = negative_input_ids + [self.processor.pad_token_id] * negative_pad_length
|
| 27 |
+
negative_attention_mask = encoded['attention_mask'] + [0] * pad_length
|
| 28 |
+
negative_condition_model_input = {
|
| 29 |
+
'input_ids': torch.tensor(negative_input_ids, dtype=torch.long),
|
| 30 |
+
'attention_mask': torch.tensor(negative_attention_mask, dtype=torch.long)
|
| 31 |
+
}
|
| 32 |
+
else:
|
| 33 |
+
negative_condition_model_input = None
|
| 34 |
+
return condition_model_input, negative_condition_model_input
|
kandinsky3/.ipynb_checkpoints/inpainting_pipeline-checkpoint.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, List
|
| 2 |
+
import PIL
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms as T
|
| 7 |
+
from einops import repeat
|
| 8 |
+
|
| 9 |
+
from kandinsky3.model.unet import UNet
|
| 10 |
+
from kandinsky3.movq import MoVQ
|
| 11 |
+
from kandinsky3.condition_encoders import T5TextConditionEncoder
|
| 12 |
+
from kandinsky3.condition_processors import T5TextConditionProcessor
|
| 13 |
+
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
|
| 14 |
+
from kandinsky3.utils import resize_image_for_diffusion, resize_mask_for_diffusion
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Kandinsky3InpaintingPipeline:
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
device_map: Union[str, torch.device, dict],
|
| 22 |
+
dtype_map: Union[str, torch.dtype, dict],
|
| 23 |
+
unet: UNet,
|
| 24 |
+
null_embedding: torch.Tensor,
|
| 25 |
+
t5_processor: T5TextConditionProcessor,
|
| 26 |
+
t5_encoder: T5TextConditionEncoder,
|
| 27 |
+
movq: MoVQ,
|
| 28 |
+
):
|
| 29 |
+
self.device_map = device_map
|
| 30 |
+
self.dtype_map = dtype_map
|
| 31 |
+
self.to_pil = T.ToPILImage()
|
| 32 |
+
self.to_tensor = T.ToTensor()
|
| 33 |
+
|
| 34 |
+
self.unet = unet
|
| 35 |
+
self.null_embedding = null_embedding
|
| 36 |
+
self.t5_processor = t5_processor
|
| 37 |
+
self.t5_encoder = t5_encoder
|
| 38 |
+
self.movq = movq
|
| 39 |
+
|
| 40 |
+
def shared_step(self, batch: dict) -> dict:
|
| 41 |
+
image = batch['image']
|
| 42 |
+
condition_model_input = batch['text']
|
| 43 |
+
negative_condition_model_input = batch['negative_text']
|
| 44 |
+
|
| 45 |
+
bs = image.shape[0]
|
| 46 |
+
|
| 47 |
+
masked_latent = None
|
| 48 |
+
mask = batch['mask']
|
| 49 |
+
|
| 50 |
+
if 'masked_image' in batch:
|
| 51 |
+
masked_latent = batch['masked_image']
|
| 52 |
+
elif self.unet.in_layer.in_channels == 9:
|
| 53 |
+
masked_latent = image.masked_fill((1 - mask).bool(), 0)
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError()
|
| 56 |
+
|
| 57 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
|
| 58 |
+
masked_latent = self.movq.encode(masked_latent)
|
| 59 |
+
mask = torch.nn.functional.interpolate(mask, size=(masked_latent.shape[2], masked_latent.shape[3]))
|
| 60 |
+
|
| 61 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
|
| 62 |
+
context, context_mask = self.t5_encoder(condition_model_input)
|
| 63 |
+
|
| 64 |
+
if negative_condition_model_input is not None:
|
| 65 |
+
negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
|
| 66 |
+
else:
|
| 67 |
+
negative_context, negative_context_mask = None, None
|
| 68 |
+
|
| 69 |
+
return {
|
| 70 |
+
'context': context,
|
| 71 |
+
'context_mask': context_mask,
|
| 72 |
+
'negative_context': negative_context,
|
| 73 |
+
'negative_context_mask': negative_context_mask,
|
| 74 |
+
'image': image,
|
| 75 |
+
'masked_latent': masked_latent,
|
| 76 |
+
'mask': mask
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def prepare_batch(
|
| 80 |
+
self,
|
| 81 |
+
text: str,
|
| 82 |
+
negative_text: str,
|
| 83 |
+
image: PIL.Image.Image,
|
| 84 |
+
mask: np.ndarray,
|
| 85 |
+
) -> dict:
|
| 86 |
+
condition_model_input, negative_condition_model_input = self.t5_processor.encode(
|
| 87 |
+
text=text, negative_text=negative_text
|
| 88 |
+
)
|
| 89 |
+
batch = {
|
| 90 |
+
'image': self.to_tensor(resize_image_for_diffusion(image.convert("RGB"))) * 2 - 1,
|
| 91 |
+
'mask': 1 - self.to_tensor(resize_mask_for_diffusion(mask)),
|
| 92 |
+
'text': condition_model_input,
|
| 93 |
+
'negative_text': negative_condition_model_input
|
| 94 |
+
}
|
| 95 |
+
batch['mask'] = batch['mask'].type(self.dtype_map['movq'])
|
| 96 |
+
|
| 97 |
+
batch['image'] = batch['image'].unsqueeze(0).to(self.device_map['movq'])
|
| 98 |
+
batch['text']['input_ids'] = batch['text']['input_ids'].unsqueeze(0).to(self.device_map['text_encoder'])
|
| 99 |
+
batch['text']['attention_mask'] = batch['text']['attention_mask'].unsqueeze(0).to(
|
| 100 |
+
self.device_map['text_encoder'])
|
| 101 |
+
batch['mask'] = batch['mask'].unsqueeze(0).to(self.device_map['movq'])
|
| 102 |
+
|
| 103 |
+
if negative_condition_model_input is not None:
|
| 104 |
+
batch['negative_text']['input_ids'] = batch['negative_text']['input_ids'].to(
|
| 105 |
+
self.device_map['text_encoder'])
|
| 106 |
+
batch['negative_text']['attention_mask'] = batch['negative_text']['attention_mask'].to(
|
| 107 |
+
self.device_map['text_encoder'])
|
| 108 |
+
|
| 109 |
+
return batch
|
| 110 |
+
|
| 111 |
+
def __call__(
|
| 112 |
+
self,
|
| 113 |
+
text: str,
|
| 114 |
+
image: PIL.Image.Image,
|
| 115 |
+
mask: np.ndarray,
|
| 116 |
+
negative_text: str = None,
|
| 117 |
+
images_num: int = 1,
|
| 118 |
+
bs: int = 1,
|
| 119 |
+
steps: int = 50,
|
| 120 |
+
guidance_weight_text: float = 4,
|
| 121 |
+
eta=1.0
|
| 122 |
+
) -> List[PIL.Image.Image]:
|
| 123 |
+
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
batch = self.prepare_batch(text, negative_text, image, mask)
|
| 126 |
+
processed = self.shared_step(batch)
|
| 127 |
+
betas = get_named_beta_schedule('cosine', 1000)
|
| 128 |
+
base_diffusion = BaseDiffusion(betas, percentile=0.95)
|
| 129 |
+
times = list(range(999, 0, -1000 // steps))
|
| 130 |
+
|
| 131 |
+
pil_images = []
|
| 132 |
+
k, m = images_num // bs, images_num % bs
|
| 133 |
+
for minibatch in [bs] * k + [m]:
|
| 134 |
+
if minibatch == 0:
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
bs_context = repeat(processed['context'], '1 n d -> b n d', b=minibatch)
|
| 138 |
+
bs_context_mask = repeat(processed['context_mask'], '1 n -> b n', b=minibatch)
|
| 139 |
+
|
| 140 |
+
if processed['negative_context'] is not None:
|
| 141 |
+
bs_negative_context = repeat(processed['negative_context'], '1 n d -> b n d', b=minibatch)
|
| 142 |
+
bs_negative_context_mask = repeat(processed['negative_context_mask'], '1 n -> b n', b=minibatch)
|
| 143 |
+
else:
|
| 144 |
+
bs_negative_context, bs_negative_context_mask = None, None
|
| 145 |
+
|
| 146 |
+
mask = processed['mask'].repeat_interleave(minibatch, dim=0)
|
| 147 |
+
masked_latent = processed['masked_latent'].repeat_interleave(minibatch, dim=0)
|
| 148 |
+
|
| 149 |
+
minibatch = masked_latent.shape[0]
|
| 150 |
+
|
| 151 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
images = base_diffusion.p_sample_loop(
|
| 154 |
+
self.unet, (minibatch, 4, masked_latent.shape[2], masked_latent.shape[3]), times,
|
| 155 |
+
self.device_map['unet'],
|
| 156 |
+
bs_context, bs_context_mask, self.null_embedding, guidance_weight_text, eta,
|
| 157 |
+
negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
|
| 158 |
+
mask=mask, masked_latent=masked_latent, gan=False
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
|
| 162 |
+
images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
|
| 163 |
+
images = torch.clip((images + 1.) / 2., 0., 1.).cpu()
|
| 164 |
+
|
| 165 |
+
for images_chunk in images.chunk(1):
|
| 166 |
+
pil_images += [self.to_pil(image) for image in images_chunk]
|
| 167 |
+
|
| 168 |
+
return pil_images
|
kandinsky3/.ipynb_checkpoints/movq-checkpoint.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from .utils import freeze
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def nonlinearity(x):
|
| 11 |
+
return x*torch.sigmoid(x)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SpatialNorm(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self, f_channels, zq_channels=None, norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=False, **norm_layer_params
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
| 20 |
+
if zq_channels is not None:
|
| 21 |
+
if freeze_norm_layer:
|
| 22 |
+
for p in self.norm_layer.parameters:
|
| 23 |
+
p.requires_grad = False
|
| 24 |
+
self.add_conv = add_conv
|
| 25 |
+
if self.add_conv:
|
| 26 |
+
self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
|
| 27 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
| 28 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
| 29 |
+
def forward(self, f, zq=None):
|
| 30 |
+
norm_f = self.norm_layer(f)
|
| 31 |
+
if zq is not None:
|
| 32 |
+
f_size = f.shape[-2:]
|
| 33 |
+
zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
|
| 34 |
+
if self.add_conv:
|
| 35 |
+
zq = self.conv(zq)
|
| 36 |
+
norm_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
| 37 |
+
return norm_f
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def Normalize(in_channels, zq_ch=None, add_conv=None):
|
| 41 |
+
return SpatialNorm(
|
| 42 |
+
in_channels, zq_ch, norm_layer=nn.GroupNorm,
|
| 43 |
+
freeze_norm_layer=False, add_conv=add_conv, num_groups=32, eps=1e-6, affine=True
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Upsample(nn.Module):
|
| 48 |
+
def __init__(self, in_channels, with_conv):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.with_conv = with_conv
|
| 51 |
+
if self.with_conv:
|
| 52 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 53 |
+
in_channels,
|
| 54 |
+
kernel_size=3,
|
| 55 |
+
stride=1,
|
| 56 |
+
padding=1)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 60 |
+
if self.with_conv:
|
| 61 |
+
x = self.conv(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Downsample(nn.Module):
|
| 66 |
+
def __init__(self, in_channels, with_conv):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.with_conv = with_conv
|
| 69 |
+
if self.with_conv:
|
| 70 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 71 |
+
in_channels,
|
| 72 |
+
kernel_size=3,
|
| 73 |
+
stride=2,
|
| 74 |
+
padding=0)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
if self.with_conv:
|
| 78 |
+
pad = (0,1,0,1)
|
| 79 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 80 |
+
x = self.conv(x)
|
| 81 |
+
else:
|
| 82 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ResnetBlock(nn.Module):
|
| 87 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
| 88 |
+
dropout, temb_channels=512, zq_ch=None, add_conv=False):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.in_channels = in_channels
|
| 91 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 92 |
+
self.out_channels = out_channels
|
| 93 |
+
self.use_conv_shortcut = conv_shortcut
|
| 94 |
+
|
| 95 |
+
self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
|
| 96 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 97 |
+
out_channels,
|
| 98 |
+
kernel_size=3,
|
| 99 |
+
stride=1,
|
| 100 |
+
padding=1)
|
| 101 |
+
if temb_channels > 0:
|
| 102 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
| 103 |
+
out_channels)
|
| 104 |
+
self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
|
| 105 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 106 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 107 |
+
out_channels,
|
| 108 |
+
kernel_size=3,
|
| 109 |
+
stride=1,
|
| 110 |
+
padding=1)
|
| 111 |
+
if self.in_channels != self.out_channels:
|
| 112 |
+
if self.use_conv_shortcut:
|
| 113 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
| 114 |
+
out_channels,
|
| 115 |
+
kernel_size=3,
|
| 116 |
+
stride=1,
|
| 117 |
+
padding=1)
|
| 118 |
+
else:
|
| 119 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 120 |
+
out_channels,
|
| 121 |
+
kernel_size=1,
|
| 122 |
+
stride=1,
|
| 123 |
+
padding=0)
|
| 124 |
+
|
| 125 |
+
def forward(self, x, temb, zq=None):
|
| 126 |
+
h = x
|
| 127 |
+
h = self.norm1(h, zq)
|
| 128 |
+
h = nonlinearity(h)
|
| 129 |
+
h = self.conv1(h)
|
| 130 |
+
|
| 131 |
+
if temb is not None:
|
| 132 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
| 133 |
+
|
| 134 |
+
h = self.norm2(h, zq)
|
| 135 |
+
h = nonlinearity(h)
|
| 136 |
+
h = self.dropout(h)
|
| 137 |
+
h = self.conv2(h)
|
| 138 |
+
|
| 139 |
+
if self.in_channels != self.out_channels:
|
| 140 |
+
if self.use_conv_shortcut:
|
| 141 |
+
x = self.conv_shortcut(x)
|
| 142 |
+
else:
|
| 143 |
+
x = self.nin_shortcut(x)
|
| 144 |
+
|
| 145 |
+
return x+h
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class AttnBlock(nn.Module):
|
| 149 |
+
def __init__(self, in_channels, zq_ch=None, add_conv=False):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.in_channels = in_channels
|
| 152 |
+
|
| 153 |
+
self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
|
| 154 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 155 |
+
in_channels,
|
| 156 |
+
kernel_size=1,
|
| 157 |
+
stride=1,
|
| 158 |
+
padding=0)
|
| 159 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 160 |
+
in_channels,
|
| 161 |
+
kernel_size=1,
|
| 162 |
+
stride=1,
|
| 163 |
+
padding=0)
|
| 164 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 165 |
+
in_channels,
|
| 166 |
+
kernel_size=1,
|
| 167 |
+
stride=1,
|
| 168 |
+
padding=0)
|
| 169 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 170 |
+
in_channels,
|
| 171 |
+
kernel_size=1,
|
| 172 |
+
stride=1,
|
| 173 |
+
padding=0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def forward(self, x, zq=None):
|
| 177 |
+
h_ = x
|
| 178 |
+
h_ = self.norm(h_, zq)
|
| 179 |
+
q = self.q(h_)
|
| 180 |
+
k = self.k(h_)
|
| 181 |
+
v = self.v(h_)
|
| 182 |
+
|
| 183 |
+
# compute attention
|
| 184 |
+
b,c,h,w = q.shape
|
| 185 |
+
q = q.reshape(b,c,h*w)
|
| 186 |
+
q = q.permute(0,2,1) # b,hw,c
|
| 187 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
| 188 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 189 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 190 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 191 |
+
|
| 192 |
+
# attend to values
|
| 193 |
+
v = v.reshape(b,c,h*w)
|
| 194 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
| 195 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 196 |
+
h_ = h_.reshape(b,c,h,w)
|
| 197 |
+
|
| 198 |
+
h_ = self.proj_out(h_)
|
| 199 |
+
|
| 200 |
+
return x+h_
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class Encoder(nn.Module):
|
| 204 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 205 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 206 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.ch = ch
|
| 209 |
+
self.temb_ch = 0
|
| 210 |
+
self.num_resolutions = len(ch_mult)
|
| 211 |
+
self.num_res_blocks = num_res_blocks
|
| 212 |
+
self.resolution = resolution
|
| 213 |
+
self.in_channels = in_channels
|
| 214 |
+
|
| 215 |
+
# downsampling
|
| 216 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
| 217 |
+
self.ch,
|
| 218 |
+
kernel_size=3,
|
| 219 |
+
stride=1,
|
| 220 |
+
padding=1)
|
| 221 |
+
|
| 222 |
+
curr_res = resolution
|
| 223 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 224 |
+
self.down = nn.ModuleList()
|
| 225 |
+
for i_level in range(self.num_resolutions):
|
| 226 |
+
block = nn.ModuleList()
|
| 227 |
+
attn = nn.ModuleList()
|
| 228 |
+
block_in = ch*in_ch_mult[i_level]
|
| 229 |
+
block_out = ch*ch_mult[i_level]
|
| 230 |
+
for i_block in range(self.num_res_blocks):
|
| 231 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 232 |
+
out_channels=block_out,
|
| 233 |
+
temb_channels=self.temb_ch,
|
| 234 |
+
dropout=dropout))
|
| 235 |
+
block_in = block_out
|
| 236 |
+
if curr_res in attn_resolutions:
|
| 237 |
+
attn.append(AttnBlock(block_in))
|
| 238 |
+
down = nn.Module()
|
| 239 |
+
down.block = block
|
| 240 |
+
down.attn = attn
|
| 241 |
+
if i_level != self.num_resolutions-1:
|
| 242 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 243 |
+
curr_res = curr_res // 2
|
| 244 |
+
self.down.append(down)
|
| 245 |
+
|
| 246 |
+
# middle
|
| 247 |
+
self.mid = nn.Module()
|
| 248 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 249 |
+
out_channels=block_in,
|
| 250 |
+
temb_channels=self.temb_ch,
|
| 251 |
+
dropout=dropout)
|
| 252 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 253 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 254 |
+
out_channels=block_in,
|
| 255 |
+
temb_channels=self.temb_ch,
|
| 256 |
+
dropout=dropout)
|
| 257 |
+
|
| 258 |
+
# end
|
| 259 |
+
self.norm_out = Normalize(block_in)
|
| 260 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 261 |
+
2*z_channels if double_z else z_channels,
|
| 262 |
+
kernel_size=3,
|
| 263 |
+
stride=1,
|
| 264 |
+
padding=1)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
temb = None
|
| 269 |
+
|
| 270 |
+
# downsampling
|
| 271 |
+
hs = [self.conv_in(x)]
|
| 272 |
+
for i_level in range(self.num_resolutions):
|
| 273 |
+
for i_block in range(self.num_res_blocks):
|
| 274 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 275 |
+
if len(self.down[i_level].attn) > 0:
|
| 276 |
+
h = self.down[i_level].attn[i_block](h)
|
| 277 |
+
hs.append(h)
|
| 278 |
+
if i_level != self.num_resolutions-1:
|
| 279 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 280 |
+
|
| 281 |
+
# middle
|
| 282 |
+
h = hs[-1]
|
| 283 |
+
h = self.mid.block_1(h, temb)
|
| 284 |
+
h = self.mid.attn_1(h)
|
| 285 |
+
h = self.mid.block_2(h, temb)
|
| 286 |
+
|
| 287 |
+
# end
|
| 288 |
+
h = self.norm_out(h)
|
| 289 |
+
h = nonlinearity(h)
|
| 290 |
+
h = self.conv_out(h)
|
| 291 |
+
return h
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class Decoder(nn.Module):
|
| 295 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 296 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 297 |
+
resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, **ignorekwargs):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.ch = ch
|
| 300 |
+
self.temb_ch = 0
|
| 301 |
+
self.num_resolutions = len(ch_mult)
|
| 302 |
+
self.num_res_blocks = num_res_blocks
|
| 303 |
+
self.resolution = resolution
|
| 304 |
+
self.in_channels = in_channels
|
| 305 |
+
self.give_pre_end = give_pre_end
|
| 306 |
+
|
| 307 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 308 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 309 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
| 310 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
| 311 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
| 312 |
+
|
| 313 |
+
# z to block_in
|
| 314 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
| 315 |
+
block_in,
|
| 316 |
+
kernel_size=3,
|
| 317 |
+
stride=1,
|
| 318 |
+
padding=1)
|
| 319 |
+
|
| 320 |
+
# middle
|
| 321 |
+
self.mid = nn.Module()
|
| 322 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 323 |
+
out_channels=block_in,
|
| 324 |
+
temb_channels=self.temb_ch,
|
| 325 |
+
dropout=dropout,
|
| 326 |
+
zq_ch=zq_ch,
|
| 327 |
+
add_conv=add_conv)
|
| 328 |
+
self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
|
| 329 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 330 |
+
out_channels=block_in,
|
| 331 |
+
temb_channels=self.temb_ch,
|
| 332 |
+
dropout=dropout,
|
| 333 |
+
zq_ch=zq_ch,
|
| 334 |
+
add_conv=add_conv)
|
| 335 |
+
|
| 336 |
+
# upsampling
|
| 337 |
+
self.up = nn.ModuleList()
|
| 338 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 339 |
+
block = nn.ModuleList()
|
| 340 |
+
attn = nn.ModuleList()
|
| 341 |
+
block_out = ch*ch_mult[i_level]
|
| 342 |
+
for i_block in range(self.num_res_blocks+1):
|
| 343 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 344 |
+
out_channels=block_out,
|
| 345 |
+
temb_channels=self.temb_ch,
|
| 346 |
+
dropout=dropout,
|
| 347 |
+
zq_ch=zq_ch,
|
| 348 |
+
add_conv=add_conv))
|
| 349 |
+
block_in = block_out
|
| 350 |
+
if curr_res in attn_resolutions:
|
| 351 |
+
attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
|
| 352 |
+
up = nn.Module()
|
| 353 |
+
up.block = block
|
| 354 |
+
up.attn = attn
|
| 355 |
+
if i_level != 0:
|
| 356 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 357 |
+
curr_res = curr_res * 2
|
| 358 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 359 |
+
|
| 360 |
+
# end
|
| 361 |
+
self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
|
| 362 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 363 |
+
out_ch,
|
| 364 |
+
kernel_size=3,
|
| 365 |
+
stride=1,
|
| 366 |
+
padding=1)
|
| 367 |
+
|
| 368 |
+
def forward(self, z, zq):
|
| 369 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 370 |
+
self.last_z_shape = z.shape
|
| 371 |
+
|
| 372 |
+
# timestep embedding
|
| 373 |
+
temb = None
|
| 374 |
+
|
| 375 |
+
# z to block_in
|
| 376 |
+
h = self.conv_in(z)
|
| 377 |
+
|
| 378 |
+
# middle
|
| 379 |
+
h = self.mid.block_1(h, temb, zq)
|
| 380 |
+
h = self.mid.attn_1(h, zq)
|
| 381 |
+
h = self.mid.block_2(h, temb, zq)
|
| 382 |
+
|
| 383 |
+
# upsampling
|
| 384 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 385 |
+
for i_block in range(self.num_res_blocks+1):
|
| 386 |
+
h = self.up[i_level].block[i_block](h, temb, zq)
|
| 387 |
+
if len(self.up[i_level].attn) > 0:
|
| 388 |
+
h = self.up[i_level].attn[i_block](h, zq)
|
| 389 |
+
if i_level != 0:
|
| 390 |
+
h = self.up[i_level].upsample(h)
|
| 391 |
+
|
| 392 |
+
# end
|
| 393 |
+
if self.give_pre_end:
|
| 394 |
+
return h
|
| 395 |
+
|
| 396 |
+
h = self.norm_out(h, zq)
|
| 397 |
+
h = nonlinearity(h)
|
| 398 |
+
h = self.conv_out(h)
|
| 399 |
+
return h
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class MoVQ(nn.Module):
|
| 403 |
+
|
| 404 |
+
def __init__(self, generator_params):
|
| 405 |
+
super().__init__()
|
| 406 |
+
z_channels = generator_params["z_channels"]
|
| 407 |
+
self.encoder = Encoder(**generator_params)
|
| 408 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
| 409 |
+
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
| 410 |
+
self.decoder = Decoder(zq_ch=z_channels, **generator_params)
|
| 411 |
+
|
| 412 |
+
# @torch.no_grad()
|
| 413 |
+
def encode(self, x):
|
| 414 |
+
h = self.encoder(x)
|
| 415 |
+
h = self.quant_conv(h)
|
| 416 |
+
return h
|
| 417 |
+
|
| 418 |
+
# @torch.no_grad()
|
| 419 |
+
def decode(self, quant):
|
| 420 |
+
decoder_input = self.post_quant_conv(quant)
|
| 421 |
+
decoded = self.decoder(decoder_input, quant)
|
| 422 |
+
return decoded
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def get_vae(conf):
|
| 426 |
+
movq = MoVQ(conf.params)
|
| 427 |
+
if conf.checkpoint is not None:
|
| 428 |
+
movq_state_dict = torch.load(conf.checkpoint)
|
| 429 |
+
movq.load_state_dict(movq_state_dict)
|
| 430 |
+
movq = freeze(movq)
|
| 431 |
+
return movq
|
kandinsky3/.ipynb_checkpoints/t2i_pipeline-checkpoint.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, List
|
| 2 |
+
import PIL
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms as T
|
| 6 |
+
from einops import repeat
|
| 7 |
+
|
| 8 |
+
from kandinsky3.model.unet import UNet
|
| 9 |
+
from kandinsky3.movq import MoVQ
|
| 10 |
+
from kandinsky3.condition_encoders import T5TextConditionEncoder
|
| 11 |
+
from kandinsky3.condition_processors import T5TextConditionProcessor
|
| 12 |
+
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Kandinsky3T2IPipeline:
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
device_map: Union[str, torch.device, dict],
|
| 20 |
+
dtype_map: Union[str, torch.dtype, dict],
|
| 21 |
+
unet: UNet,
|
| 22 |
+
null_embedding: torch.Tensor,
|
| 23 |
+
t5_processor: T5TextConditionProcessor,
|
| 24 |
+
t5_encoder: T5TextConditionEncoder,
|
| 25 |
+
movq: MoVQ,
|
| 26 |
+
gan: bool,
|
| 27 |
+
):
|
| 28 |
+
self.device_map = device_map
|
| 29 |
+
self.dtype_map = dtype_map
|
| 30 |
+
self.to_pil = T.ToPILImage()
|
| 31 |
+
|
| 32 |
+
self.unet = unet
|
| 33 |
+
self.null_embedding = null_embedding
|
| 34 |
+
self.t5_processor = t5_processor
|
| 35 |
+
self.t5_encoder = t5_encoder
|
| 36 |
+
self.movq = movq
|
| 37 |
+
|
| 38 |
+
self.gan = gan
|
| 39 |
+
|
| 40 |
+
def __call__(
|
| 41 |
+
self,
|
| 42 |
+
text: str,
|
| 43 |
+
negative_text: str = None,
|
| 44 |
+
images_num: int = 1,
|
| 45 |
+
bs: int = 1,
|
| 46 |
+
width: int = 1024,
|
| 47 |
+
height: int = 1024,
|
| 48 |
+
guidance_scale: float = 3.0,
|
| 49 |
+
steps: int = 50,
|
| 50 |
+
eta: float = 1.0
|
| 51 |
+
) -> List[PIL.Image.Image]:
|
| 52 |
+
|
| 53 |
+
betas = get_named_beta_schedule('cosine', 1000)
|
| 54 |
+
base_diffusion = BaseDiffusion(betas, 0.99)
|
| 55 |
+
times = list(range(999, 0, -1000 // steps))
|
| 56 |
+
if self.gan:
|
| 57 |
+
times = list(range(979, 0, -250))
|
| 58 |
+
|
| 59 |
+
condition_model_input, negative_condition_model_input = self.t5_processor.encode(text, negative_text)
|
| 60 |
+
for input_type in condition_model_input:
|
| 61 |
+
condition_model_input[input_type] = condition_model_input[input_type][None].to(
|
| 62 |
+
self.device_map['text_encoder']
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if negative_condition_model_input is not None:
|
| 66 |
+
for input_type in negative_condition_model_input:
|
| 67 |
+
negative_condition_model_input[input_type] = negative_condition_model_input[input_type][None].to(
|
| 68 |
+
self.device_map['text_encoder']
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
pil_images = []
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
|
| 74 |
+
context, context_mask = self.t5_encoder(condition_model_input)
|
| 75 |
+
if negative_condition_model_input is not None:
|
| 76 |
+
negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
|
| 77 |
+
else:
|
| 78 |
+
negative_context, negative_context_mask = None, None
|
| 79 |
+
|
| 80 |
+
k, m = images_num // bs, images_num % bs
|
| 81 |
+
for minibatch in [bs] * k + [m]:
|
| 82 |
+
if minibatch == 0:
|
| 83 |
+
continue
|
| 84 |
+
bs_context = repeat(context, '1 n d -> b n d', b=minibatch)
|
| 85 |
+
bs_context_mask = repeat(context_mask, '1 n -> b n', b=minibatch)
|
| 86 |
+
if negative_context is not None:
|
| 87 |
+
bs_negative_context = repeat(negative_context, '1 n d -> b n d', b=minibatch)
|
| 88 |
+
bs_negative_context_mask = repeat(negative_context_mask, '1 n -> b n', b=minibatch)
|
| 89 |
+
else:
|
| 90 |
+
bs_negative_context, bs_negative_context_mask = None, None
|
| 91 |
+
|
| 92 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
|
| 93 |
+
images = base_diffusion.p_sample_loop(
|
| 94 |
+
self.unet, (minibatch, 4, height // 8, width // 8), times, self.device_map['unet'],
|
| 95 |
+
bs_context, bs_context_mask, self.null_embedding, guidance_scale, eta,
|
| 96 |
+
negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
|
| 97 |
+
gan=self.gan
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
|
| 101 |
+
images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
|
| 102 |
+
# print(torch.max(images), torch.min(images))
|
| 103 |
+
images = torch.clip((images + 1.) / 2., 0., 1.)
|
| 104 |
+
# print(torch.max(images), torch.min(images))
|
| 105 |
+
# raise
|
| 106 |
+
for images_chunk in images.chunk(1):
|
| 107 |
+
pil_images += [self.to_pil(image) for image in images_chunk]
|
| 108 |
+
|
| 109 |
+
return pil_images
|
kandinsky3/.ipynb_checkpoints/utils-checkpoint.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import OmegaConf
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy import ndimage
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from skimage.transform import resize
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_conf(config_path):
|
| 9 |
+
conf = OmegaConf.load(config_path)
|
| 10 |
+
conf.data.tokens_length = conf.common.tokens_length
|
| 11 |
+
conf.data.processor_names = conf.model.encoders.model_names
|
| 12 |
+
conf.data.dataset.seed = conf.common.seed
|
| 13 |
+
conf.data.dataset.image_size = conf.common.image_size
|
| 14 |
+
|
| 15 |
+
conf.trainer.trainer_params.max_steps = conf.common.train_steps
|
| 16 |
+
conf.scheduler.params.total_steps = conf.common.train_steps
|
| 17 |
+
conf.logger.tensorboard.name = conf.common.experiment_name
|
| 18 |
+
|
| 19 |
+
conf.model.encoders.context_dim = conf.model.unet_params.context_dim
|
| 20 |
+
return conf
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def freeze(model):
|
| 24 |
+
for p in model.parameters():
|
| 25 |
+
p.requires_grad = False
|
| 26 |
+
return model
|
| 27 |
+
|
| 28 |
+
def unfreeze(model):
|
| 29 |
+
for p in model.parameters():
|
| 30 |
+
p.requires_grad = True
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
def zero_module(module):
|
| 34 |
+
for p in module.parameters():
|
| 35 |
+
nn.init.zeros_(p)
|
| 36 |
+
return module
|
| 37 |
+
|
| 38 |
+
def resize_mask_for_diffusion(mask):
|
| 39 |
+
reduce_factor = max(1, (mask.size / 1024**2)**0.5)
|
| 40 |
+
resized_mask = resize(
|
| 41 |
+
mask,
|
| 42 |
+
(
|
| 43 |
+
(round(mask.shape[0] / reduce_factor) // 64) * 64,
|
| 44 |
+
(round(mask.shape[1] / reduce_factor) // 64) * 64
|
| 45 |
+
),
|
| 46 |
+
preserve_range=True,
|
| 47 |
+
anti_aliasing=False
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return resized_mask
|
| 51 |
+
|
| 52 |
+
def resize_image_for_diffusion(image):
|
| 53 |
+
reduce_factor = max(1, (image.size[0] * image.size[1] / 1024**2)**0.5)
|
| 54 |
+
image = image.resize((
|
| 55 |
+
(round(image.size[0] / reduce_factor) // 64) * 64, (round(image.size[1] / reduce_factor) // 64) * 64
|
| 56 |
+
))
|
| 57 |
+
|
| 58 |
+
return image
|
| 59 |
+
|
| 60 |
+
def prepare_mask(mask):
|
| 61 |
+
ker = np.array([[1, 1, 1, 1, 1],
|
| 62 |
+
[1, 5, 5, 5, 1],
|
| 63 |
+
[1, 5, 44, 5, 1],
|
| 64 |
+
[1, 5, 5, 5, 1],
|
| 65 |
+
[1, 1, 1, 1, 1]]) / 100
|
| 66 |
+
out = ndimage.convolve(mask, ker)
|
| 67 |
+
out = ndimage.convolve(out, ker)
|
| 68 |
+
out = ndimage.convolve(out, ker)
|
| 69 |
+
|
| 70 |
+
mask = (out > 0).astype(int)
|
| 71 |
+
return mask
|
kandinsky3/__init__.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 6 |
+
|
| 7 |
+
from kandinsky3.model.unet import UNet
|
| 8 |
+
from kandinsky3.movq import MoVQ
|
| 9 |
+
from kandinsky3.condition_encoders import T5TextConditionEncoder
|
| 10 |
+
from kandinsky3.condition_processors import T5TextConditionProcessor
|
| 11 |
+
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
|
| 12 |
+
|
| 13 |
+
from .t2i_pipeline import Kandinsky3T2IPipeline
|
| 14 |
+
from .inpainting_pipeline import Kandinsky3InpaintingPipeline
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_T2I_unet(
|
| 18 |
+
device: Union[str, torch.device],
|
| 19 |
+
weights_path: Optional[str] = None,
|
| 20 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
| 21 |
+
) -> (UNet, Optional[torch.Tensor], Optional[dict]):
|
| 22 |
+
unet = UNet(
|
| 23 |
+
model_channels=384,
|
| 24 |
+
num_channels=4,
|
| 25 |
+
init_channels=192,
|
| 26 |
+
time_embed_dim=1536,
|
| 27 |
+
context_dim=4096,
|
| 28 |
+
groups=32,
|
| 29 |
+
head_dim=64,
|
| 30 |
+
expansion_ratio=4,
|
| 31 |
+
compression_ratio=2,
|
| 32 |
+
dim_mult=(1, 2, 4, 8),
|
| 33 |
+
num_blocks=(3, 3, 3, 3),
|
| 34 |
+
add_cross_attention=(False, True, True, True),
|
| 35 |
+
add_self_attention=(False, True, True, True),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
null_embedding = None
|
| 39 |
+
if weights_path:
|
| 40 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 41 |
+
null_embedding = state_dict['null_embedding']
|
| 42 |
+
unet.load_state_dict(state_dict['unet'])
|
| 43 |
+
|
| 44 |
+
unet.to(device=device, dtype=dtype).eval()
|
| 45 |
+
return unet, null_embedding
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_T5encoder(
|
| 49 |
+
device: Union[str, torch.device],
|
| 50 |
+
weights_path: str,
|
| 51 |
+
projection_name: str,
|
| 52 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
| 53 |
+
low_cpu_mem_usage: bool = True,
|
| 54 |
+
load_in_8bit: bool = False,
|
| 55 |
+
load_in_4bit: bool = False,
|
| 56 |
+
) -> (T5TextConditionProcessor, T5TextConditionEncoder):
|
| 57 |
+
tokens_length = 128
|
| 58 |
+
context_dim = 4096
|
| 59 |
+
processor = T5TextConditionProcessor(tokens_length, weights_path)
|
| 60 |
+
condition_encoder = T5TextConditionEncoder(
|
| 61 |
+
weights_path, context_dim, low_cpu_mem_usage=low_cpu_mem_usage, device=device,
|
| 62 |
+
dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if weights_path:
|
| 66 |
+
projections_weights_path = os.path.join(weights_path, projection_name)
|
| 67 |
+
state_dict = torch.load(projections_weights_path, map_location=torch.device('cpu'))
|
| 68 |
+
condition_encoder.projection.load_state_dict(state_dict)
|
| 69 |
+
|
| 70 |
+
condition_encoder.projection.to(device=device, dtype=dtype).eval()
|
| 71 |
+
return processor, condition_encoder
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_movq(
|
| 75 |
+
device: Union[str, torch.device],
|
| 76 |
+
weights_path: Optional[str] = None,
|
| 77 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
| 78 |
+
) -> MoVQ:
|
| 79 |
+
generator_config = {
|
| 80 |
+
'double_z': False,
|
| 81 |
+
'z_channels': 4,
|
| 82 |
+
'resolution': 256,
|
| 83 |
+
'in_channels': 3,
|
| 84 |
+
'out_ch': 3,
|
| 85 |
+
'ch': 256,
|
| 86 |
+
'ch_mult': [1, 2, 2, 4],
|
| 87 |
+
'num_res_blocks': 2,
|
| 88 |
+
'attn_resolutions': [32],
|
| 89 |
+
'dropout': 0.0
|
| 90 |
+
}
|
| 91 |
+
movq = MoVQ(generator_config)
|
| 92 |
+
|
| 93 |
+
if weights_path:
|
| 94 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 95 |
+
movq.load_state_dict(state_dict)
|
| 96 |
+
|
| 97 |
+
movq.to(device=device, dtype=dtype).eval()
|
| 98 |
+
return movq
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_inpainting_unet(
|
| 102 |
+
device: Union[str, torch.device],
|
| 103 |
+
weights_path: Optional[str] = None,
|
| 104 |
+
dtype: Union[str, torch.dtype] = torch.float32,
|
| 105 |
+
) -> (UNet, Optional[torch.Tensor], Optional[dict]):
|
| 106 |
+
unet = UNet(
|
| 107 |
+
model_channels=384,
|
| 108 |
+
num_channels=9,
|
| 109 |
+
init_channels=192,
|
| 110 |
+
time_embed_dim=1536,
|
| 111 |
+
context_dim=4096,
|
| 112 |
+
groups=32,
|
| 113 |
+
head_dim=64,
|
| 114 |
+
expansion_ratio=4,
|
| 115 |
+
compression_ratio=2,
|
| 116 |
+
dim_mult=(1, 2, 4, 8),
|
| 117 |
+
num_blocks=(3, 3, 3, 3),
|
| 118 |
+
add_cross_attention=(False, True, True, True),
|
| 119 |
+
add_self_attention=(False, True, True, True),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
null_embedding = None
|
| 123 |
+
if weights_path:
|
| 124 |
+
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
|
| 125 |
+
null_embedding = state_dict['null_embedding']
|
| 126 |
+
unet.load_state_dict(state_dict['unet'])
|
| 127 |
+
|
| 128 |
+
unet.to(device=device, dtype=dtype).eval()
|
| 129 |
+
return unet, null_embedding
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_T2I_pipeline(
|
| 133 |
+
device_map: Union[str, torch.device, dict],
|
| 134 |
+
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
|
| 135 |
+
low_cpu_mem_usage: bool = True,
|
| 136 |
+
load_in_8bit: bool = False,
|
| 137 |
+
load_in_4bit: bool = False,
|
| 138 |
+
cache_dir: str = '/tmp/kandinsky3/',
|
| 139 |
+
unet_path: str = None,
|
| 140 |
+
text_encoder_path: str = None,
|
| 141 |
+
movq_path: str = None,
|
| 142 |
+
) -> Kandinsky3T2IPipeline:
|
| 143 |
+
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
|
| 144 |
+
if not isinstance(device_map, dict):
|
| 145 |
+
device_map = {
|
| 146 |
+
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
|
| 147 |
+
}
|
| 148 |
+
if not isinstance(dtype_map, dict):
|
| 149 |
+
dtype_map = {
|
| 150 |
+
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
if unet_path is None:
|
| 154 |
+
unet_path = hf_hub_download(
|
| 155 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3.pt', cache_dir=cache_dir
|
| 156 |
+
)
|
| 157 |
+
if text_encoder_path is None:
|
| 158 |
+
text_encoder_path = snapshot_download(
|
| 159 |
+
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
|
| 160 |
+
)
|
| 161 |
+
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
|
| 162 |
+
if movq_path is None:
|
| 163 |
+
movq_path = hf_hub_download(
|
| 164 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
|
| 168 |
+
processor, condition_encoder = get_T5encoder(
|
| 169 |
+
device_map['text_encoder'], text_encoder_path, 'projection.pt', dtype=dtype_map['text_encoder'],
|
| 170 |
+
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
|
| 171 |
+
)
|
| 172 |
+
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
|
| 173 |
+
return Kandinsky3T2IPipeline(
|
| 174 |
+
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, False
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_T2I_Flash_pipeline(
|
| 179 |
+
device_map: Union[str, torch.device, dict],
|
| 180 |
+
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
|
| 181 |
+
low_cpu_mem_usage: bool = True,
|
| 182 |
+
load_in_8bit: bool = False,
|
| 183 |
+
load_in_4bit: bool = False,
|
| 184 |
+
cache_dir: str = '/tmp/kandinsky3/',
|
| 185 |
+
unet_path: str = None,
|
| 186 |
+
text_encoder_path: str = None,
|
| 187 |
+
movq_path: str = None,
|
| 188 |
+
) -> Kandinsky3T2IPipeline:
|
| 189 |
+
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
|
| 190 |
+
if not isinstance(device_map, dict):
|
| 191 |
+
device_map = {
|
| 192 |
+
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
|
| 193 |
+
}
|
| 194 |
+
if not isinstance(dtype_map, dict):
|
| 195 |
+
dtype_map = {
|
| 196 |
+
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
if unet_path is None:
|
| 200 |
+
unet_path = hf_hub_download(
|
| 201 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir
|
| 202 |
+
)
|
| 203 |
+
if text_encoder_path is None:
|
| 204 |
+
text_encoder_path = snapshot_download(
|
| 205 |
+
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
|
| 206 |
+
)
|
| 207 |
+
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
|
| 208 |
+
if movq_path is None:
|
| 209 |
+
movq_path = hf_hub_download(
|
| 210 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
|
| 214 |
+
processor, condition_encoder = get_T5encoder(
|
| 215 |
+
device_map['text_encoder'], text_encoder_path, 'projection_flash.pt', dtype=dtype_map['text_encoder'],
|
| 216 |
+
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
|
| 217 |
+
)
|
| 218 |
+
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
|
| 219 |
+
return Kandinsky3T2IPipeline(
|
| 220 |
+
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, True
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def get_inpainting_pipeline(
|
| 225 |
+
device_map: Union[str, torch.device, dict],
|
| 226 |
+
dtype_map: Union[str, torch.dtype, dict] = torch.float32,
|
| 227 |
+
low_cpu_mem_usage: bool = True,
|
| 228 |
+
load_in_8bit: bool = False,
|
| 229 |
+
load_in_4bit: bool = False,
|
| 230 |
+
cache_dir: str = '/tmp/kandinsky3/',
|
| 231 |
+
unet_path: str = None,
|
| 232 |
+
text_encoder_path: str = None,
|
| 233 |
+
movq_path: str = None,
|
| 234 |
+
) -> Kandinsky3InpaintingPipeline:
|
| 235 |
+
# assert ((unet_path is not None) or (text_encoder_path is not None) or (movq_path is not None))
|
| 236 |
+
if not isinstance(device_map, dict):
|
| 237 |
+
device_map = {
|
| 238 |
+
'unet': device_map, 'text_encoder': device_map, 'movq': device_map
|
| 239 |
+
}
|
| 240 |
+
if not isinstance(dtype_map, dict):
|
| 241 |
+
dtype_map = {
|
| 242 |
+
'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
if unet_path is None:
|
| 246 |
+
unet_path = hf_hub_download(
|
| 247 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_inpainting.pt', cache_dir=cache_dir
|
| 248 |
+
)
|
| 249 |
+
if text_encoder_path is None:
|
| 250 |
+
text_encoder_path = snapshot_download(
|
| 251 |
+
repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir
|
| 252 |
+
)
|
| 253 |
+
text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder')
|
| 254 |
+
if movq_path is None:
|
| 255 |
+
movq_path = hf_hub_download(
|
| 256 |
+
repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
unet, null_embedding = get_inpainting_unet(device_map['unet'], unet_path, dtype=dtype_map['unet'])
|
| 260 |
+
processor, condition_encoder = get_T5encoder(
|
| 261 |
+
device_map['text_encoder'], text_encoder_path, 'projection_inpainting.pt', dtype=dtype_map['text_encoder'],
|
| 262 |
+
low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit
|
| 263 |
+
)
|
| 264 |
+
movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq'])
|
| 265 |
+
return Kandinsky3InpaintingPipeline(
|
| 266 |
+
device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq
|
| 267 |
+
)
|
kandinsky3/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (5.38 kB). View file
|
|
|
kandinsky3/__pycache__/condition_encoders.cpython-310.pyc
ADDED
|
Binary file (1.84 kB). View file
|
|
|
kandinsky3/__pycache__/condition_processors.cpython-310.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
kandinsky3/__pycache__/inpainting_pipeline.cpython-310.pyc
ADDED
|
Binary file (5.32 kB). View file
|
|
|
kandinsky3/__pycache__/movq.cpython-310.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
kandinsky3/__pycache__/t2i_pipeline.cpython-310.pyc
ADDED
|
Binary file (3.63 kB). View file
|
|
|
kandinsky3/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (2.31 kB). View file
|
|
|
kandinsky3/condition_encoders.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from transformers import T5EncoderModel
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class T5TextConditionEncoder(nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self, model_path, context_dim,
|
| 11 |
+
low_cpu_mem_usage: bool = True, device: Optional[str] = None,
|
| 12 |
+
dtype: Union[str, torch.dtype] = torch.float32, load_in_4bit: bool = False, load_in_8bit: bool = False
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.encoder = T5EncoderModel.from_pretrained(
|
| 16 |
+
model_path, low_cpu_mem_usage=low_cpu_mem_usage, device_map=device,
|
| 17 |
+
torch_dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit,
|
| 18 |
+
).encoder
|
| 19 |
+
self.projection = nn.Sequential(
|
| 20 |
+
nn.Linear(self.encoder.config.d_model, context_dim, bias=False),
|
| 21 |
+
nn.LayerNorm(context_dim)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, model_input):
|
| 25 |
+
embeddings = self.encoder(**model_input).last_hidden_state
|
| 26 |
+
context = self.projection(embeddings)
|
| 27 |
+
if 'attention_mask' in model_input:
|
| 28 |
+
context_mask = model_input['attention_mask']
|
| 29 |
+
context[context_mask == 0] = torch.zeros_like(context[context_mask == 0])
|
| 30 |
+
max_seq_length = context_mask.sum(-1).max() + 1
|
| 31 |
+
context = context[:, :max_seq_length]
|
| 32 |
+
context_mask = context_mask[:, :max_seq_length]
|
| 33 |
+
else:
|
| 34 |
+
context_mask = torch.ones(*embeddings.shape[:-1], dtype=torch.long, device=embeddings.device)
|
| 35 |
+
return context, context_mask
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_condition_encoder(conf):
|
| 39 |
+
return T5TextConditionEncoder(**conf)
|
| 40 |
+
|
kandinsky3/condition_processors.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import T5Tokenizer
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class T5TextConditionProcessor:
|
| 6 |
+
|
| 7 |
+
def __init__(self, tokens_length, processor_path):
|
| 8 |
+
self.tokens_length = tokens_length
|
| 9 |
+
self.processor = T5Tokenizer.from_pretrained(processor_path)
|
| 10 |
+
|
| 11 |
+
def encode(self, text=None, negative_text=None):
|
| 12 |
+
encoded = self.processor(text, max_length=self.tokens_length, truncation=True)
|
| 13 |
+
pad_length = self.tokens_length - len(encoded['input_ids'])
|
| 14 |
+
input_ids = encoded['input_ids'] + [self.processor.pad_token_id] * pad_length
|
| 15 |
+
attention_mask = encoded['attention_mask'] + [0] * pad_length
|
| 16 |
+
condition_model_input = {
|
| 17 |
+
'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
| 18 |
+
'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
if negative_text is not None:
|
| 22 |
+
negative_encoded = self.processor(negative_text, max_length=self.tokens_length, truncation=True)
|
| 23 |
+
negative_input_ids = negative_encoded['input_ids'][:len(encoded['input_ids'])]
|
| 24 |
+
negative_input_ids[-1] = self.processor.eos_token_id
|
| 25 |
+
negative_pad_length = self.tokens_length - len(negative_input_ids)
|
| 26 |
+
negative_input_ids = negative_input_ids + [self.processor.pad_token_id] * negative_pad_length
|
| 27 |
+
negative_attention_mask = encoded['attention_mask'] + [0] * pad_length
|
| 28 |
+
negative_condition_model_input = {
|
| 29 |
+
'input_ids': torch.tensor(negative_input_ids, dtype=torch.long),
|
| 30 |
+
'attention_mask': torch.tensor(negative_attention_mask, dtype=torch.long)
|
| 31 |
+
}
|
| 32 |
+
else:
|
| 33 |
+
negative_condition_model_input = None
|
| 34 |
+
return condition_model_input, negative_condition_model_input
|
kandinsky3/inpainting_pipeline.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, List
|
| 2 |
+
import PIL
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms as T
|
| 7 |
+
from einops import repeat
|
| 8 |
+
|
| 9 |
+
from kandinsky3.model.unet import UNet
|
| 10 |
+
from kandinsky3.movq import MoVQ
|
| 11 |
+
from kandinsky3.condition_encoders import T5TextConditionEncoder
|
| 12 |
+
from kandinsky3.condition_processors import T5TextConditionProcessor
|
| 13 |
+
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
|
| 14 |
+
from kandinsky3.utils import resize_image_for_diffusion, resize_mask_for_diffusion
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Kandinsky3InpaintingPipeline:
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
device_map: Union[str, torch.device, dict],
|
| 22 |
+
dtype_map: Union[str, torch.dtype, dict],
|
| 23 |
+
unet: UNet,
|
| 24 |
+
null_embedding: torch.Tensor,
|
| 25 |
+
t5_processor: T5TextConditionProcessor,
|
| 26 |
+
t5_encoder: T5TextConditionEncoder,
|
| 27 |
+
movq: MoVQ,
|
| 28 |
+
):
|
| 29 |
+
self.device_map = device_map
|
| 30 |
+
self.dtype_map = dtype_map
|
| 31 |
+
self.to_pil = T.ToPILImage()
|
| 32 |
+
self.to_tensor = T.ToTensor()
|
| 33 |
+
|
| 34 |
+
self.unet = unet
|
| 35 |
+
self.null_embedding = null_embedding
|
| 36 |
+
self.t5_processor = t5_processor
|
| 37 |
+
self.t5_encoder = t5_encoder
|
| 38 |
+
self.movq = movq
|
| 39 |
+
|
| 40 |
+
def shared_step(self, batch: dict) -> dict:
|
| 41 |
+
image = batch['image']
|
| 42 |
+
condition_model_input = batch['text']
|
| 43 |
+
negative_condition_model_input = batch['negative_text']
|
| 44 |
+
|
| 45 |
+
bs = image.shape[0]
|
| 46 |
+
|
| 47 |
+
masked_latent = None
|
| 48 |
+
mask = batch['mask']
|
| 49 |
+
|
| 50 |
+
if 'masked_image' in batch:
|
| 51 |
+
masked_latent = batch['masked_image']
|
| 52 |
+
elif self.unet.in_layer.in_channels == 9:
|
| 53 |
+
masked_latent = image.masked_fill((1 - mask).bool(), 0)
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError()
|
| 56 |
+
|
| 57 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
|
| 58 |
+
masked_latent = self.movq.encode(masked_latent)
|
| 59 |
+
mask = torch.nn.functional.interpolate(mask, size=(masked_latent.shape[2], masked_latent.shape[3]))
|
| 60 |
+
|
| 61 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
|
| 62 |
+
context, context_mask = self.t5_encoder(condition_model_input)
|
| 63 |
+
|
| 64 |
+
if negative_condition_model_input is not None:
|
| 65 |
+
negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
|
| 66 |
+
else:
|
| 67 |
+
negative_context, negative_context_mask = None, None
|
| 68 |
+
|
| 69 |
+
return {
|
| 70 |
+
'context': context,
|
| 71 |
+
'context_mask': context_mask,
|
| 72 |
+
'negative_context': negative_context,
|
| 73 |
+
'negative_context_mask': negative_context_mask,
|
| 74 |
+
'image': image,
|
| 75 |
+
'masked_latent': masked_latent,
|
| 76 |
+
'mask': mask
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def prepare_batch(
|
| 80 |
+
self,
|
| 81 |
+
text: str,
|
| 82 |
+
negative_text: str,
|
| 83 |
+
image: PIL.Image.Image,
|
| 84 |
+
mask: np.ndarray,
|
| 85 |
+
) -> dict:
|
| 86 |
+
condition_model_input, negative_condition_model_input = self.t5_processor.encode(
|
| 87 |
+
text=text, negative_text=negative_text
|
| 88 |
+
)
|
| 89 |
+
batch = {
|
| 90 |
+
'image': self.to_tensor(resize_image_for_diffusion(image.convert("RGB"))) * 2 - 1,
|
| 91 |
+
'mask': 1 - self.to_tensor(resize_mask_for_diffusion(mask)),
|
| 92 |
+
'text': condition_model_input,
|
| 93 |
+
'negative_text': negative_condition_model_input
|
| 94 |
+
}
|
| 95 |
+
batch['mask'] = batch['mask'].type(self.dtype_map['movq'])
|
| 96 |
+
|
| 97 |
+
batch['image'] = batch['image'].unsqueeze(0).to(self.device_map['movq'])
|
| 98 |
+
batch['text']['input_ids'] = batch['text']['input_ids'].unsqueeze(0).to(self.device_map['text_encoder'])
|
| 99 |
+
batch['text']['attention_mask'] = batch['text']['attention_mask'].unsqueeze(0).to(
|
| 100 |
+
self.device_map['text_encoder'])
|
| 101 |
+
batch['mask'] = batch['mask'].unsqueeze(0).to(self.device_map['movq'])
|
| 102 |
+
|
| 103 |
+
if negative_condition_model_input is not None:
|
| 104 |
+
batch['negative_text']['input_ids'] = batch['negative_text']['input_ids'].to(
|
| 105 |
+
self.device_map['text_encoder'])
|
| 106 |
+
batch['negative_text']['attention_mask'] = batch['negative_text']['attention_mask'].to(
|
| 107 |
+
self.device_map['text_encoder'])
|
| 108 |
+
|
| 109 |
+
return batch
|
| 110 |
+
|
| 111 |
+
def __call__(
|
| 112 |
+
self,
|
| 113 |
+
text: str,
|
| 114 |
+
image: PIL.Image.Image,
|
| 115 |
+
mask: np.ndarray,
|
| 116 |
+
negative_text: str = None,
|
| 117 |
+
images_num: int = 1,
|
| 118 |
+
bs: int = 1,
|
| 119 |
+
steps: int = 50,
|
| 120 |
+
guidance_weight_text: float = 4,
|
| 121 |
+
eta=1.0
|
| 122 |
+
) -> List[PIL.Image.Image]:
|
| 123 |
+
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
batch = self.prepare_batch(text, negative_text, image, mask)
|
| 126 |
+
processed = self.shared_step(batch)
|
| 127 |
+
betas = get_named_beta_schedule('cosine', 1000)
|
| 128 |
+
base_diffusion = BaseDiffusion(betas, percentile=0.95)
|
| 129 |
+
times = list(range(999, 0, -1000 // steps))
|
| 130 |
+
|
| 131 |
+
pil_images = []
|
| 132 |
+
k, m = images_num // bs, images_num % bs
|
| 133 |
+
for minibatch in [bs] * k + [m]:
|
| 134 |
+
if minibatch == 0:
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
bs_context = repeat(processed['context'], '1 n d -> b n d', b=minibatch)
|
| 138 |
+
bs_context_mask = repeat(processed['context_mask'], '1 n -> b n', b=minibatch)
|
| 139 |
+
|
| 140 |
+
if processed['negative_context'] is not None:
|
| 141 |
+
bs_negative_context = repeat(processed['negative_context'], '1 n d -> b n d', b=minibatch)
|
| 142 |
+
bs_negative_context_mask = repeat(processed['negative_context_mask'], '1 n -> b n', b=minibatch)
|
| 143 |
+
else:
|
| 144 |
+
bs_negative_context, bs_negative_context_mask = None, None
|
| 145 |
+
|
| 146 |
+
mask = processed['mask'].repeat_interleave(minibatch, dim=0)
|
| 147 |
+
masked_latent = processed['masked_latent'].repeat_interleave(minibatch, dim=0)
|
| 148 |
+
|
| 149 |
+
minibatch = masked_latent.shape[0]
|
| 150 |
+
|
| 151 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
images = base_diffusion.p_sample_loop(
|
| 154 |
+
self.unet, (minibatch, 4, masked_latent.shape[2], masked_latent.shape[3]), times,
|
| 155 |
+
self.device_map['unet'],
|
| 156 |
+
bs_context, bs_context_mask, self.null_embedding, guidance_weight_text, eta,
|
| 157 |
+
negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
|
| 158 |
+
mask=mask, masked_latent=masked_latent, gan=False
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
|
| 162 |
+
images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
|
| 163 |
+
images = torch.clip((images + 1.) / 2., 0., 1.).cpu()
|
| 164 |
+
|
| 165 |
+
for images_chunk in images.chunk(1):
|
| 166 |
+
pil_images += [self.to_pil(image) for image in images_chunk]
|
| 167 |
+
|
| 168 |
+
return pil_images
|
kandinsky3/model/.ipynb_checkpoints/diffusion-checkpoint.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from .utils import get_tensor_items
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_named_beta_schedule(schedule_name, timesteps):
|
| 11 |
+
if schedule_name == "linear":
|
| 12 |
+
scale = 1000 / timesteps
|
| 13 |
+
beta_start = scale * 0.0001
|
| 14 |
+
beta_end = scale * 0.02
|
| 15 |
+
return torch.linspace(
|
| 16 |
+
beta_start, beta_end, timesteps, dtype=torch.float32
|
| 17 |
+
)
|
| 18 |
+
elif schedule_name == "cosine":
|
| 19 |
+
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 20 |
+
betas = []
|
| 21 |
+
for i in range(timesteps):
|
| 22 |
+
t1 = i / timesteps
|
| 23 |
+
t2 = (i + 1) / timesteps
|
| 24 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
|
| 25 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BaseDiffusion:
|
| 29 |
+
|
| 30 |
+
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
|
| 31 |
+
self.betas = betas
|
| 32 |
+
self.num_timesteps = betas.shape[0]
|
| 33 |
+
|
| 34 |
+
alphas = 1. - betas
|
| 35 |
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 36 |
+
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
|
| 37 |
+
|
| 38 |
+
# calculate q(x_t | x_{t-1})
|
| 39 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
| 40 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
|
| 41 |
+
|
| 42 |
+
# calculate q(x_{t-1} | x_t, x_0)
|
| 43 |
+
self.posterior_mean_coef_1 = torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod)
|
| 44 |
+
self.posterior_mean_coef_2 = torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
|
| 45 |
+
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
|
| 46 |
+
self.posterior_log_variance = torch.log(
|
| 47 |
+
torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.percentile = percentile
|
| 51 |
+
self.time_scale = 1000 // self.num_timesteps
|
| 52 |
+
self.gen_noise = gen_noise
|
| 53 |
+
self.jump_length = 3
|
| 54 |
+
|
| 55 |
+
def process_x_start(self, x_start):
|
| 56 |
+
bs, ndims = x_start.shape[0], len(x_start.shape[1:])
|
| 57 |
+
if self.percentile is not None:
|
| 58 |
+
quantile = torch.quantile(
|
| 59 |
+
rearrange(x_start, 'b ... -> b (...)').abs(),
|
| 60 |
+
self.percentile,
|
| 61 |
+
dim=-1
|
| 62 |
+
)
|
| 63 |
+
quantile = torch.clip(quantile, min=1.)
|
| 64 |
+
quantile = quantile.reshape(bs, *((1,) * ndims))
|
| 65 |
+
return torch.clip(x_start, -quantile, quantile) / quantile
|
| 66 |
+
else:
|
| 67 |
+
return torch.clip(x_start, -1., 1.)
|
| 68 |
+
|
| 69 |
+
def get_x_start(self, x, t, noise):
|
| 70 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
|
| 71 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape)
|
| 72 |
+
pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod
|
| 73 |
+
return pred_x_start
|
| 74 |
+
|
| 75 |
+
def get_noise(self, x, t, x_start):
|
| 76 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
| 77 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
|
| 78 |
+
pred_noise = (x - sqrt_alphas_cumprod * x_start) / sqrt_one_minus_alphas_cumprod
|
| 79 |
+
return pred_noise
|
| 80 |
+
|
| 81 |
+
def q_sample(self, x_start, t, noise=None):
|
| 82 |
+
if noise is None:
|
| 83 |
+
noise = self.gen_noise(x_start)
|
| 84 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
|
| 85 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
|
| 86 |
+
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
|
| 87 |
+
return x_t
|
| 88 |
+
|
| 89 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
| 90 |
+
posterior_mean_coef_1 = get_tensor_items(self.posterior_mean_coef_1, t, x_start.shape)
|
| 91 |
+
posterior_mean_coef_2 = get_tensor_items(self.posterior_mean_coef_2, t, x_t.shape)
|
| 92 |
+
posterior_mean = posterior_mean_coef_1 * x_start + posterior_mean_coef_2 * x_t
|
| 93 |
+
|
| 94 |
+
posterior_variance = get_tensor_items(self.posterior_variance, t, x_start.shape)
|
| 95 |
+
posterior_log_variance = get_tensor_items(self.posterior_log_variance, t, x_start.shape)
|
| 96 |
+
return posterior_mean, posterior_variance, posterior_log_variance
|
| 97 |
+
|
| 98 |
+
def q_posterior_variance(self, t, prev_t, shape, eta=1., ):
|
| 99 |
+
alphas_cumprod = get_tensor_items(self.alphas_cumprod, t, shape)
|
| 100 |
+
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, shape)
|
| 101 |
+
|
| 102 |
+
posterior_variance = torch.sqrt(
|
| 103 |
+
eta * (1. - alphas_cumprod / prev_alphas_cumprod) * (1. - prev_alphas_cumprod) / (1. - alphas_cumprod)
|
| 104 |
+
)
|
| 105 |
+
return posterior_variance
|
| 106 |
+
|
| 107 |
+
def text_guidance(
|
| 108 |
+
self, model, x, t, context, context_mask, null_embedding, guidance_weight_text,
|
| 109 |
+
uncondition_context=None, uncondition_context_mask=None, mask=None, masked_latent=None
|
| 110 |
+
):
|
| 111 |
+
large_x = x.repeat(2, 1, 1, 1)
|
| 112 |
+
large_t = t.repeat(2).to(x.dtype)
|
| 113 |
+
|
| 114 |
+
if uncondition_context is None:
|
| 115 |
+
uncondition_context = torch.zeros_like(context)
|
| 116 |
+
uncondition_context_mask = torch.zeros_like(context_mask)
|
| 117 |
+
uncondition_context[:, 0] = null_embedding
|
| 118 |
+
uncondition_context_mask[:, 0] = 1
|
| 119 |
+
large_context = torch.cat([context, uncondition_context])
|
| 120 |
+
large_context_mask = torch.cat([context_mask, uncondition_context_mask])
|
| 121 |
+
|
| 122 |
+
if mask is not None:
|
| 123 |
+
mask = mask.repeat(2, 1, 1, 1)
|
| 124 |
+
if masked_latent is not None:
|
| 125 |
+
masked_latent = masked_latent.repeat(2, 1, 1, 1)
|
| 126 |
+
|
| 127 |
+
if model.in_layer.in_channels == 9:
|
| 128 |
+
large_x = torch.cat([large_x, mask, masked_latent], dim=1)
|
| 129 |
+
|
| 130 |
+
pred_large_noise = model(large_x, large_t * self.time_scale, large_context, large_context_mask.bool())
|
| 131 |
+
pred_noise, uncond_pred_noise = torch.chunk(pred_large_noise, 2)
|
| 132 |
+
pred_noise = (guidance_weight_text + 1.) * pred_noise - guidance_weight_text * uncond_pred_noise
|
| 133 |
+
return pred_noise
|
| 134 |
+
|
| 135 |
+
def p_mean_variance(
|
| 136 |
+
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
|
| 137 |
+
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
|
| 138 |
+
):
|
| 139 |
+
|
| 140 |
+
pred_noise = self.text_guidance(
|
| 141 |
+
model, x, t, context, context_mask, null_embedding, guidance_weight_text,
|
| 142 |
+
negative_context, negative_context_mask, mask, masked_latent
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
pred_x_start = self.get_x_start(x, t, pred_noise)
|
| 146 |
+
pred_x_start = self.process_x_start(pred_x_start)
|
| 147 |
+
pred_noise = self.get_noise(x, t, pred_x_start)
|
| 148 |
+
pred_var = self.q_posterior_variance(t, prev_t, x.shape, eta)
|
| 149 |
+
|
| 150 |
+
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, x.shape)
|
| 151 |
+
pred_mean = torch.sqrt(prev_alphas_cumprod) * pred_x_start
|
| 152 |
+
pred_mean += torch.sqrt(1. - prev_alphas_cumprod - pred_var ** 2) * pred_noise
|
| 153 |
+
return pred_mean, pred_var
|
| 154 |
+
|
| 155 |
+
# @torch.no_grad()
|
| 156 |
+
def p_sample(
|
| 157 |
+
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
|
| 158 |
+
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
|
| 159 |
+
):
|
| 160 |
+
bs = x.shape[0]
|
| 161 |
+
ndims = len(x.shape[1:])
|
| 162 |
+
pred_mean, pred_var = self.p_mean_variance(
|
| 163 |
+
model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta,
|
| 164 |
+
negative_context=negative_context, negative_context_mask=negative_context_mask,
|
| 165 |
+
mask=mask, masked_latent=masked_latent
|
| 166 |
+
)
|
| 167 |
+
noise = torch.randn_like(x)
|
| 168 |
+
mask = (prev_t != 0).reshape(bs, *((1,) * ndims))
|
| 169 |
+
sample = pred_mean + mask * pred_var * noise
|
| 170 |
+
return sample
|
| 171 |
+
|
| 172 |
+
# @torch.no_grad()
|
| 173 |
+
def p_sample_loop(
|
| 174 |
+
self, model, shape, times, device, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
|
| 175 |
+
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None, gan=False,
|
| 176 |
+
):
|
| 177 |
+
img = torch.randn(*shape, device=device)
|
| 178 |
+
times = times + [0, ]
|
| 179 |
+
times = list(zip(times[:-1], times[1:]))
|
| 180 |
+
|
| 181 |
+
for time, prev_time in tqdm(times):
|
| 182 |
+
time = torch.tensor([time] * shape[0], device=device)
|
| 183 |
+
if gan:
|
| 184 |
+
x_t = self.q_sample(img, time)
|
| 185 |
+
pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool())
|
| 186 |
+
img = self.get_x_start(x_t, time, pred_noise)
|
| 187 |
+
else:
|
| 188 |
+
prev_time = torch.tensor([prev_time] * shape[0], device=device)
|
| 189 |
+
img = self.p_sample(
|
| 190 |
+
model, img, time, prev_time, context, context_mask, null_embedding, guidance_weight_text, eta,
|
| 191 |
+
negative_context=negative_context, negative_context_mask=negative_context_mask,
|
| 192 |
+
mask=mask, masked_latent=masked_latent
|
| 193 |
+
)
|
| 194 |
+
return img
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_diffusion(conf):
|
| 198 |
+
betas = get_named_beta_schedule(**conf.schedule_params)
|
| 199 |
+
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
|
| 200 |
+
return base_diffusion
|
kandinsky3/model/.ipynb_checkpoints/unet-checkpoint.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, einsum
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
from .nn import Identity, Attention, SinusoidalPosEmb, ConditionalGroupNorm
|
| 6 |
+
from .utils import exist, set_default_item, set_default_layer
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Block(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.group_norm = ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
|
| 15 |
+
self.activation = nn.SiLU()
|
| 16 |
+
self.up_sample = set_default_layer(
|
| 17 |
+
exist(up_resolution) and up_resolution,
|
| 18 |
+
nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
|
| 19 |
+
)
|
| 20 |
+
padding = set_default_item(kernel_size == 1, 0, 1)
|
| 21 |
+
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
|
| 22 |
+
self.down_sample = set_default_layer(
|
| 23 |
+
exist(up_resolution) and not up_resolution,
|
| 24 |
+
nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x, time_embed):
|
| 28 |
+
x = self.group_norm(x, time_embed)
|
| 29 |
+
x = self.activation(x)
|
| 30 |
+
x = self.up_sample(x)
|
| 31 |
+
x = self.projection(x)
|
| 32 |
+
x = self.down_sample(x)
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ResNetBlock(nn.Module):
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4*[None]
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
kernel_sizes = [1, 3, 3, 1]
|
| 43 |
+
hidden_channel = max(in_channels, out_channels) // compression_ratio
|
| 44 |
+
hidden_channels = [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
|
| 45 |
+
self.resnet_blocks = nn.ModuleList([
|
| 46 |
+
Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
|
| 47 |
+
for (in_channel, out_channel), kernel_size, up_resolution in zip(hidden_channels, kernel_sizes, up_resolutions)
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
self.shortcut_up_sample = set_default_layer(
|
| 51 |
+
True in up_resolutions,
|
| 52 |
+
nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
|
| 53 |
+
)
|
| 54 |
+
self.shortcut_projection = set_default_layer(
|
| 55 |
+
in_channels != out_channels,
|
| 56 |
+
nn.Conv2d, (in_channels, out_channels), {'kernel_size': 1}
|
| 57 |
+
)
|
| 58 |
+
self.shortcut_down_sample = set_default_layer(
|
| 59 |
+
False in up_resolutions,
|
| 60 |
+
nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, x, time_embed):
|
| 64 |
+
out = x
|
| 65 |
+
for resnet_block in self.resnet_blocks:
|
| 66 |
+
out = resnet_block(out, time_embed)
|
| 67 |
+
|
| 68 |
+
x = self.shortcut_up_sample(x)
|
| 69 |
+
x = self.shortcut_projection(x)
|
| 70 |
+
x = self.shortcut_down_sample(x)
|
| 71 |
+
x = x + out
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class AttentionPolling(nn.Module):
|
| 76 |
+
|
| 77 |
+
def __init__(self, num_channels, context_dim, head_dim=64):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
|
| 80 |
+
|
| 81 |
+
def forward(self, x, context, context_mask=None):
|
| 82 |
+
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
|
| 83 |
+
return x + context.squeeze(1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class AttentionBlock(nn.Module):
|
| 87 |
+
|
| 88 |
+
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.in_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
|
| 91 |
+
self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
|
| 92 |
+
|
| 93 |
+
hidden_channels = expansion_ratio * num_channels
|
| 94 |
+
self.out_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
|
| 95 |
+
self.feed_forward = nn.Sequential(
|
| 96 |
+
nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
|
| 97 |
+
nn.SiLU(),
|
| 98 |
+
nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, x, time_embed, context=None, context_mask=None):
|
| 102 |
+
height, width = x.shape[-2:]
|
| 103 |
+
out = self.in_norm(x, time_embed)
|
| 104 |
+
out = rearrange(out, 'b c h w -> b (h w) c', h=height, w=width)
|
| 105 |
+
context = set_default_item(exist(context), context, out)
|
| 106 |
+
out = self.attention(out, context, context_mask)
|
| 107 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=height, w=width)
|
| 108 |
+
x = x + out
|
| 109 |
+
|
| 110 |
+
out = self.out_norm(x, time_embed)
|
| 111 |
+
out = self.feed_forward(out)
|
| 112 |
+
x = x + out
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class DownSampleBlock(nn.Module):
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self, in_channels, out_channels, time_embed_dim, context_dim=None,
|
| 120 |
+
num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
|
| 121 |
+
down_sample=True, self_attention=True
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.self_attention_block = set_default_layer(
|
| 125 |
+
self_attention,
|
| 126 |
+
AttentionBlock,
|
| 127 |
+
(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
|
| 128 |
+
layer_2=Identity
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
|
| 132 |
+
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
|
| 133 |
+
self.resnet_attn_blocks = nn.ModuleList([
|
| 134 |
+
nn.ModuleList([
|
| 135 |
+
ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
|
| 136 |
+
set_default_layer(
|
| 137 |
+
exist(context_dim),
|
| 138 |
+
AttentionBlock,
|
| 139 |
+
(out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
|
| 140 |
+
layer_2=Identity
|
| 141 |
+
),
|
| 142 |
+
ResNetBlock(out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution),
|
| 143 |
+
]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
|
| 144 |
+
])
|
| 145 |
+
|
| 146 |
+
def forward(self, x, time_embed, context=None, context_mask=None, control_net_residual=None):
|
| 147 |
+
x = self.self_attention_block(x, time_embed)
|
| 148 |
+
for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
|
| 149 |
+
x = in_resnet_block(x, time_embed)
|
| 150 |
+
x = attention(x, time_embed, context, context_mask)
|
| 151 |
+
x = out_resnet_block(x, time_embed)
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class UpSampleBlock(nn.Module):
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self, in_channels, cat_dim, out_channels, time_embed_dim, context_dim=None,
|
| 159 |
+
num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
|
| 160 |
+
up_sample=True, self_attention=True
|
| 161 |
+
):
|
| 162 |
+
super().__init__()
|
| 163 |
+
up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
|
| 164 |
+
hidden_channels = [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) + [(in_channels, out_channels)]
|
| 165 |
+
self.resnet_attn_blocks = nn.ModuleList([
|
| 166 |
+
nn.ModuleList([
|
| 167 |
+
ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution),
|
| 168 |
+
set_default_layer(
|
| 169 |
+
exist(context_dim),
|
| 170 |
+
AttentionBlock,
|
| 171 |
+
(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
|
| 172 |
+
layer_2=Identity
|
| 173 |
+
),
|
| 174 |
+
ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
|
| 175 |
+
]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
|
| 176 |
+
])
|
| 177 |
+
|
| 178 |
+
self.self_attention_block = set_default_layer(
|
| 179 |
+
self_attention,
|
| 180 |
+
AttentionBlock,
|
| 181 |
+
(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
|
| 182 |
+
layer_2=Identity
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def forward(self, x, time_embed, context=None, context_mask=None):
|
| 186 |
+
for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
|
| 187 |
+
x = in_resnet_block(x, time_embed)
|
| 188 |
+
x = attention(x, time_embed, context, context_mask)
|
| 189 |
+
x = out_resnet_block(x, time_embed)
|
| 190 |
+
x = self.self_attention_block(x, time_embed)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
class ControlNetModel(nn.Module):
|
| 194 |
+
def __init__(self,
|
| 195 |
+
model_channels,
|
| 196 |
+
init_channels=None,
|
| 197 |
+
num_channels=3,
|
| 198 |
+
out_channels=4,
|
| 199 |
+
time_embed_dim=None,
|
| 200 |
+
context_dim=None,
|
| 201 |
+
groups=32,
|
| 202 |
+
head_dim=64,
|
| 203 |
+
expansion_ratio=4,
|
| 204 |
+
compression_ratio=2,
|
| 205 |
+
dim_mult=(1, 2, 4, 8),
|
| 206 |
+
num_blocks=(3, 3, 3, 3),
|
| 207 |
+
add_cross_attention=(False, True, True, True),
|
| 208 |
+
add_self_attention=(False, True, True, True)
|
| 209 |
+
):
|
| 210 |
+
super().__init__()
|
| 211 |
+
init_channels = init_channels or model_channels
|
| 212 |
+
self.to_time_embed = nn.Sequential(
|
| 213 |
+
SinusoidalPosEmb(init_channels),
|
| 214 |
+
nn.Linear(init_channels, time_embed_dim),
|
| 215 |
+
nn.SiLU(),
|
| 216 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 217 |
+
)
|
| 218 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
| 219 |
+
|
| 220 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
| 221 |
+
|
| 222 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
| 223 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
| 224 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
| 225 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
| 226 |
+
rev_layer_params = map(reversed, layer_params)
|
| 227 |
+
|
| 228 |
+
cat_dims = []
|
| 229 |
+
self.num_levels = len(in_out_dims)
|
| 230 |
+
self.down_samples = nn.ModuleList([])
|
| 231 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
| 232 |
+
down_sample = level != (self.num_levels - 1)
|
| 233 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
| 234 |
+
self.down_samples.append(
|
| 235 |
+
DownSampleBlock(
|
| 236 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
| 237 |
+
compression_ratio, down_sample, self_attention
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def forward(self, x, time, context=None, context_mask=None):
|
| 242 |
+
time_embed = self.to_time_embed(time)
|
| 243 |
+
if exist(context):
|
| 244 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
| 245 |
+
|
| 246 |
+
hidden_states = []
|
| 247 |
+
x = self.in_layer(x)
|
| 248 |
+
for level, down_sample in enumerate(self.down_samples):
|
| 249 |
+
x = down_sample(x, time_embed, context, context_mask)
|
| 250 |
+
if level != self.num_levels - 1:
|
| 251 |
+
hidden_states.append(x)
|
| 252 |
+
return hidden_states
|
| 253 |
+
|
| 254 |
+
class UNet(nn.Module):
|
| 255 |
+
|
| 256 |
+
def __init__(self,
|
| 257 |
+
model_channels,
|
| 258 |
+
init_channels=None,
|
| 259 |
+
num_channels=3,
|
| 260 |
+
out_channels=4,
|
| 261 |
+
time_embed_dim=None,
|
| 262 |
+
context_dim=None,
|
| 263 |
+
groups=32,
|
| 264 |
+
head_dim=64,
|
| 265 |
+
expansion_ratio=4,
|
| 266 |
+
compression_ratio=2,
|
| 267 |
+
dim_mult=(1, 2, 4, 8),
|
| 268 |
+
num_blocks=(3, 3, 3, 3),
|
| 269 |
+
add_cross_attention=(False, True, True, True),
|
| 270 |
+
add_self_attention=(False, True, True, True),
|
| 271 |
+
*args,
|
| 272 |
+
**kwargs,
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
init_channels = init_channels or model_channels
|
| 276 |
+
self.to_time_embed = nn.Sequential(
|
| 277 |
+
SinusoidalPosEmb(init_channels),
|
| 278 |
+
nn.Linear(init_channels, time_embed_dim),
|
| 279 |
+
nn.SiLU(),
|
| 280 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 281 |
+
)
|
| 282 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
| 283 |
+
|
| 284 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
| 285 |
+
|
| 286 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
| 287 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
| 288 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
| 289 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
| 290 |
+
rev_layer_params = map(reversed, layer_params)
|
| 291 |
+
|
| 292 |
+
cat_dims = []
|
| 293 |
+
self.num_levels = len(in_out_dims)
|
| 294 |
+
self.down_samples = nn.ModuleList([])
|
| 295 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
| 296 |
+
down_sample = level != (self.num_levels - 1)
|
| 297 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
| 298 |
+
self.down_samples.append(
|
| 299 |
+
DownSampleBlock(
|
| 300 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
| 301 |
+
compression_ratio, down_sample, self_attention
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.up_samples = nn.ModuleList([])
|
| 306 |
+
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
|
| 307 |
+
up_sample = level != 0
|
| 308 |
+
self.up_samples.append(
|
| 309 |
+
UpSampleBlock(
|
| 310 |
+
in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
|
| 311 |
+
expansion_ratio, compression_ratio, up_sample, self_attention
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
self.out_layer = nn.Sequential(
|
| 316 |
+
nn.GroupNorm(groups, init_channels),
|
| 317 |
+
nn.SiLU(),
|
| 318 |
+
nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.control_net = None
|
| 322 |
+
|
| 323 |
+
def forward(self, x, time, context=None, context_mask=None, control_net_residual=None):
|
| 324 |
+
time_embed = self.to_time_embed(time)
|
| 325 |
+
if exist(context):
|
| 326 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
| 327 |
+
|
| 328 |
+
hidden_states = []
|
| 329 |
+
x = self.in_layer(x)
|
| 330 |
+
for level, down_sample in enumerate(self.down_samples):
|
| 331 |
+
x = down_sample(x, time_embed, context, context_mask, control_net_residual)
|
| 332 |
+
if level != self.num_levels - 1:
|
| 333 |
+
hidden_states.append(x)
|
| 334 |
+
for level, up_sample in enumerate(self.up_samples):
|
| 335 |
+
if level != 0:
|
| 336 |
+
x = torch.cat([x, hidden_states.pop()], dim=1)
|
| 337 |
+
x = up_sample(x, time_embed, context, context_mask)
|
| 338 |
+
x = self.out_layer(x)
|
| 339 |
+
return x
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class ControlNetModel(nn.Module):
|
| 343 |
+
def __init__(self,
|
| 344 |
+
model_channels,
|
| 345 |
+
init_channels=None,
|
| 346 |
+
num_channels=3,
|
| 347 |
+
out_channels=4,
|
| 348 |
+
time_embed_dim=None,
|
| 349 |
+
context_dim=None,
|
| 350 |
+
groups=32,
|
| 351 |
+
head_dim=64,
|
| 352 |
+
expansion_ratio=4,
|
| 353 |
+
compression_ratio=2,
|
| 354 |
+
dim_mult=(1, 2, 4, 8),
|
| 355 |
+
num_blocks=(3, 3, 3, 3),
|
| 356 |
+
add_cross_attention=(False, True, True, True),
|
| 357 |
+
add_self_attention=(False, True, True, True),
|
| 358 |
+
*args,
|
| 359 |
+
**kwargs,
|
| 360 |
+
):
|
| 361 |
+
super().__init__()
|
| 362 |
+
init_channels = init_channels or model_channels
|
| 363 |
+
self.to_time_embed = nn.Sequential(
|
| 364 |
+
SinusoidalPosEmb(init_channels),
|
| 365 |
+
nn.Linear(init_channels, time_embed_dim),
|
| 366 |
+
nn.SiLU(),
|
| 367 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 368 |
+
)
|
| 369 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
| 370 |
+
|
| 371 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
| 372 |
+
|
| 373 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
| 374 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
| 375 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
| 376 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
| 377 |
+
rev_layer_params = map(reversed, layer_params)
|
| 378 |
+
|
| 379 |
+
cat_dims = []
|
| 380 |
+
self.num_levels = len(in_out_dims)
|
| 381 |
+
self.down_samples = nn.ModuleList([])
|
| 382 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
| 383 |
+
down_sample = level != (self.num_levels - 1)
|
| 384 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
| 385 |
+
self.down_samples.append(
|
| 386 |
+
DownSampleBlock(
|
| 387 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
| 388 |
+
compression_ratio, down_sample, self_attention
|
| 389 |
+
)
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def forward(self, x, time, context=None, context_mask=None):
|
| 393 |
+
time_embed = self.to_time_embed(time)
|
| 394 |
+
if exist(context):
|
| 395 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
| 396 |
+
|
| 397 |
+
hidden_states = []
|
| 398 |
+
x = self.in_layer(x)
|
| 399 |
+
for level, down_sample in enumerate(self.down_samples):
|
| 400 |
+
x = down_sample(x, time_embed, context, context_mask)
|
| 401 |
+
if level != self.num_levels - 1:
|
| 402 |
+
hidden_states.append(x)
|
| 403 |
+
return hidden_states
|
| 404 |
+
|
| 405 |
+
class ControlUNet(nn.Module):
|
| 406 |
+
|
| 407 |
+
def __init__(self,
|
| 408 |
+
model_channels,
|
| 409 |
+
init_channels=None,
|
| 410 |
+
num_channels=3,
|
| 411 |
+
out_channels=4,
|
| 412 |
+
time_embed_dim=None,
|
| 413 |
+
context_dim=None,
|
| 414 |
+
groups=32,
|
| 415 |
+
head_dim=64,
|
| 416 |
+
expansion_ratio=4,
|
| 417 |
+
compression_ratio=2,
|
| 418 |
+
dim_mult=(1, 2, 4, 8),
|
| 419 |
+
num_blocks=(3, 3, 3, 3),
|
| 420 |
+
add_cross_attention=(False, True, True, True),
|
| 421 |
+
add_self_attention=(False, True, True, True),
|
| 422 |
+
control_net_channels=5,
|
| 423 |
+
*args,
|
| 424 |
+
**kwargs,
|
| 425 |
+
):
|
| 426 |
+
super().__init__()
|
| 427 |
+
init_channels = init_channels or model_channels
|
| 428 |
+
self.to_time_embed = nn.Sequential(
|
| 429 |
+
SinusoidalPosEmb(init_channels),
|
| 430 |
+
nn.Linear(init_channels, time_embed_dim),
|
| 431 |
+
nn.SiLU(),
|
| 432 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 433 |
+
)
|
| 434 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
| 435 |
+
|
| 436 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
| 437 |
+
|
| 438 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
| 439 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
| 440 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
| 441 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
| 442 |
+
rev_layer_params = map(reversed, layer_params)
|
| 443 |
+
|
| 444 |
+
cat_dims = []
|
| 445 |
+
self.num_levels = len(in_out_dims)
|
| 446 |
+
self.down_samples = nn.ModuleList([])
|
| 447 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
| 448 |
+
down_sample = level != (self.num_levels - 1)
|
| 449 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
| 450 |
+
self.down_samples.append(
|
| 451 |
+
DownSampleBlock(
|
| 452 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
| 453 |
+
compression_ratio, down_sample, self_attention
|
| 454 |
+
)
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
self.up_samples = nn.ModuleList([])
|
| 458 |
+
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
|
| 459 |
+
up_sample = level != 0
|
| 460 |
+
self.up_samples.append(
|
| 461 |
+
UpSampleBlock(
|
| 462 |
+
in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
|
| 463 |
+
expansion_ratio, compression_ratio, up_sample, self_attention
|
| 464 |
+
)
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
self.out_layer = nn.Sequential(
|
| 468 |
+
nn.GroupNorm(groups, init_channels),
|
| 469 |
+
nn.SiLU(),
|
| 470 |
+
nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
self.control_net = ControlNetModel(model_channels,
|
| 474 |
+
init_channels,
|
| 475 |
+
control_net_channels,
|
| 476 |
+
out_channels,
|
| 477 |
+
time_embed_dim,
|
| 478 |
+
context_dim,
|
| 479 |
+
groups,
|
| 480 |
+
head_dim,
|
| 481 |
+
expansion_ratio,
|
| 482 |
+
compression_ratio,
|
| 483 |
+
dim_mult,
|
| 484 |
+
num_blocks,
|
| 485 |
+
add_cross_attention,
|
| 486 |
+
add_self_attention)
|
| 487 |
+
|
| 488 |
+
def forward(self, x, time, context=None, context_mask=None, control_net_data=None):
|
| 489 |
+
time_embed = self.to_time_embed(time)
|
| 490 |
+
if exist(context):
|
| 491 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
| 492 |
+
|
| 493 |
+
control_net_hiddens = self.control_net(control_net_data, time, context, context_mask)
|
| 494 |
+
hidden_states = []
|
| 495 |
+
x = self.in_layer(x)
|
| 496 |
+
for level, down_sample in enumerate(self.down_samples):
|
| 497 |
+
x = down_sample(x, time_embed, context, context_mask)
|
| 498 |
+
if level != self.num_levels - 1:
|
| 499 |
+
x += control_net_hiddens.pop(0)
|
| 500 |
+
hidden_states.append(x)
|
| 501 |
+
for level, up_sample in enumerate(self.up_samples):
|
| 502 |
+
if level != 0:
|
| 503 |
+
x = torch.cat([x, hidden_states.pop()], dim=1)
|
| 504 |
+
x = up_sample(x, time_embed, context, context_mask)
|
| 505 |
+
x = self.out_layer(x)
|
| 506 |
+
return x
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def get_control_unet(conf):
|
| 510 |
+
unet = ControlUNet(**conf)
|
| 511 |
+
return unet
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def get_unet(conf):
|
| 515 |
+
unet = UNet(**conf)
|
| 516 |
+
return unet
|
kandinsky3/model/__init__.py
ADDED
|
File without changes
|
kandinsky3/model/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (212 Bytes). View file
|
|
|
kandinsky3/model/__pycache__/diffusion.cpython-310.pyc
ADDED
|
Binary file (6.4 kB). View file
|
|
|
kandinsky3/model/__pycache__/nn.cpython-310.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
kandinsky3/model/__pycache__/unet.cpython-310.pyc
ADDED
|
Binary file (13.9 kB). View file
|
|
|
kandinsky3/model/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
kandinsky3/model/diffusion.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from .utils import get_tensor_items
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_named_beta_schedule(schedule_name, timesteps):
|
| 11 |
+
if schedule_name == "linear":
|
| 12 |
+
scale = 1000 / timesteps
|
| 13 |
+
beta_start = scale * 0.0001
|
| 14 |
+
beta_end = scale * 0.02
|
| 15 |
+
return torch.linspace(
|
| 16 |
+
beta_start, beta_end, timesteps, dtype=torch.float32
|
| 17 |
+
)
|
| 18 |
+
elif schedule_name == "cosine":
|
| 19 |
+
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 20 |
+
betas = []
|
| 21 |
+
for i in range(timesteps):
|
| 22 |
+
t1 = i / timesteps
|
| 23 |
+
t2 = (i + 1) / timesteps
|
| 24 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999))
|
| 25 |
+
return torch.tensor(betas, dtype=torch.float32)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BaseDiffusion:
|
| 29 |
+
|
| 30 |
+
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like):
|
| 31 |
+
self.betas = betas
|
| 32 |
+
self.num_timesteps = betas.shape[0]
|
| 33 |
+
|
| 34 |
+
alphas = 1. - betas
|
| 35 |
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 36 |
+
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]])
|
| 37 |
+
|
| 38 |
+
# calculate q(x_t | x_{t-1})
|
| 39 |
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
| 40 |
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
|
| 41 |
+
|
| 42 |
+
# calculate q(x_{t-1} | x_t, x_0)
|
| 43 |
+
self.posterior_mean_coef_1 = torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod)
|
| 44 |
+
self.posterior_mean_coef_2 = torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
|
| 45 |
+
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
|
| 46 |
+
self.posterior_log_variance = torch.log(
|
| 47 |
+
torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]])
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.percentile = percentile
|
| 51 |
+
self.time_scale = 1000 // self.num_timesteps
|
| 52 |
+
self.gen_noise = gen_noise
|
| 53 |
+
self.jump_length = 3
|
| 54 |
+
|
| 55 |
+
def process_x_start(self, x_start):
|
| 56 |
+
bs, ndims = x_start.shape[0], len(x_start.shape[1:])
|
| 57 |
+
if self.percentile is not None:
|
| 58 |
+
quantile = torch.quantile(
|
| 59 |
+
rearrange(x_start, 'b ... -> b (...)').abs(),
|
| 60 |
+
self.percentile,
|
| 61 |
+
dim=-1
|
| 62 |
+
)
|
| 63 |
+
quantile = torch.clip(quantile, min=1.)
|
| 64 |
+
quantile = quantile.reshape(bs, *((1,) * ndims))
|
| 65 |
+
return torch.clip(x_start, -quantile, quantile) / quantile
|
| 66 |
+
else:
|
| 67 |
+
return torch.clip(x_start, -1., 1.)
|
| 68 |
+
|
| 69 |
+
def get_x_start(self, x, t, noise):
|
| 70 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
|
| 71 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape)
|
| 72 |
+
pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod
|
| 73 |
+
return pred_x_start
|
| 74 |
+
|
| 75 |
+
def get_noise(self, x, t, x_start):
|
| 76 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
| 77 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
|
| 78 |
+
pred_noise = (x - sqrt_alphas_cumprod * x_start) / sqrt_one_minus_alphas_cumprod
|
| 79 |
+
return pred_noise
|
| 80 |
+
|
| 81 |
+
def q_sample(self, x_start, t, noise=None):
|
| 82 |
+
if noise is None:
|
| 83 |
+
noise = self.gen_noise(x_start)
|
| 84 |
+
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape)
|
| 85 |
+
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape)
|
| 86 |
+
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise
|
| 87 |
+
return x_t
|
| 88 |
+
|
| 89 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
| 90 |
+
posterior_mean_coef_1 = get_tensor_items(self.posterior_mean_coef_1, t, x_start.shape)
|
| 91 |
+
posterior_mean_coef_2 = get_tensor_items(self.posterior_mean_coef_2, t, x_t.shape)
|
| 92 |
+
posterior_mean = posterior_mean_coef_1 * x_start + posterior_mean_coef_2 * x_t
|
| 93 |
+
|
| 94 |
+
posterior_variance = get_tensor_items(self.posterior_variance, t, x_start.shape)
|
| 95 |
+
posterior_log_variance = get_tensor_items(self.posterior_log_variance, t, x_start.shape)
|
| 96 |
+
return posterior_mean, posterior_variance, posterior_log_variance
|
| 97 |
+
|
| 98 |
+
def q_posterior_variance(self, t, prev_t, shape, eta=1., ):
|
| 99 |
+
alphas_cumprod = get_tensor_items(self.alphas_cumprod, t, shape)
|
| 100 |
+
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, shape)
|
| 101 |
+
|
| 102 |
+
posterior_variance = torch.sqrt(
|
| 103 |
+
eta * (1. - alphas_cumprod / prev_alphas_cumprod) * (1. - prev_alphas_cumprod) / (1. - alphas_cumprod)
|
| 104 |
+
)
|
| 105 |
+
return posterior_variance
|
| 106 |
+
|
| 107 |
+
def text_guidance(
|
| 108 |
+
self, model, x, t, context, context_mask, null_embedding, guidance_weight_text,
|
| 109 |
+
uncondition_context=None, uncondition_context_mask=None, mask=None, masked_latent=None
|
| 110 |
+
):
|
| 111 |
+
large_x = x.repeat(2, 1, 1, 1)
|
| 112 |
+
large_t = t.repeat(2).to(x.dtype)
|
| 113 |
+
|
| 114 |
+
if uncondition_context is None:
|
| 115 |
+
uncondition_context = torch.zeros_like(context)
|
| 116 |
+
uncondition_context_mask = torch.zeros_like(context_mask)
|
| 117 |
+
uncondition_context[:, 0] = null_embedding
|
| 118 |
+
uncondition_context_mask[:, 0] = 1
|
| 119 |
+
large_context = torch.cat([context, uncondition_context])
|
| 120 |
+
large_context_mask = torch.cat([context_mask, uncondition_context_mask])
|
| 121 |
+
|
| 122 |
+
if mask is not None:
|
| 123 |
+
mask = mask.repeat(2, 1, 1, 1)
|
| 124 |
+
if masked_latent is not None:
|
| 125 |
+
masked_latent = masked_latent.repeat(2, 1, 1, 1)
|
| 126 |
+
|
| 127 |
+
if model.in_layer.in_channels == 9:
|
| 128 |
+
large_x = torch.cat([large_x, mask, masked_latent], dim=1)
|
| 129 |
+
|
| 130 |
+
pred_large_noise = model(large_x, large_t * self.time_scale, large_context, large_context_mask.bool())
|
| 131 |
+
pred_noise, uncond_pred_noise = torch.chunk(pred_large_noise, 2)
|
| 132 |
+
pred_noise = (guidance_weight_text + 1.) * pred_noise - guidance_weight_text * uncond_pred_noise
|
| 133 |
+
return pred_noise
|
| 134 |
+
|
| 135 |
+
def p_mean_variance(
|
| 136 |
+
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
|
| 137 |
+
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
|
| 138 |
+
):
|
| 139 |
+
|
| 140 |
+
pred_noise = self.text_guidance(
|
| 141 |
+
model, x, t, context, context_mask, null_embedding, guidance_weight_text,
|
| 142 |
+
negative_context, negative_context_mask, mask, masked_latent
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
pred_x_start = self.get_x_start(x, t, pred_noise)
|
| 146 |
+
pred_x_start = self.process_x_start(pred_x_start)
|
| 147 |
+
pred_noise = self.get_noise(x, t, pred_x_start)
|
| 148 |
+
pred_var = self.q_posterior_variance(t, prev_t, x.shape, eta)
|
| 149 |
+
|
| 150 |
+
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, x.shape)
|
| 151 |
+
pred_mean = torch.sqrt(prev_alphas_cumprod) * pred_x_start
|
| 152 |
+
pred_mean += torch.sqrt(1. - prev_alphas_cumprod - pred_var ** 2) * pred_noise
|
| 153 |
+
return pred_mean, pred_var
|
| 154 |
+
|
| 155 |
+
# @torch.no_grad()
|
| 156 |
+
def p_sample(
|
| 157 |
+
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
|
| 158 |
+
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None
|
| 159 |
+
):
|
| 160 |
+
bs = x.shape[0]
|
| 161 |
+
ndims = len(x.shape[1:])
|
| 162 |
+
pred_mean, pred_var = self.p_mean_variance(
|
| 163 |
+
model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta,
|
| 164 |
+
negative_context=negative_context, negative_context_mask=negative_context_mask,
|
| 165 |
+
mask=mask, masked_latent=masked_latent
|
| 166 |
+
)
|
| 167 |
+
noise = torch.randn_like(x)
|
| 168 |
+
mask = (prev_t != 0).reshape(bs, *((1,) * ndims))
|
| 169 |
+
sample = pred_mean + mask * pred_var * noise
|
| 170 |
+
return sample
|
| 171 |
+
|
| 172 |
+
# @torch.no_grad()
|
| 173 |
+
def p_sample_loop(
|
| 174 |
+
self, model, shape, times, device, context, context_mask, null_embedding, guidance_weight_text, eta=1.,
|
| 175 |
+
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None, gan=False,
|
| 176 |
+
):
|
| 177 |
+
img = torch.randn(*shape, device=device)
|
| 178 |
+
times = times + [0, ]
|
| 179 |
+
times = list(zip(times[:-1], times[1:]))
|
| 180 |
+
|
| 181 |
+
for time, prev_time in tqdm(times):
|
| 182 |
+
time = torch.tensor([time] * shape[0], device=device)
|
| 183 |
+
if gan:
|
| 184 |
+
x_t = self.q_sample(img, time)
|
| 185 |
+
pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool())
|
| 186 |
+
img = self.get_x_start(x_t, time, pred_noise)
|
| 187 |
+
else:
|
| 188 |
+
prev_time = torch.tensor([prev_time] * shape[0], device=device)
|
| 189 |
+
img = self.p_sample(
|
| 190 |
+
model, img, time, prev_time, context, context_mask, null_embedding, guidance_weight_text, eta,
|
| 191 |
+
negative_context=negative_context, negative_context_mask=negative_context_mask,
|
| 192 |
+
mask=mask, masked_latent=masked_latent
|
| 193 |
+
)
|
| 194 |
+
return img
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def get_diffusion(conf):
|
| 198 |
+
betas = get_named_beta_schedule(**conf.schedule_params)
|
| 199 |
+
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params)
|
| 200 |
+
return base_diffusion
|
kandinsky3/model/nn.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, einsum
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
|
| 7 |
+
from .utils import exist
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Identity(nn.Module):
|
| 11 |
+
def __init__(self, *args, **kwargs):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def forward(x, *args, **kwargs):
|
| 16 |
+
return x
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SinusoidalPosEmb(nn.Module):
|
| 20 |
+
|
| 21 |
+
def __init__(self, dim):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.dim = dim
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
half_dim = self.dim // 2
|
| 27 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 28 |
+
emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
|
| 29 |
+
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
| 30 |
+
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ConditionalGroupNorm(nn.Module):
|
| 34 |
+
|
| 35 |
+
def __init__(self, groups, normalized_shape, context_dim):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
|
| 38 |
+
self.context_mlp = nn.Sequential(
|
| 39 |
+
nn.SiLU(),
|
| 40 |
+
nn.Linear(context_dim, 2 * normalized_shape)
|
| 41 |
+
)
|
| 42 |
+
self.context_mlp[1].weight.data.zero_()
|
| 43 |
+
self.context_mlp[1].bias.data.zero_()
|
| 44 |
+
|
| 45 |
+
def forward(self, x, context):
|
| 46 |
+
context = self.context_mlp(context)
|
| 47 |
+
ndims = ' 1' * len(x.shape[2:])
|
| 48 |
+
context = rearrange(context, f'b c -> b c{ndims}')
|
| 49 |
+
|
| 50 |
+
scale, shift = context.chunk(2, dim=1)
|
| 51 |
+
x = self.norm(x) * (scale + 1.) + shift
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Attention(nn.Module):
|
| 56 |
+
|
| 57 |
+
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
|
| 58 |
+
super().__init__()
|
| 59 |
+
assert out_channels % head_dim == 0
|
| 60 |
+
self.num_heads = out_channels // head_dim
|
| 61 |
+
self.scale = head_dim ** -0.5
|
| 62 |
+
|
| 63 |
+
self.to_query = nn.Linear(in_channels, out_channels, bias=False)
|
| 64 |
+
self.to_key = nn.Linear(context_dim, out_channels, bias=False)
|
| 65 |
+
self.to_value = nn.Linear(context_dim, out_channels, bias=False)
|
| 66 |
+
|
| 67 |
+
self.output_layer = nn.Linear(out_channels, out_channels, bias=False)
|
| 68 |
+
|
| 69 |
+
def forward(self, x, context, context_mask=None):
|
| 70 |
+
query = rearrange(self.to_query(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
| 71 |
+
key = rearrange(self.to_key(context), 'b n (h d) -> b h n d', h=self.num_heads)
|
| 72 |
+
value = rearrange(self.to_value(context), 'b n (h d) -> b h n d', h=self.num_heads)
|
| 73 |
+
|
| 74 |
+
attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key) * self.scale
|
| 75 |
+
if exist(context_mask):
|
| 76 |
+
max_neg_value = -torch.finfo(attention_matrix.dtype).max
|
| 77 |
+
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
|
| 78 |
+
attention_matrix = attention_matrix.masked_fill(~context_mask, max_neg_value)
|
| 79 |
+
attention_matrix = attention_matrix.softmax(dim=-1)
|
| 80 |
+
|
| 81 |
+
out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
|
| 82 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 83 |
+
out = self.output_layer(out)
|
| 84 |
+
return out
|
kandinsky3/model/unet.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, einsum
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
from .nn import Identity, Attention, SinusoidalPosEmb, ConditionalGroupNorm
|
| 6 |
+
from .utils import exist, set_default_item, set_default_layer
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Block(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.group_norm = ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
|
| 15 |
+
self.activation = nn.SiLU()
|
| 16 |
+
self.up_sample = set_default_layer(
|
| 17 |
+
exist(up_resolution) and up_resolution,
|
| 18 |
+
nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
|
| 19 |
+
)
|
| 20 |
+
padding = set_default_item(kernel_size == 1, 0, 1)
|
| 21 |
+
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
|
| 22 |
+
self.down_sample = set_default_layer(
|
| 23 |
+
exist(up_resolution) and not up_resolution,
|
| 24 |
+
nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, x, time_embed):
|
| 28 |
+
x = self.group_norm(x, time_embed)
|
| 29 |
+
x = self.activation(x)
|
| 30 |
+
x = self.up_sample(x)
|
| 31 |
+
x = self.projection(x)
|
| 32 |
+
x = self.down_sample(x)
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ResNetBlock(nn.Module):
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4*[None]
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
kernel_sizes = [1, 3, 3, 1]
|
| 43 |
+
hidden_channel = max(in_channels, out_channels) // compression_ratio
|
| 44 |
+
hidden_channels = [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
|
| 45 |
+
self.resnet_blocks = nn.ModuleList([
|
| 46 |
+
Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
|
| 47 |
+
for (in_channel, out_channel), kernel_size, up_resolution in zip(hidden_channels, kernel_sizes, up_resolutions)
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
self.shortcut_up_sample = set_default_layer(
|
| 51 |
+
True in up_resolutions,
|
| 52 |
+
nn.ConvTranspose2d, (in_channels, in_channels), {'kernel_size': 2, 'stride': 2}
|
| 53 |
+
)
|
| 54 |
+
self.shortcut_projection = set_default_layer(
|
| 55 |
+
in_channels != out_channels,
|
| 56 |
+
nn.Conv2d, (in_channels, out_channels), {'kernel_size': 1}
|
| 57 |
+
)
|
| 58 |
+
self.shortcut_down_sample = set_default_layer(
|
| 59 |
+
False in up_resolutions,
|
| 60 |
+
nn.Conv2d, (out_channels, out_channels), {'kernel_size': 2, 'stride': 2}
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, x, time_embed):
|
| 64 |
+
out = x
|
| 65 |
+
for resnet_block in self.resnet_blocks:
|
| 66 |
+
out = resnet_block(out, time_embed)
|
| 67 |
+
|
| 68 |
+
x = self.shortcut_up_sample(x)
|
| 69 |
+
x = self.shortcut_projection(x)
|
| 70 |
+
x = self.shortcut_down_sample(x)
|
| 71 |
+
x = x + out
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class AttentionPolling(nn.Module):
|
| 76 |
+
|
| 77 |
+
def __init__(self, num_channels, context_dim, head_dim=64):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
|
| 80 |
+
|
| 81 |
+
def forward(self, x, context, context_mask=None):
|
| 82 |
+
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
|
| 83 |
+
return x + context.squeeze(1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class AttentionBlock(nn.Module):
|
| 87 |
+
|
| 88 |
+
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.in_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
|
| 91 |
+
self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
|
| 92 |
+
|
| 93 |
+
hidden_channels = expansion_ratio * num_channels
|
| 94 |
+
self.out_norm = ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
|
| 95 |
+
self.feed_forward = nn.Sequential(
|
| 96 |
+
nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
|
| 97 |
+
nn.SiLU(),
|
| 98 |
+
nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def forward(self, x, time_embed, context=None, context_mask=None):
|
| 102 |
+
height, width = x.shape[-2:]
|
| 103 |
+
out = self.in_norm(x, time_embed)
|
| 104 |
+
out = rearrange(out, 'b c h w -> b (h w) c', h=height, w=width)
|
| 105 |
+
context = set_default_item(exist(context), context, out)
|
| 106 |
+
out = self.attention(out, context, context_mask)
|
| 107 |
+
out = rearrange(out, 'b (h w) c -> b c h w', h=height, w=width)
|
| 108 |
+
x = x + out
|
| 109 |
+
|
| 110 |
+
out = self.out_norm(x, time_embed)
|
| 111 |
+
out = self.feed_forward(out)
|
| 112 |
+
x = x + out
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class DownSampleBlock(nn.Module):
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self, in_channels, out_channels, time_embed_dim, context_dim=None,
|
| 120 |
+
num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
|
| 121 |
+
down_sample=True, self_attention=True
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.self_attention_block = set_default_layer(
|
| 125 |
+
self_attention,
|
| 126 |
+
AttentionBlock,
|
| 127 |
+
(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
|
| 128 |
+
layer_2=Identity
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
|
| 132 |
+
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
|
| 133 |
+
self.resnet_attn_blocks = nn.ModuleList([
|
| 134 |
+
nn.ModuleList([
|
| 135 |
+
ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
|
| 136 |
+
set_default_layer(
|
| 137 |
+
exist(context_dim),
|
| 138 |
+
AttentionBlock,
|
| 139 |
+
(out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
|
| 140 |
+
layer_2=Identity
|
| 141 |
+
),
|
| 142 |
+
ResNetBlock(out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution),
|
| 143 |
+
]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
|
| 144 |
+
])
|
| 145 |
+
|
| 146 |
+
def forward(self, x, time_embed, context=None, context_mask=None, control_net_residual=None):
|
| 147 |
+
x = self.self_attention_block(x, time_embed)
|
| 148 |
+
for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
|
| 149 |
+
x = in_resnet_block(x, time_embed)
|
| 150 |
+
x = attention(x, time_embed, context, context_mask)
|
| 151 |
+
x = out_resnet_block(x, time_embed)
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class UpSampleBlock(nn.Module):
|
| 156 |
+
|
| 157 |
+
def __init__(
|
| 158 |
+
self, in_channels, cat_dim, out_channels, time_embed_dim, context_dim=None,
|
| 159 |
+
num_blocks=3, groups=32, head_dim=64, expansion_ratio=4, compression_ratio=2,
|
| 160 |
+
up_sample=True, self_attention=True
|
| 161 |
+
):
|
| 162 |
+
super().__init__()
|
| 163 |
+
up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
|
| 164 |
+
hidden_channels = [(in_channels + cat_dim, in_channels)] + [(in_channels, in_channels)] * (num_blocks - 2) + [(in_channels, out_channels)]
|
| 165 |
+
self.resnet_attn_blocks = nn.ModuleList([
|
| 166 |
+
nn.ModuleList([
|
| 167 |
+
ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution),
|
| 168 |
+
set_default_layer(
|
| 169 |
+
exist(context_dim),
|
| 170 |
+
AttentionBlock,
|
| 171 |
+
(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
|
| 172 |
+
layer_2=Identity
|
| 173 |
+
),
|
| 174 |
+
ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio),
|
| 175 |
+
]) for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions)
|
| 176 |
+
])
|
| 177 |
+
|
| 178 |
+
self.self_attention_block = set_default_layer(
|
| 179 |
+
self_attention,
|
| 180 |
+
AttentionBlock,
|
| 181 |
+
(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
|
| 182 |
+
layer_2=Identity
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
def forward(self, x, time_embed, context=None, context_mask=None):
|
| 186 |
+
for in_resnet_block, attention, out_resnet_block in self.resnet_attn_blocks:
|
| 187 |
+
x = in_resnet_block(x, time_embed)
|
| 188 |
+
x = attention(x, time_embed, context, context_mask)
|
| 189 |
+
x = out_resnet_block(x, time_embed)
|
| 190 |
+
x = self.self_attention_block(x, time_embed)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
class ControlNetModel(nn.Module):
|
| 194 |
+
def __init__(self,
|
| 195 |
+
model_channels,
|
| 196 |
+
init_channels=None,
|
| 197 |
+
num_channels=3,
|
| 198 |
+
out_channels=4,
|
| 199 |
+
time_embed_dim=None,
|
| 200 |
+
context_dim=None,
|
| 201 |
+
groups=32,
|
| 202 |
+
head_dim=64,
|
| 203 |
+
expansion_ratio=4,
|
| 204 |
+
compression_ratio=2,
|
| 205 |
+
dim_mult=(1, 2, 4, 8),
|
| 206 |
+
num_blocks=(3, 3, 3, 3),
|
| 207 |
+
add_cross_attention=(False, True, True, True),
|
| 208 |
+
add_self_attention=(False, True, True, True)
|
| 209 |
+
):
|
| 210 |
+
super().__init__()
|
| 211 |
+
init_channels = init_channels or model_channels
|
| 212 |
+
self.to_time_embed = nn.Sequential(
|
| 213 |
+
SinusoidalPosEmb(init_channels),
|
| 214 |
+
nn.Linear(init_channels, time_embed_dim),
|
| 215 |
+
nn.SiLU(),
|
| 216 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 217 |
+
)
|
| 218 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
| 219 |
+
|
| 220 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
| 221 |
+
|
| 222 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
| 223 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
| 224 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
| 225 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
| 226 |
+
rev_layer_params = map(reversed, layer_params)
|
| 227 |
+
|
| 228 |
+
cat_dims = []
|
| 229 |
+
self.num_levels = len(in_out_dims)
|
| 230 |
+
self.down_samples = nn.ModuleList([])
|
| 231 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
| 232 |
+
down_sample = level != (self.num_levels - 1)
|
| 233 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
| 234 |
+
self.down_samples.append(
|
| 235 |
+
DownSampleBlock(
|
| 236 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
| 237 |
+
compression_ratio, down_sample, self_attention
|
| 238 |
+
)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def forward(self, x, time, context=None, context_mask=None):
|
| 242 |
+
time_embed = self.to_time_embed(time)
|
| 243 |
+
if exist(context):
|
| 244 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
| 245 |
+
|
| 246 |
+
hidden_states = []
|
| 247 |
+
x = self.in_layer(x)
|
| 248 |
+
for level, down_sample in enumerate(self.down_samples):
|
| 249 |
+
x = down_sample(x, time_embed, context, context_mask)
|
| 250 |
+
if level != self.num_levels - 1:
|
| 251 |
+
hidden_states.append(x)
|
| 252 |
+
return hidden_states
|
| 253 |
+
|
| 254 |
+
class UNet(nn.Module):
|
| 255 |
+
|
| 256 |
+
def __init__(self,
|
| 257 |
+
model_channels,
|
| 258 |
+
init_channels=None,
|
| 259 |
+
num_channels=3,
|
| 260 |
+
out_channels=4,
|
| 261 |
+
time_embed_dim=None,
|
| 262 |
+
context_dim=None,
|
| 263 |
+
groups=32,
|
| 264 |
+
head_dim=64,
|
| 265 |
+
expansion_ratio=4,
|
| 266 |
+
compression_ratio=2,
|
| 267 |
+
dim_mult=(1, 2, 4, 8),
|
| 268 |
+
num_blocks=(3, 3, 3, 3),
|
| 269 |
+
add_cross_attention=(False, True, True, True),
|
| 270 |
+
add_self_attention=(False, True, True, True),
|
| 271 |
+
*args,
|
| 272 |
+
**kwargs,
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
init_channels = init_channels or model_channels
|
| 276 |
+
self.to_time_embed = nn.Sequential(
|
| 277 |
+
SinusoidalPosEmb(init_channels),
|
| 278 |
+
nn.Linear(init_channels, time_embed_dim),
|
| 279 |
+
nn.SiLU(),
|
| 280 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 281 |
+
)
|
| 282 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
| 283 |
+
|
| 284 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
| 285 |
+
|
| 286 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
| 287 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
| 288 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
| 289 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
| 290 |
+
rev_layer_params = map(reversed, layer_params)
|
| 291 |
+
|
| 292 |
+
cat_dims = []
|
| 293 |
+
self.num_levels = len(in_out_dims)
|
| 294 |
+
self.down_samples = nn.ModuleList([])
|
| 295 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
| 296 |
+
down_sample = level != (self.num_levels - 1)
|
| 297 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
| 298 |
+
self.down_samples.append(
|
| 299 |
+
DownSampleBlock(
|
| 300 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
| 301 |
+
compression_ratio, down_sample, self_attention
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.up_samples = nn.ModuleList([])
|
| 306 |
+
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
|
| 307 |
+
up_sample = level != 0
|
| 308 |
+
self.up_samples.append(
|
| 309 |
+
UpSampleBlock(
|
| 310 |
+
in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
|
| 311 |
+
expansion_ratio, compression_ratio, up_sample, self_attention
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
self.out_layer = nn.Sequential(
|
| 316 |
+
nn.GroupNorm(groups, init_channels),
|
| 317 |
+
nn.SiLU(),
|
| 318 |
+
nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.control_net = None
|
| 322 |
+
|
| 323 |
+
def forward(self, x, time, context=None, context_mask=None, control_net_residual=None):
|
| 324 |
+
time_embed = self.to_time_embed(time)
|
| 325 |
+
if exist(context):
|
| 326 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
| 327 |
+
|
| 328 |
+
hidden_states = []
|
| 329 |
+
x = self.in_layer(x)
|
| 330 |
+
for level, down_sample in enumerate(self.down_samples):
|
| 331 |
+
x = down_sample(x, time_embed, context, context_mask, control_net_residual)
|
| 332 |
+
if level != self.num_levels - 1:
|
| 333 |
+
hidden_states.append(x)
|
| 334 |
+
for level, up_sample in enumerate(self.up_samples):
|
| 335 |
+
if level != 0:
|
| 336 |
+
x = torch.cat([x, hidden_states.pop()], dim=1)
|
| 337 |
+
x = up_sample(x, time_embed, context, context_mask)
|
| 338 |
+
x = self.out_layer(x)
|
| 339 |
+
return x
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class ControlNetModel(nn.Module):
|
| 343 |
+
def __init__(self,
|
| 344 |
+
model_channels,
|
| 345 |
+
init_channels=None,
|
| 346 |
+
num_channels=3,
|
| 347 |
+
out_channels=4,
|
| 348 |
+
time_embed_dim=None,
|
| 349 |
+
context_dim=None,
|
| 350 |
+
groups=32,
|
| 351 |
+
head_dim=64,
|
| 352 |
+
expansion_ratio=4,
|
| 353 |
+
compression_ratio=2,
|
| 354 |
+
dim_mult=(1, 2, 4, 8),
|
| 355 |
+
num_blocks=(3, 3, 3, 3),
|
| 356 |
+
add_cross_attention=(False, True, True, True),
|
| 357 |
+
add_self_attention=(False, True, True, True),
|
| 358 |
+
*args,
|
| 359 |
+
**kwargs,
|
| 360 |
+
):
|
| 361 |
+
super().__init__()
|
| 362 |
+
init_channels = init_channels or model_channels
|
| 363 |
+
self.to_time_embed = nn.Sequential(
|
| 364 |
+
SinusoidalPosEmb(init_channels),
|
| 365 |
+
nn.Linear(init_channels, time_embed_dim),
|
| 366 |
+
nn.SiLU(),
|
| 367 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 368 |
+
)
|
| 369 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
| 370 |
+
|
| 371 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
| 372 |
+
|
| 373 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
| 374 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
| 375 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
| 376 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
| 377 |
+
rev_layer_params = map(reversed, layer_params)
|
| 378 |
+
|
| 379 |
+
cat_dims = []
|
| 380 |
+
self.num_levels = len(in_out_dims)
|
| 381 |
+
self.down_samples = nn.ModuleList([])
|
| 382 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
| 383 |
+
down_sample = level != (self.num_levels - 1)
|
| 384 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
| 385 |
+
self.down_samples.append(
|
| 386 |
+
DownSampleBlock(
|
| 387 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
| 388 |
+
compression_ratio, down_sample, self_attention
|
| 389 |
+
)
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
def forward(self, x, time, context=None, context_mask=None):
|
| 393 |
+
time_embed = self.to_time_embed(time)
|
| 394 |
+
if exist(context):
|
| 395 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
| 396 |
+
|
| 397 |
+
hidden_states = []
|
| 398 |
+
x = self.in_layer(x)
|
| 399 |
+
for level, down_sample in enumerate(self.down_samples):
|
| 400 |
+
x = down_sample(x, time_embed, context, context_mask)
|
| 401 |
+
if level != self.num_levels - 1:
|
| 402 |
+
hidden_states.append(x)
|
| 403 |
+
return hidden_states
|
| 404 |
+
|
| 405 |
+
class ControlUNet(nn.Module):
|
| 406 |
+
|
| 407 |
+
def __init__(self,
|
| 408 |
+
model_channels,
|
| 409 |
+
init_channels=None,
|
| 410 |
+
num_channels=3,
|
| 411 |
+
out_channels=4,
|
| 412 |
+
time_embed_dim=None,
|
| 413 |
+
context_dim=None,
|
| 414 |
+
groups=32,
|
| 415 |
+
head_dim=64,
|
| 416 |
+
expansion_ratio=4,
|
| 417 |
+
compression_ratio=2,
|
| 418 |
+
dim_mult=(1, 2, 4, 8),
|
| 419 |
+
num_blocks=(3, 3, 3, 3),
|
| 420 |
+
add_cross_attention=(False, True, True, True),
|
| 421 |
+
add_self_attention=(False, True, True, True),
|
| 422 |
+
control_net_channels=5,
|
| 423 |
+
*args,
|
| 424 |
+
**kwargs,
|
| 425 |
+
):
|
| 426 |
+
super().__init__()
|
| 427 |
+
init_channels = init_channels or model_channels
|
| 428 |
+
self.to_time_embed = nn.Sequential(
|
| 429 |
+
SinusoidalPosEmb(init_channels),
|
| 430 |
+
nn.Linear(init_channels, time_embed_dim),
|
| 431 |
+
nn.SiLU(),
|
| 432 |
+
nn.Linear(time_embed_dim, time_embed_dim)
|
| 433 |
+
)
|
| 434 |
+
self.feature_pooling = AttentionPolling(time_embed_dim, context_dim, head_dim)
|
| 435 |
+
|
| 436 |
+
self.in_layer = nn.Conv2d(num_channels, init_channels, kernel_size=3, padding=1)
|
| 437 |
+
|
| 438 |
+
hidden_dims = [init_channels, *map(lambda mult: model_channels * mult, dim_mult)]
|
| 439 |
+
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
|
| 440 |
+
text_dims = [set_default_item(is_exist, context_dim) for is_exist in add_cross_attention]
|
| 441 |
+
layer_params = [num_blocks, text_dims, add_self_attention]
|
| 442 |
+
rev_layer_params = map(reversed, layer_params)
|
| 443 |
+
|
| 444 |
+
cat_dims = []
|
| 445 |
+
self.num_levels = len(in_out_dims)
|
| 446 |
+
self.down_samples = nn.ModuleList([])
|
| 447 |
+
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(zip(in_out_dims, *layer_params)):
|
| 448 |
+
down_sample = level != (self.num_levels - 1)
|
| 449 |
+
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
|
| 450 |
+
self.down_samples.append(
|
| 451 |
+
DownSampleBlock(
|
| 452 |
+
in_dim, out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim, expansion_ratio,
|
| 453 |
+
compression_ratio, down_sample, self_attention
|
| 454 |
+
)
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
self.up_samples = nn.ModuleList([])
|
| 458 |
+
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(zip(reversed(in_out_dims), *rev_layer_params)):
|
| 459 |
+
up_sample = level != 0
|
| 460 |
+
self.up_samples.append(
|
| 461 |
+
UpSampleBlock(
|
| 462 |
+
in_dim, cat_dims.pop(), out_dim, time_embed_dim, text_dim, res_block_num, groups, head_dim,
|
| 463 |
+
expansion_ratio, compression_ratio, up_sample, self_attention
|
| 464 |
+
)
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
self.out_layer = nn.Sequential(
|
| 468 |
+
nn.GroupNorm(groups, init_channels),
|
| 469 |
+
nn.SiLU(),
|
| 470 |
+
nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
self.control_net = ControlNetModel(model_channels,
|
| 474 |
+
init_channels,
|
| 475 |
+
control_net_channels,
|
| 476 |
+
out_channels,
|
| 477 |
+
time_embed_dim,
|
| 478 |
+
context_dim,
|
| 479 |
+
groups,
|
| 480 |
+
head_dim,
|
| 481 |
+
expansion_ratio,
|
| 482 |
+
compression_ratio,
|
| 483 |
+
dim_mult,
|
| 484 |
+
num_blocks,
|
| 485 |
+
add_cross_attention,
|
| 486 |
+
add_self_attention)
|
| 487 |
+
|
| 488 |
+
def forward(self, x, time, context=None, context_mask=None, control_net_data=None):
|
| 489 |
+
time_embed = self.to_time_embed(time)
|
| 490 |
+
if exist(context):
|
| 491 |
+
time_embed = self.feature_pooling(time_embed, context, context_mask)
|
| 492 |
+
|
| 493 |
+
control_net_hiddens = self.control_net(control_net_data, time, context, context_mask)
|
| 494 |
+
hidden_states = []
|
| 495 |
+
x = self.in_layer(x)
|
| 496 |
+
for level, down_sample in enumerate(self.down_samples):
|
| 497 |
+
x = down_sample(x, time_embed, context, context_mask)
|
| 498 |
+
if level != self.num_levels - 1:
|
| 499 |
+
x += control_net_hiddens.pop(0)
|
| 500 |
+
hidden_states.append(x)
|
| 501 |
+
for level, up_sample in enumerate(self.up_samples):
|
| 502 |
+
if level != 0:
|
| 503 |
+
x = torch.cat([x, hidden_states.pop()], dim=1)
|
| 504 |
+
x = up_sample(x, time_embed, context, context_mask)
|
| 505 |
+
x = self.out_layer(x)
|
| 506 |
+
return x
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def get_control_unet(conf):
|
| 510 |
+
unet = ControlUNet(**conf)
|
| 511 |
+
return unet
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def get_unet(conf):
|
| 515 |
+
unet = UNet(**conf)
|
| 516 |
+
return unet
|
kandinsky3/model/utils.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.nn import Identity
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def exist(item):
|
| 6 |
+
return item is not None
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def set_default_item(condition, item_1, item_2=None):
|
| 10 |
+
if condition:
|
| 11 |
+
return item_1
|
| 12 |
+
else:
|
| 13 |
+
return item_2
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=Identity, args_2=[], kwargs_2={}):
|
| 17 |
+
if condition:
|
| 18 |
+
return layer_1(*args_1, **kwargs_1)
|
| 19 |
+
else:
|
| 20 |
+
return layer_2(*args_2, **kwargs_2)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_tensor_items(x, pos, broadcast_shape):
|
| 24 |
+
device = pos.device
|
| 25 |
+
bs = pos.shape[0]
|
| 26 |
+
ndims = len(broadcast_shape[1:])
|
| 27 |
+
x = x.cpu()[pos.cpu()]
|
| 28 |
+
return x.reshape(bs, *((1,) * ndims)).to(device)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def local_patching(x, height, width, group_size):
|
| 32 |
+
if group_size > 0:
|
| 33 |
+
x = rearrange(
|
| 34 |
+
x, 'b c (h g1) (w g2) -> b (h w) (g1 g2) c',
|
| 35 |
+
h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
x = rearrange(x, 'b c h w -> b (h w) c', h=height, w=width)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def local_merge(x, height, width, group_size):
|
| 43 |
+
if group_size > 0:
|
| 44 |
+
x = rearrange(
|
| 45 |
+
x, 'b (h w) (g1 g2) c -> b c (h g1) (w g2)',
|
| 46 |
+
h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=height, w=width)
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def global_patching(x, height, width, group_size):
|
| 54 |
+
x = local_patching(x, height, width, height//group_size)
|
| 55 |
+
x = x.transpose(-2, -3)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def global_merge(x, height, width, group_size):
|
| 60 |
+
x = x.transpose(-2, -3)
|
| 61 |
+
x = local_merge(x, height, width, height//group_size)
|
| 62 |
+
return x
|
kandinsky3/movq.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from .utils import freeze
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def nonlinearity(x):
|
| 11 |
+
return x*torch.sigmoid(x)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SpatialNorm(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self, f_channels, zq_channels=None, norm_layer=nn.GroupNorm, freeze_norm_layer=False, add_conv=False, **norm_layer_params
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
|
| 20 |
+
if zq_channels is not None:
|
| 21 |
+
if freeze_norm_layer:
|
| 22 |
+
for p in self.norm_layer.parameters:
|
| 23 |
+
p.requires_grad = False
|
| 24 |
+
self.add_conv = add_conv
|
| 25 |
+
if self.add_conv:
|
| 26 |
+
self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
|
| 27 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
| 28 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
| 29 |
+
def forward(self, f, zq=None):
|
| 30 |
+
norm_f = self.norm_layer(f)
|
| 31 |
+
if zq is not None:
|
| 32 |
+
f_size = f.shape[-2:]
|
| 33 |
+
zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
|
| 34 |
+
if self.add_conv:
|
| 35 |
+
zq = self.conv(zq)
|
| 36 |
+
norm_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
| 37 |
+
return norm_f
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def Normalize(in_channels, zq_ch=None, add_conv=None):
|
| 41 |
+
return SpatialNorm(
|
| 42 |
+
in_channels, zq_ch, norm_layer=nn.GroupNorm,
|
| 43 |
+
freeze_norm_layer=False, add_conv=add_conv, num_groups=32, eps=1e-6, affine=True
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Upsample(nn.Module):
|
| 48 |
+
def __init__(self, in_channels, with_conv):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.with_conv = with_conv
|
| 51 |
+
if self.with_conv:
|
| 52 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 53 |
+
in_channels,
|
| 54 |
+
kernel_size=3,
|
| 55 |
+
stride=1,
|
| 56 |
+
padding=1)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 60 |
+
if self.with_conv:
|
| 61 |
+
x = self.conv(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Downsample(nn.Module):
|
| 66 |
+
def __init__(self, in_channels, with_conv):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.with_conv = with_conv
|
| 69 |
+
if self.with_conv:
|
| 70 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 71 |
+
in_channels,
|
| 72 |
+
kernel_size=3,
|
| 73 |
+
stride=2,
|
| 74 |
+
padding=0)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
if self.with_conv:
|
| 78 |
+
pad = (0,1,0,1)
|
| 79 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 80 |
+
x = self.conv(x)
|
| 81 |
+
else:
|
| 82 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 83 |
+
return x
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ResnetBlock(nn.Module):
|
| 87 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
| 88 |
+
dropout, temb_channels=512, zq_ch=None, add_conv=False):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.in_channels = in_channels
|
| 91 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 92 |
+
self.out_channels = out_channels
|
| 93 |
+
self.use_conv_shortcut = conv_shortcut
|
| 94 |
+
|
| 95 |
+
self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
|
| 96 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 97 |
+
out_channels,
|
| 98 |
+
kernel_size=3,
|
| 99 |
+
stride=1,
|
| 100 |
+
padding=1)
|
| 101 |
+
if temb_channels > 0:
|
| 102 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
| 103 |
+
out_channels)
|
| 104 |
+
self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
|
| 105 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 106 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 107 |
+
out_channels,
|
| 108 |
+
kernel_size=3,
|
| 109 |
+
stride=1,
|
| 110 |
+
padding=1)
|
| 111 |
+
if self.in_channels != self.out_channels:
|
| 112 |
+
if self.use_conv_shortcut:
|
| 113 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
| 114 |
+
out_channels,
|
| 115 |
+
kernel_size=3,
|
| 116 |
+
stride=1,
|
| 117 |
+
padding=1)
|
| 118 |
+
else:
|
| 119 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 120 |
+
out_channels,
|
| 121 |
+
kernel_size=1,
|
| 122 |
+
stride=1,
|
| 123 |
+
padding=0)
|
| 124 |
+
|
| 125 |
+
def forward(self, x, temb, zq=None):
|
| 126 |
+
h = x
|
| 127 |
+
h = self.norm1(h, zq)
|
| 128 |
+
h = nonlinearity(h)
|
| 129 |
+
h = self.conv1(h)
|
| 130 |
+
|
| 131 |
+
if temb is not None:
|
| 132 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
| 133 |
+
|
| 134 |
+
h = self.norm2(h, zq)
|
| 135 |
+
h = nonlinearity(h)
|
| 136 |
+
h = self.dropout(h)
|
| 137 |
+
h = self.conv2(h)
|
| 138 |
+
|
| 139 |
+
if self.in_channels != self.out_channels:
|
| 140 |
+
if self.use_conv_shortcut:
|
| 141 |
+
x = self.conv_shortcut(x)
|
| 142 |
+
else:
|
| 143 |
+
x = self.nin_shortcut(x)
|
| 144 |
+
|
| 145 |
+
return x+h
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class AttnBlock(nn.Module):
|
| 149 |
+
def __init__(self, in_channels, zq_ch=None, add_conv=False):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.in_channels = in_channels
|
| 152 |
+
|
| 153 |
+
self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
|
| 154 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 155 |
+
in_channels,
|
| 156 |
+
kernel_size=1,
|
| 157 |
+
stride=1,
|
| 158 |
+
padding=0)
|
| 159 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 160 |
+
in_channels,
|
| 161 |
+
kernel_size=1,
|
| 162 |
+
stride=1,
|
| 163 |
+
padding=0)
|
| 164 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 165 |
+
in_channels,
|
| 166 |
+
kernel_size=1,
|
| 167 |
+
stride=1,
|
| 168 |
+
padding=0)
|
| 169 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 170 |
+
in_channels,
|
| 171 |
+
kernel_size=1,
|
| 172 |
+
stride=1,
|
| 173 |
+
padding=0)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def forward(self, x, zq=None):
|
| 177 |
+
h_ = x
|
| 178 |
+
h_ = self.norm(h_, zq)
|
| 179 |
+
q = self.q(h_)
|
| 180 |
+
k = self.k(h_)
|
| 181 |
+
v = self.v(h_)
|
| 182 |
+
|
| 183 |
+
# compute attention
|
| 184 |
+
b,c,h,w = q.shape
|
| 185 |
+
q = q.reshape(b,c,h*w)
|
| 186 |
+
q = q.permute(0,2,1) # b,hw,c
|
| 187 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
| 188 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 189 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 190 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 191 |
+
|
| 192 |
+
# attend to values
|
| 193 |
+
v = v.reshape(b,c,h*w)
|
| 194 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
| 195 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 196 |
+
h_ = h_.reshape(b,c,h,w)
|
| 197 |
+
|
| 198 |
+
h_ = self.proj_out(h_)
|
| 199 |
+
|
| 200 |
+
return x+h_
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class Encoder(nn.Module):
|
| 204 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 205 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 206 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.ch = ch
|
| 209 |
+
self.temb_ch = 0
|
| 210 |
+
self.num_resolutions = len(ch_mult)
|
| 211 |
+
self.num_res_blocks = num_res_blocks
|
| 212 |
+
self.resolution = resolution
|
| 213 |
+
self.in_channels = in_channels
|
| 214 |
+
|
| 215 |
+
# downsampling
|
| 216 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
| 217 |
+
self.ch,
|
| 218 |
+
kernel_size=3,
|
| 219 |
+
stride=1,
|
| 220 |
+
padding=1)
|
| 221 |
+
|
| 222 |
+
curr_res = resolution
|
| 223 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 224 |
+
self.down = nn.ModuleList()
|
| 225 |
+
for i_level in range(self.num_resolutions):
|
| 226 |
+
block = nn.ModuleList()
|
| 227 |
+
attn = nn.ModuleList()
|
| 228 |
+
block_in = ch*in_ch_mult[i_level]
|
| 229 |
+
block_out = ch*ch_mult[i_level]
|
| 230 |
+
for i_block in range(self.num_res_blocks):
|
| 231 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 232 |
+
out_channels=block_out,
|
| 233 |
+
temb_channels=self.temb_ch,
|
| 234 |
+
dropout=dropout))
|
| 235 |
+
block_in = block_out
|
| 236 |
+
if curr_res in attn_resolutions:
|
| 237 |
+
attn.append(AttnBlock(block_in))
|
| 238 |
+
down = nn.Module()
|
| 239 |
+
down.block = block
|
| 240 |
+
down.attn = attn
|
| 241 |
+
if i_level != self.num_resolutions-1:
|
| 242 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 243 |
+
curr_res = curr_res // 2
|
| 244 |
+
self.down.append(down)
|
| 245 |
+
|
| 246 |
+
# middle
|
| 247 |
+
self.mid = nn.Module()
|
| 248 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 249 |
+
out_channels=block_in,
|
| 250 |
+
temb_channels=self.temb_ch,
|
| 251 |
+
dropout=dropout)
|
| 252 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 253 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 254 |
+
out_channels=block_in,
|
| 255 |
+
temb_channels=self.temb_ch,
|
| 256 |
+
dropout=dropout)
|
| 257 |
+
|
| 258 |
+
# end
|
| 259 |
+
self.norm_out = Normalize(block_in)
|
| 260 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 261 |
+
2*z_channels if double_z else z_channels,
|
| 262 |
+
kernel_size=3,
|
| 263 |
+
stride=1,
|
| 264 |
+
padding=1)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
temb = None
|
| 269 |
+
|
| 270 |
+
# downsampling
|
| 271 |
+
hs = [self.conv_in(x)]
|
| 272 |
+
for i_level in range(self.num_resolutions):
|
| 273 |
+
for i_block in range(self.num_res_blocks):
|
| 274 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 275 |
+
if len(self.down[i_level].attn) > 0:
|
| 276 |
+
h = self.down[i_level].attn[i_block](h)
|
| 277 |
+
hs.append(h)
|
| 278 |
+
if i_level != self.num_resolutions-1:
|
| 279 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 280 |
+
|
| 281 |
+
# middle
|
| 282 |
+
h = hs[-1]
|
| 283 |
+
h = self.mid.block_1(h, temb)
|
| 284 |
+
h = self.mid.attn_1(h)
|
| 285 |
+
h = self.mid.block_2(h, temb)
|
| 286 |
+
|
| 287 |
+
# end
|
| 288 |
+
h = self.norm_out(h)
|
| 289 |
+
h = nonlinearity(h)
|
| 290 |
+
h = self.conv_out(h)
|
| 291 |
+
return h
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class Decoder(nn.Module):
|
| 295 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 296 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 297 |
+
resolution, z_channels, give_pre_end=False, zq_ch=None, add_conv=False, **ignorekwargs):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.ch = ch
|
| 300 |
+
self.temb_ch = 0
|
| 301 |
+
self.num_resolutions = len(ch_mult)
|
| 302 |
+
self.num_res_blocks = num_res_blocks
|
| 303 |
+
self.resolution = resolution
|
| 304 |
+
self.in_channels = in_channels
|
| 305 |
+
self.give_pre_end = give_pre_end
|
| 306 |
+
|
| 307 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 308 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 309 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
| 310 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
| 311 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
| 312 |
+
|
| 313 |
+
# z to block_in
|
| 314 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
| 315 |
+
block_in,
|
| 316 |
+
kernel_size=3,
|
| 317 |
+
stride=1,
|
| 318 |
+
padding=1)
|
| 319 |
+
|
| 320 |
+
# middle
|
| 321 |
+
self.mid = nn.Module()
|
| 322 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 323 |
+
out_channels=block_in,
|
| 324 |
+
temb_channels=self.temb_ch,
|
| 325 |
+
dropout=dropout,
|
| 326 |
+
zq_ch=zq_ch,
|
| 327 |
+
add_conv=add_conv)
|
| 328 |
+
self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
|
| 329 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 330 |
+
out_channels=block_in,
|
| 331 |
+
temb_channels=self.temb_ch,
|
| 332 |
+
dropout=dropout,
|
| 333 |
+
zq_ch=zq_ch,
|
| 334 |
+
add_conv=add_conv)
|
| 335 |
+
|
| 336 |
+
# upsampling
|
| 337 |
+
self.up = nn.ModuleList()
|
| 338 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 339 |
+
block = nn.ModuleList()
|
| 340 |
+
attn = nn.ModuleList()
|
| 341 |
+
block_out = ch*ch_mult[i_level]
|
| 342 |
+
for i_block in range(self.num_res_blocks+1):
|
| 343 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 344 |
+
out_channels=block_out,
|
| 345 |
+
temb_channels=self.temb_ch,
|
| 346 |
+
dropout=dropout,
|
| 347 |
+
zq_ch=zq_ch,
|
| 348 |
+
add_conv=add_conv))
|
| 349 |
+
block_in = block_out
|
| 350 |
+
if curr_res in attn_resolutions:
|
| 351 |
+
attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
|
| 352 |
+
up = nn.Module()
|
| 353 |
+
up.block = block
|
| 354 |
+
up.attn = attn
|
| 355 |
+
if i_level != 0:
|
| 356 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 357 |
+
curr_res = curr_res * 2
|
| 358 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 359 |
+
|
| 360 |
+
# end
|
| 361 |
+
self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
|
| 362 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 363 |
+
out_ch,
|
| 364 |
+
kernel_size=3,
|
| 365 |
+
stride=1,
|
| 366 |
+
padding=1)
|
| 367 |
+
|
| 368 |
+
def forward(self, z, zq):
|
| 369 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 370 |
+
self.last_z_shape = z.shape
|
| 371 |
+
|
| 372 |
+
# timestep embedding
|
| 373 |
+
temb = None
|
| 374 |
+
|
| 375 |
+
# z to block_in
|
| 376 |
+
h = self.conv_in(z)
|
| 377 |
+
|
| 378 |
+
# middle
|
| 379 |
+
h = self.mid.block_1(h, temb, zq)
|
| 380 |
+
h = self.mid.attn_1(h, zq)
|
| 381 |
+
h = self.mid.block_2(h, temb, zq)
|
| 382 |
+
|
| 383 |
+
# upsampling
|
| 384 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 385 |
+
for i_block in range(self.num_res_blocks+1):
|
| 386 |
+
h = self.up[i_level].block[i_block](h, temb, zq)
|
| 387 |
+
if len(self.up[i_level].attn) > 0:
|
| 388 |
+
h = self.up[i_level].attn[i_block](h, zq)
|
| 389 |
+
if i_level != 0:
|
| 390 |
+
h = self.up[i_level].upsample(h)
|
| 391 |
+
|
| 392 |
+
# end
|
| 393 |
+
if self.give_pre_end:
|
| 394 |
+
return h
|
| 395 |
+
|
| 396 |
+
h = self.norm_out(h, zq)
|
| 397 |
+
h = nonlinearity(h)
|
| 398 |
+
h = self.conv_out(h)
|
| 399 |
+
return h
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class MoVQ(nn.Module):
|
| 403 |
+
|
| 404 |
+
def __init__(self, generator_params):
|
| 405 |
+
super().__init__()
|
| 406 |
+
z_channels = generator_params["z_channels"]
|
| 407 |
+
self.encoder = Encoder(**generator_params)
|
| 408 |
+
self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
| 409 |
+
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
| 410 |
+
self.decoder = Decoder(zq_ch=z_channels, **generator_params)
|
| 411 |
+
|
| 412 |
+
# @torch.no_grad()
|
| 413 |
+
def encode(self, x):
|
| 414 |
+
h = self.encoder(x)
|
| 415 |
+
h = self.quant_conv(h)
|
| 416 |
+
return h
|
| 417 |
+
|
| 418 |
+
# @torch.no_grad()
|
| 419 |
+
def decode(self, quant):
|
| 420 |
+
decoder_input = self.post_quant_conv(quant)
|
| 421 |
+
decoded = self.decoder(decoder_input, quant)
|
| 422 |
+
return decoded
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def get_vae(conf):
|
| 426 |
+
movq = MoVQ(conf.params)
|
| 427 |
+
if conf.checkpoint is not None:
|
| 428 |
+
movq_state_dict = torch.load(conf.checkpoint)
|
| 429 |
+
movq.load_state_dict(movq_state_dict)
|
| 430 |
+
movq = freeze(movq)
|
| 431 |
+
return movq
|
kandinsky3/setup.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="kandinsky3",
|
| 5 |
+
packages=[
|
| 6 |
+
"kandinsky3",
|
| 7 |
+
"kandinsky3/model"
|
| 8 |
+
],
|
| 9 |
+
install_requires=[
|
| 10 |
+
"timm",
|
| 11 |
+
"torch==1.10.1+cu111",
|
| 12 |
+
"torchvision==0.11.2+cu111",
|
| 13 |
+
"torchaudio==0.10.1",
|
| 14 |
+
"pytorch_lightning==1.7.5",
|
| 15 |
+
"transformers",
|
| 16 |
+
"accelerate",
|
| 17 |
+
"diffusers",
|
| 18 |
+
"setuptools==59.5.0",
|
| 19 |
+
"omegaconf",
|
| 20 |
+
"datasets",
|
| 21 |
+
"einops",
|
| 22 |
+
"webdataset",
|
| 23 |
+
"fsspec",
|
| 24 |
+
"s3fs",
|
| 25 |
+
"hydra-core",
|
| 26 |
+
"scikit-image",
|
| 27 |
+
"matplotlib",
|
| 28 |
+
"wandb",
|
| 29 |
+
"albumentations",
|
| 30 |
+
"bezier",
|
| 31 |
+
"scipy",
|
| 32 |
+
"Pillow",
|
| 33 |
+
"tqdm",
|
| 34 |
+
"huggingface_hub"
|
| 35 |
+
|
| 36 |
+
],
|
| 37 |
+
author="",
|
| 38 |
+
)
|
kandinsky3/t2i_pipeline.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, List
|
| 2 |
+
import PIL
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms as T
|
| 6 |
+
from einops import repeat
|
| 7 |
+
|
| 8 |
+
from kandinsky3.model.unet import UNet
|
| 9 |
+
from kandinsky3.movq import MoVQ
|
| 10 |
+
from kandinsky3.condition_encoders import T5TextConditionEncoder
|
| 11 |
+
from kandinsky3.condition_processors import T5TextConditionProcessor
|
| 12 |
+
from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Kandinsky3T2IPipeline:
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
device_map: Union[str, torch.device, dict],
|
| 20 |
+
dtype_map: Union[str, torch.dtype, dict],
|
| 21 |
+
unet: UNet,
|
| 22 |
+
null_embedding: torch.Tensor,
|
| 23 |
+
t5_processor: T5TextConditionProcessor,
|
| 24 |
+
t5_encoder: T5TextConditionEncoder,
|
| 25 |
+
movq: MoVQ,
|
| 26 |
+
gan: bool,
|
| 27 |
+
):
|
| 28 |
+
self.device_map = device_map
|
| 29 |
+
self.dtype_map = dtype_map
|
| 30 |
+
self.to_pil = T.ToPILImage()
|
| 31 |
+
|
| 32 |
+
self.unet = unet
|
| 33 |
+
self.null_embedding = null_embedding
|
| 34 |
+
self.t5_processor = t5_processor
|
| 35 |
+
self.t5_encoder = t5_encoder
|
| 36 |
+
self.movq = movq
|
| 37 |
+
|
| 38 |
+
self.gan = gan
|
| 39 |
+
|
| 40 |
+
def __call__(
|
| 41 |
+
self,
|
| 42 |
+
text: str,
|
| 43 |
+
negative_text: str = None,
|
| 44 |
+
images_num: int = 1,
|
| 45 |
+
bs: int = 1,
|
| 46 |
+
width: int = 1024,
|
| 47 |
+
height: int = 1024,
|
| 48 |
+
guidance_scale: float = 3.0,
|
| 49 |
+
steps: int = 50,
|
| 50 |
+
eta: float = 1.0
|
| 51 |
+
) -> List[PIL.Image.Image]:
|
| 52 |
+
|
| 53 |
+
betas = get_named_beta_schedule('cosine', 1000)
|
| 54 |
+
base_diffusion = BaseDiffusion(betas, 0.99)
|
| 55 |
+
times = list(range(999, 0, -1000 // steps))
|
| 56 |
+
if self.gan:
|
| 57 |
+
times = list(range(979, 0, -250))
|
| 58 |
+
|
| 59 |
+
condition_model_input, negative_condition_model_input = self.t5_processor.encode(text, negative_text)
|
| 60 |
+
for input_type in condition_model_input:
|
| 61 |
+
condition_model_input[input_type] = condition_model_input[input_type][None].to(
|
| 62 |
+
self.device_map['text_encoder']
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if negative_condition_model_input is not None:
|
| 66 |
+
for input_type in negative_condition_model_input:
|
| 67 |
+
negative_condition_model_input[input_type] = negative_condition_model_input[input_type][None].to(
|
| 68 |
+
self.device_map['text_encoder']
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
pil_images = []
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['text_encoder']):
|
| 74 |
+
context, context_mask = self.t5_encoder(condition_model_input)
|
| 75 |
+
if negative_condition_model_input is not None:
|
| 76 |
+
negative_context, negative_context_mask = self.t5_encoder(negative_condition_model_input)
|
| 77 |
+
else:
|
| 78 |
+
negative_context, negative_context_mask = None, None
|
| 79 |
+
|
| 80 |
+
k, m = images_num // bs, images_num % bs
|
| 81 |
+
for minibatch in [bs] * k + [m]:
|
| 82 |
+
if minibatch == 0:
|
| 83 |
+
continue
|
| 84 |
+
bs_context = repeat(context, '1 n d -> b n d', b=minibatch)
|
| 85 |
+
bs_context_mask = repeat(context_mask, '1 n -> b n', b=minibatch)
|
| 86 |
+
if negative_context is not None:
|
| 87 |
+
bs_negative_context = repeat(negative_context, '1 n d -> b n d', b=minibatch)
|
| 88 |
+
bs_negative_context_mask = repeat(negative_context_mask, '1 n -> b n', b=minibatch)
|
| 89 |
+
else:
|
| 90 |
+
bs_negative_context, bs_negative_context_mask = None, None
|
| 91 |
+
|
| 92 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['unet']):
|
| 93 |
+
images = base_diffusion.p_sample_loop(
|
| 94 |
+
self.unet, (minibatch, 4, height // 8, width // 8), times, self.device_map['unet'],
|
| 95 |
+
bs_context, bs_context_mask, self.null_embedding, guidance_scale, eta,
|
| 96 |
+
negative_context=bs_negative_context, negative_context_mask=bs_negative_context_mask,
|
| 97 |
+
gan=self.gan
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
with torch.cuda.amp.autocast(dtype=self.dtype_map['movq']):
|
| 101 |
+
images = torch.cat([self.movq.decode(image) for image in images.chunk(2)])
|
| 102 |
+
# print(torch.max(images), torch.min(images))
|
| 103 |
+
images = torch.clip((images + 1.) / 2., 0., 1.)
|
| 104 |
+
# print(torch.max(images), torch.min(images))
|
| 105 |
+
# raise
|
| 106 |
+
for images_chunk in images.chunk(1):
|
| 107 |
+
pil_images += [self.to_pil(image) for image in images_chunk]
|
| 108 |
+
|
| 109 |
+
return pil_images
|
kandinsky3/utils.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import OmegaConf
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy import ndimage
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from skimage.transform import resize
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_conf(config_path):
|
| 9 |
+
conf = OmegaConf.load(config_path)
|
| 10 |
+
conf.data.tokens_length = conf.common.tokens_length
|
| 11 |
+
conf.data.processor_names = conf.model.encoders.model_names
|
| 12 |
+
conf.data.dataset.seed = conf.common.seed
|
| 13 |
+
conf.data.dataset.image_size = conf.common.image_size
|
| 14 |
+
|
| 15 |
+
conf.trainer.trainer_params.max_steps = conf.common.train_steps
|
| 16 |
+
conf.scheduler.params.total_steps = conf.common.train_steps
|
| 17 |
+
conf.logger.tensorboard.name = conf.common.experiment_name
|
| 18 |
+
|
| 19 |
+
conf.model.encoders.context_dim = conf.model.unet_params.context_dim
|
| 20 |
+
return conf
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def freeze(model):
|
| 24 |
+
for p in model.parameters():
|
| 25 |
+
p.requires_grad = False
|
| 26 |
+
return model
|
| 27 |
+
|
| 28 |
+
def unfreeze(model):
|
| 29 |
+
for p in model.parameters():
|
| 30 |
+
p.requires_grad = True
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
def zero_module(module):
|
| 34 |
+
for p in module.parameters():
|
| 35 |
+
nn.init.zeros_(p)
|
| 36 |
+
return module
|
| 37 |
+
|
| 38 |
+
def resize_mask_for_diffusion(mask):
|
| 39 |
+
reduce_factor = max(1, (mask.size / 1024**2)**0.5)
|
| 40 |
+
resized_mask = resize(
|
| 41 |
+
mask,
|
| 42 |
+
(
|
| 43 |
+
(round(mask.shape[0] / reduce_factor) // 64) * 64,
|
| 44 |
+
(round(mask.shape[1] / reduce_factor) // 64) * 64
|
| 45 |
+
),
|
| 46 |
+
preserve_range=True,
|
| 47 |
+
anti_aliasing=False
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return resized_mask
|
| 51 |
+
|
| 52 |
+
def resize_image_for_diffusion(image):
|
| 53 |
+
reduce_factor = max(1, (image.size[0] * image.size[1] / 1024**2)**0.5)
|
| 54 |
+
image = image.resize((
|
| 55 |
+
(round(image.size[0] / reduce_factor) // 64) * 64, (round(image.size[1] / reduce_factor) // 64) * 64
|
| 56 |
+
))
|
| 57 |
+
|
| 58 |
+
return image
|
| 59 |
+
|
| 60 |
+
def prepare_mask(mask):
|
| 61 |
+
ker = np.array([[1, 1, 1, 1, 1],
|
| 62 |
+
[1, 5, 5, 5, 1],
|
| 63 |
+
[1, 5, 44, 5, 1],
|
| 64 |
+
[1, 5, 5, 5, 1],
|
| 65 |
+
[1, 1, 1, 1, 1]]) / 100
|
| 66 |
+
out = ndimage.convolve(mask, ker)
|
| 67 |
+
out = ndimage.convolve(out, ker)
|
| 68 |
+
out = ndimage.convolve(out, ker)
|
| 69 |
+
|
| 70 |
+
mask = (out > 0).astype(int)
|
| 71 |
+
return mask
|
unet_model_checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4165a1e48f2a7630729c03c8c7662b6acf8b9f6a590102ba80051c01afb480eb
|
| 3 |
+
size 12154895798
|