| import os |
| from typing import Optional, Union |
|
|
| import torch |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
| from kandinsky3.model.unet import UNet |
| from kandinsky3.movq import MoVQ |
| from kandinsky3.condition_encoders import T5TextConditionEncoder |
| from kandinsky3.condition_processors import T5TextConditionProcessor |
| from kandinsky3.model.diffusion import BaseDiffusion, get_named_beta_schedule |
|
|
| from .t2i_pipeline import Kandinsky3T2IPipeline |
| from .inpainting_pipeline import Kandinsky3InpaintingPipeline |
|
|
|
|
| def get_T2I_unet( |
| device: Union[str, torch.device], |
| weights_path: Optional[str] = None, |
| dtype: Union[str, torch.dtype] = torch.float32, |
| ) -> (UNet, Optional[torch.Tensor], Optional[dict]): |
| unet = UNet( |
| model_channels=384, |
| num_channels=4, |
| init_channels=192, |
| time_embed_dim=1536, |
| context_dim=4096, |
| groups=32, |
| head_dim=64, |
| expansion_ratio=4, |
| compression_ratio=2, |
| dim_mult=(1, 2, 4, 8), |
| num_blocks=(3, 3, 3, 3), |
| add_cross_attention=(False, True, True, True), |
| add_self_attention=(False, True, True, True), |
| ) |
|
|
| null_embedding = None |
| if weights_path: |
| state_dict = torch.load(weights_path, map_location=torch.device('cpu')) |
| null_embedding = state_dict['null_embedding'] |
| unet.load_state_dict(state_dict['unet']) |
|
|
| unet.to(device=device, dtype=dtype).eval() |
| return unet, null_embedding |
|
|
|
|
| def get_T5encoder( |
| device: Union[str, torch.device], |
| weights_path: str, |
| projection_name: str, |
| dtype: Union[str, torch.dtype] = torch.float32, |
| low_cpu_mem_usage: bool = True, |
| load_in_8bit: bool = False, |
| load_in_4bit: bool = False, |
| ) -> (T5TextConditionProcessor, T5TextConditionEncoder): |
| tokens_length = 128 |
| context_dim = 4096 |
| processor = T5TextConditionProcessor(tokens_length, weights_path) |
| condition_encoder = T5TextConditionEncoder( |
| weights_path, context_dim, low_cpu_mem_usage=low_cpu_mem_usage, device=device, |
| dtype=dtype, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit |
| ) |
|
|
| if weights_path: |
| projections_weights_path = os.path.join(weights_path, projection_name) |
| state_dict = torch.load(projections_weights_path, map_location=torch.device('cpu')) |
| condition_encoder.projection.load_state_dict(state_dict) |
|
|
| condition_encoder.projection.to(device=device, dtype=dtype).eval() |
| return processor, condition_encoder |
|
|
|
|
| def get_movq( |
| device: Union[str, torch.device], |
| weights_path: Optional[str] = None, |
| dtype: Union[str, torch.dtype] = torch.float32, |
| ) -> MoVQ: |
| generator_config = { |
| 'double_z': False, |
| 'z_channels': 4, |
| 'resolution': 256, |
| 'in_channels': 3, |
| 'out_ch': 3, |
| 'ch': 256, |
| 'ch_mult': [1, 2, 2, 4], |
| 'num_res_blocks': 2, |
| 'attn_resolutions': [32], |
| 'dropout': 0.0 |
| } |
| movq = MoVQ(generator_config) |
|
|
| if weights_path: |
| state_dict = torch.load(weights_path, map_location=torch.device('cpu')) |
| movq.load_state_dict(state_dict) |
|
|
| movq.to(device=device, dtype=dtype).eval() |
| return movq |
|
|
|
|
| def get_inpainting_unet( |
| device: Union[str, torch.device], |
| weights_path: Optional[str] = None, |
| dtype: Union[str, torch.dtype] = torch.float32, |
| ) -> (UNet, Optional[torch.Tensor], Optional[dict]): |
| unet = UNet( |
| model_channels=384, |
| num_channels=9, |
| init_channels=192, |
| time_embed_dim=1536, |
| context_dim=4096, |
| groups=32, |
| head_dim=64, |
| expansion_ratio=4, |
| compression_ratio=2, |
| dim_mult=(1, 2, 4, 8), |
| num_blocks=(3, 3, 3, 3), |
| add_cross_attention=(False, True, True, True), |
| add_self_attention=(False, True, True, True), |
| ) |
|
|
| null_embedding = None |
| if weights_path: |
| state_dict = torch.load(weights_path, map_location=torch.device('cpu')) |
| null_embedding = state_dict['null_embedding'] |
| unet.load_state_dict(state_dict['unet']) |
|
|
| unet.to(device=device, dtype=dtype).eval() |
| return unet, null_embedding |
|
|
|
|
| def get_T2I_pipeline( |
| device_map: Union[str, torch.device, dict], |
| dtype_map: Union[str, torch.dtype, dict] = torch.float32, |
| low_cpu_mem_usage: bool = True, |
| load_in_8bit: bool = False, |
| load_in_4bit: bool = False, |
| cache_dir: str = '/tmp/kandinsky3/', |
| unet_path: str = None, |
| text_encoder_path: str = None, |
| movq_path: str = None, |
| ) -> Kandinsky3T2IPipeline: |
| |
| if not isinstance(device_map, dict): |
| device_map = { |
| 'unet': device_map, 'text_encoder': device_map, 'movq': device_map |
| } |
| if not isinstance(dtype_map, dict): |
| dtype_map = { |
| 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map |
| } |
|
|
| if unet_path is None: |
| unet_path = hf_hub_download( |
| repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3.pt', cache_dir=cache_dir |
| ) |
| if text_encoder_path is None: |
| text_encoder_path = snapshot_download( |
| repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir |
| ) |
| text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder') |
| if movq_path is None: |
| movq_path = hf_hub_download( |
| repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir |
| ) |
|
|
| unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet']) |
| processor, condition_encoder = get_T5encoder( |
| device_map['text_encoder'], text_encoder_path, 'projection.pt', dtype=dtype_map['text_encoder'], |
| low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit |
| ) |
| movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq']) |
| return Kandinsky3T2IPipeline( |
| device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, False |
| ) |
|
|
|
|
| def get_T2I_Flash_pipeline( |
| device_map: Union[str, torch.device, dict], |
| dtype_map: Union[str, torch.dtype, dict] = torch.float32, |
| low_cpu_mem_usage: bool = True, |
| load_in_8bit: bool = False, |
| load_in_4bit: bool = False, |
| cache_dir: str = '/tmp/kandinsky3/', |
| unet_path: str = None, |
| text_encoder_path: str = None, |
| movq_path: str = None, |
| ) -> Kandinsky3T2IPipeline: |
| |
| if not isinstance(device_map, dict): |
| device_map = { |
| 'unet': device_map, 'text_encoder': device_map, 'movq': device_map |
| } |
| if not isinstance(dtype_map, dict): |
| dtype_map = { |
| 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map |
| } |
|
|
| if unet_path is None: |
| unet_path = hf_hub_download( |
| repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_flash.pt', cache_dir=cache_dir |
| ) |
| if text_encoder_path is None: |
| text_encoder_path = snapshot_download( |
| repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir |
| ) |
| text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder') |
| if movq_path is None: |
| movq_path = hf_hub_download( |
| repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir |
| ) |
|
|
| unet, null_embedding = get_T2I_unet(device_map['unet'], unet_path, dtype=dtype_map['unet']) |
| processor, condition_encoder = get_T5encoder( |
| device_map['text_encoder'], text_encoder_path, 'projection_flash.pt', dtype=dtype_map['text_encoder'], |
| low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit |
| ) |
| movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq']) |
| return Kandinsky3T2IPipeline( |
| device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq, True |
| ) |
|
|
|
|
| def get_inpainting_pipeline( |
| device_map: Union[str, torch.device, dict], |
| dtype_map: Union[str, torch.dtype, dict] = torch.float32, |
| low_cpu_mem_usage: bool = True, |
| load_in_8bit: bool = False, |
| load_in_4bit: bool = False, |
| cache_dir: str = '/tmp/kandinsky3/', |
| unet_path: str = None, |
| text_encoder_path: str = None, |
| movq_path: str = None, |
| ) -> Kandinsky3InpaintingPipeline: |
| |
| if not isinstance(device_map, dict): |
| device_map = { |
| 'unet': device_map, 'text_encoder': device_map, 'movq': device_map |
| } |
| if not isinstance(dtype_map, dict): |
| dtype_map = { |
| 'unet': dtype_map, 'text_encoder': dtype_map, 'movq': dtype_map |
| } |
|
|
| if unet_path is None: |
| unet_path = hf_hub_download( |
| repo_id="ai-forever/Kandinsky3.1", filename='weights/kandinsky3_inpainting.pt', cache_dir=cache_dir |
| ) |
| if text_encoder_path is None: |
| text_encoder_path = snapshot_download( |
| repo_id="ai-forever/Kandinsky3.1", allow_patterns="weights/flan_ul2_encoder/*", cache_dir=cache_dir |
| ) |
| text_encoder_path = os.path.join(text_encoder_path, 'weights/flan_ul2_encoder') |
| if movq_path is None: |
| movq_path = hf_hub_download( |
| repo_id="ai-forever/Kandinsky3.1", filename='weights/movq.pt', cache_dir=cache_dir |
| ) |
|
|
| unet, null_embedding = get_inpainting_unet(device_map['unet'], unet_path, dtype=dtype_map['unet']) |
| processor, condition_encoder = get_T5encoder( |
| device_map['text_encoder'], text_encoder_path, 'projection_inpainting.pt', dtype=dtype_map['text_encoder'], |
| low_cpu_mem_usage=low_cpu_mem_usage, load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit |
| ) |
| movq = get_movq(device_map['movq'], movq_path, dtype=dtype_map['movq']) |
| return Kandinsky3InpaintingPipeline( |
| device_map, dtype_map, unet, null_embedding, processor, condition_encoder, movq |
| ) |
|
|