import numpy as np import torch import torch.nn as nn import torch.distributions as dist from .ae_bases import BasicEncoder, BasicGenerator class VAE(torch.nn.Module): def __init__( self, d, input_size, z_dim=256, fmap_sizes=(16, 64, 256, 1024), to_1x1=True, conv_params=None, tconv_params=None, normalization_op=None, normalization_params=None, activation_op="leakyrelu", activation_params=None, block_op=None, block_params=None, *args, **kwargs ): """Basic VAE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder) Args: input_size ((int, int, int): Size of the input in format CxHxW): z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256 fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024). to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. Defaults to True. conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. """ super(VAE, self).__init__() if d == 2: conv_op = nn.Conv2d tconv_op = nn.ConvTranspose2d else: conv_op = nn.Conv3d tconv_op = nn.ConvTranspose3d match (activation_op): case "relu": activation_op = nn.ReLU case "prelu": activation_op = nn.PReLU case "leakyrelu": activation_op = nn.LeakyReLU case _: raise ValueError(f"Activation function {activation_op} not supported") input_size_enc = list(input_size) input_size_dec = list(input_size) self.enc = BasicEncoder( input_size=input_size_enc, fmap_sizes=fmap_sizes, z_dim=z_dim * 2, conv_op=conv_op, conv_params=conv_params, normalization_op=normalization_op, normalization_params=normalization_params, activation_op=activation_op, activation_params=activation_params, block_op=block_op, block_params=block_params, to_1x1=to_1x1, ) self.dec = BasicGenerator( input_size=input_size_dec, fmap_sizes=fmap_sizes[::-1], z_dim=z_dim, upsample_op=tconv_op, conv_params=tconv_params, normalization_op=normalization_op, normalization_params=normalization_params, activation_op=activation_op, activation_params=activation_params, block_op=block_op, block_params=block_params, to_1x1=to_1x1, ) self.hidden_size = self.enc.output_size def forward(self, inpt, sample=True, no_dist=False, **kwargs): y1 = self.enc(inpt, **kwargs) mu, log_std = torch.chunk(y1, 2, dim=1) std = torch.exp(log_std) z_dist = dist.Normal(mu, std) if sample: z_sample = z_dist.rsample() else: z_sample = mu x_rec = self.dec(z_sample) if no_dist: return x_rec else: return x_rec, z_dist def encode(self, inpt, **kwargs): """Encodes a sample and returns the paramters for the approx inference dist. (Normal) Args: inpt ([tensor]): The input to encode Returns: mu : The mean used to parameterized a Normal distribution std: The standard deviation used to parameterized a Normal distribution """ enc = self.enc(inpt, **kwargs) mu, log_std = torch.chunk(enc, 2, dim=1) std = torch.exp(log_std) return mu, std def decode(self, inpt, **kwargs): """Decodes a latent space sample, used the generative model (decode = mu_{gen}(z) as used in p(x|z) = N(x | mu_{gen}(z), 1) ). Args: inpt ([type]): A sample from the latent space to decode Returns: [type]: [description] """ x_rec = self.dec(inpt, **kwargs) return x_rec class AE(torch.nn.Module): def __init__( self, input_size, z_dim=1024, fmap_sizes=(16, 64, 256, 1024), to_1x1=True, conv_op=torch.nn.Conv2d, conv_params=None, tconv_op=torch.nn.ConvTranspose2d, tconv_params=None, normalization_op=None, normalization_params=None, activation_op=torch.nn.LeakyReLU, activation_params=None, block_op=None, block_params=None, *args, **kwargs ): """Basic AE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder) Args: input_size ((int, int, int): Size of the input in format CxHxW): z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256 fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024). to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. Defaults to True. conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. """ super(AE, self).__init__() input_size_enc = list(input_size) input_size_dec = list(input_size) self.enc = BasicEncoder( input_size=input_size_enc, fmap_sizes=fmap_sizes, z_dim=z_dim, conv_op=conv_op, conv_params=conv_params, normalization_op=normalization_op, normalization_params=normalization_params, activation_op=activation_op, activation_params=activation_params, block_op=block_op, block_params=block_params, to_1x1=to_1x1, ) self.dec = BasicGenerator( input_size=input_size_dec, fmap_sizes=fmap_sizes[::-1], z_dim=z_dim, upsample_op=tconv_op, conv_params=tconv_params, normalization_op=normalization_op, normalization_params=normalization_params, activation_op=activation_op, activation_params=activation_params, block_op=block_op, block_params=block_params, to_1x1=to_1x1, ) self.hidden_size = self.enc.output_size def forward(self, inpt, **kwargs): y1 = self.enc(inpt, **kwargs) x_rec = self.dec(y1) return x_rec def encode(self, inpt, **kwargs): """Encodes a input sample to a latent space sample Args: inpt ([tensor]): Input sample Returns: enc: Encoded input sample in the latent space """ enc = self.enc(inpt, **kwargs) return enc def decode(self, inpt, **kwargs): """Decodes a latent space sample back to the input space Args: inpt ([tensor]): [Latent space sample] Returns: [rec]: [Encoded latent sample back in the input space] """ rec = self.dec(inpt, **kwargs) return rec