# -*- coding: utf-8 -*- # # @File: generator.py # @Author: Haozhe Xie # @Date: 2024-03-09 20:36:52 # @Last Modified by: Haozhe Xie # @Last Modified at: 2024-09-23 20:49:35 # @Email: root@haozhexie.com import numpy as np import torch import torch.nn.functional as F import extensions.grid_encoder import gaussiancity.pt_v3 class Generator(torch.nn.Module): def __init__(self, cfg, n_classes, proj_size): super(Generator, self).__init__() self.cfg = cfg self.n_classes = n_classes if cfg.ENCODER == "GLOBAL": self.proj_encoder = GlobalEncoder( n_classes, cfg.GLOBAL_ENCODER_N_BLOCKS, cfg.ENCODER_OUT_DIM - 3 ) elif cfg.ENCODER == "LOCAL": self.proj_encoder = LocalEncoder(n_classes, cfg.ENCODER_OUT_DIM - 3) elif cfg.ENCODER is None: self.proj_encoder = None assert cfg.ENCODER_OUT_DIM == 3 else: raise ValueError("Unknown encoder: %s" % cfg.ENCODER) if cfg.POS_EMD == "HASH_GRID": pt_feat_dim = cfg.HASH_GRID_N_LEVELS * cfg.HASH_GRID_LEVEL_DIM self.pos_encoder = extensions.grid_encoder.GridEncoder( in_channels=cfg.ENCODER_OUT_DIM, desired_resolution=proj_size, n_levels=cfg.HASH_GRID_N_LEVELS, lvl_channels=cfg.HASH_GRID_LEVEL_DIM, ) elif cfg.POS_EMD == "SIN_COS": pt_feat_dim = 2 * cfg.ENCODER_OUT_DIM * cfg.SIN_COS_FREQ_BENDS self.pos_encoder = SinCosEncoder(cfg.SIN_COS_FREQ_BENDS) else: raise ValueError("Unknown positional encoder: %s" % cfg.POS_EMD) if cfg.PTV3.ENABLED: self.pt_net = gaussiancity.pt_v3.PointTransformerV3( in_channels=pt_feat_dim, order=cfg.PTV3.ORDER, stride=cfg.PTV3.STRIDE, enc_depths=cfg.PTV3.ENC_DEPTHS, enc_channels=cfg.PTV3.ENC_CHANNELS, enc_num_head=cfg.PTV3.ENC_N_HEAD, enc_patch_size=cfg.PTV3.ENC_PATCH_SIZE, dec_depths=cfg.PTV3.DEC_DEPTHS, dec_channels=cfg.PTV3.DEC_CHANNELS, dec_num_head=cfg.PTV3.DEC_N_HEAD, dec_patch_size=cfg.PTV3.DEC_PATCH_SIZE, enable_flash=cfg.PTV3.ENABLE_FLASH_ATTN, ) pt_feat_dim += cfg.PTV3.DEC_CHANNELS[0] else: self.pt_net = None self.ga_mlp = GaussianAttrMLP( n_classes, pt_feat_dim, cfg.Z_DIM, cfg.MLP_HIDDEN_DIM, cfg.MLP_N_SHARED_LAYERS, cfg.ATTR_FACTORS, cfg.ATTR_N_LAYERS, ) def forward(self, proj_uv, rel_xyz, batch_idx, onehots, z, proj_hf, proj_seg): # Ref: https://github.com/hzxie/CityDreamer/blob/master/models/gancraft.py#L381 if self.cfg.ENCODER == "GLOBAL": proj_feat = self.proj_encoder(proj_hf, proj_seg) pt_feat = proj_feat.unsqueeze(dim=1).repeat(1, proj_uv.size(1), 1) elif self.cfg.ENCODER == "LOCAL": proj_feat = self.proj_encoder(proj_hf, proj_seg) pt_feat = ( F.grid_sample(proj_feat, proj_uv.unsqueeze(dim=1), align_corners=True) .squeeze(dim=2) .permute(0, 2, 1) ) elif self.cfg.ENCODER is None: pt_feat = torch.empty( rel_xyz.size(0), rel_xyz.size(1), 0, device=proj_uv.device ) # print(pt_feat.size()) # torch.Size([B, n_pts, cfg.ENCODER_OUT_DIM - 3]) pt_feat = torch.cat([pt_feat, rel_xyz], dim=2) # print(pt_feat.size()) # torch.Size([B, n_pts, cfg.ENCODER_OUT_DIM]) pt_feat1 = self.pos_encoder(pt_feat) # print(pt_feat1.size()) # torch.Size([B, n_pts, pt_feat_dim]) if self.pt_net is None: pt_feat2 = torch.empty( rel_xyz.size(0), rel_xyz.size(1), 0, device=proj_uv.device ) else: pt_feat2 = self.pt_net(batch_idx, pt_feat1, rel_xyz) # print(pt_feat2.size()) # torch.Size([B, n_pts, pt_feat_dim]) return self.ga_mlp(torch.cat([pt_feat1, pt_feat2], dim=-1), onehots, z) class GlobalEncoder(torch.nn.Module): def __init__(self, n_classes, n_blocks, out_channels): super(GlobalEncoder, self).__init__() self.hf_conv = torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1) self.seg_conv = torch.nn.Conv2d( n_classes, 8, kernel_size=3, stride=2, padding=1, ) conv_blocks = [] cur_hidden_channels = 16 for _ in range(1, n_blocks): conv_blocks.append( SRTConvBlock(in_channels=cur_hidden_channels, out_channels=None) ) cur_hidden_channels *= 2 self.conv_blocks = torch.nn.Sequential(*conv_blocks) self.fc1 = torch.nn.Linear(cur_hidden_channels, 16) self.fc2 = torch.nn.Linear(16, out_channels) self.act = torch.nn.LeakyReLU(0.2) def forward(self, proj_hf, proj_seg): hf = self.act(self.hf_conv(proj_hf)) seg = self.act(self.seg_conv(proj_seg)) out = torch.cat([hf, seg], dim=1) for layer in self.conv_blocks: out = self.act(layer(out)) out = out.permute(0, 2, 3, 1) out = torch.mean(out.reshape(out.shape[0], -1, out.shape[-1]), dim=1) cond = self.act(self.fc1(out)) cond = torch.tanh(self.fc2(cond)) return cond class LocalEncoder(torch.nn.Module): def __init__(self, n_classes, out_channels): super(LocalEncoder, self).__init__() self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3) self.seg_conv = torch.nn.Conv2d( n_classes, 32, kernel_size=7, stride=2, padding=3 ) self.bn1 = torch.nn.GroupNorm(32, 64) self.conv2 = ResConvBlock(64, 128) self.conv3 = ResConvBlock(128, 256) self.conv4 = ResConvBlock(256, 512) self.dconv5 = torch.nn.ConvTranspose2d( 512, 128, kernel_size=4, stride=2, padding=1 ) self.dconv6 = torch.nn.ConvTranspose2d( 128, 32, kernel_size=4, stride=2, padding=1 ) self.dconv7 = torch.nn.Conv2d(32, out_channels, kernel_size=1) def forward(self, proj_hf, proj_seg): hf = self.hf_conv(proj_hf) seg = self.seg_conv(proj_seg) out = F.relu(self.bn1(torch.cat([hf, seg], dim=1)), inplace=True) # print(out.size()) # torch.Size([N, 64, H/2, W/2]) out = F.avg_pool2d(self.conv2(out), 2, stride=2) # print(out.size()) # torch.Size([N, 128, H/4, W/4]) out = self.conv3(out) # print(out.size()) # torch.Size([N, 256, H/4, W/4]) out = self.conv4(out) # print(out.size()) # torch.Size([N, 512, H/4, W/4]) out = self.dconv5(out) # print(out.size()) # torch.Size([N, 128, H/2, W/2]) out = self.dconv6(out) # print(out.size()) # torch.Size([N, 32, H, W]) out = self.dconv7(out) # print(out.size()) # torch.Size([N, OUT_DIM - 1, H, W]) return torch.tanh(out) class SRTConvBlock(torch.nn.Module): def __init__(self, in_channels, hidden_channels=None, out_channels=None): super(SRTConvBlock, self).__init__() if hidden_channels is None: hidden_channels = in_channels if out_channels is None: out_channels = 2 * hidden_channels self.layers = torch.nn.Sequential( torch.nn.Conv2d( in_channels, hidden_channels, stride=1, kernel_size=3, padding=1, bias=False, ), torch.nn.ReLU(), torch.nn.Conv2d( hidden_channels, out_channels, stride=2, kernel_size=3, padding=1, bias=False, ), torch.nn.ReLU(), ) def forward(self, x): return self.layers(x) class ResConvBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, bias=False): super(ResConvBlock, self).__init__() # conv3x3(in_planes, int(out_planes / 2)) self.conv1 = torch.nn.Conv2d( in_channels, out_channels // 2, kernel_size=3, stride=1, padding=1, bias=bias, ) # conv3x3(int(out_planes / 2), int(out_planes / 4)) self.conv2 = torch.nn.Conv2d( out_channels // 2, out_channels // 4, kernel_size=3, stride=1, padding=1, bias=bias, ) # conv3x3(int(out_planes / 4), int(out_planes / 4)) self.conv3 = torch.nn.Conv2d( out_channels // 4, out_channels // 4, kernel_size=3, stride=1, padding=1, bias=bias, ) self.bn1 = torch.nn.GroupNorm(32, in_channels) self.bn2 = torch.nn.GroupNorm(32, out_channels // 2) self.bn3 = torch.nn.GroupNorm(32, out_channels // 4) self.bn4 = torch.nn.GroupNorm(32, in_channels) if in_channels != out_channels: self.downsample = torch.nn.Sequential( self.bn4, torch.nn.ReLU(True), torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, bias=False ), ) else: self.downsample = None def forward(self, x): residual = x # print(residual.size()) # torch.Size([N, 64, H, W]) out1 = self.bn1(x) out1 = F.relu(out1, True) out1 = self.conv1(out1) # print(out1.size()) # torch.Size([N, 64, H, W]) out2 = self.bn2(out1) out2 = F.relu(out2, True) out2 = self.conv2(out2) # print(out2.size()) # torch.Size([N, 32, H, W]) out3 = self.bn3(out2) out3 = F.relu(out3, True) out3 = self.conv3(out3) # print(out3.size()) # torch.Size([N, 32, H, W]) out3 = torch.cat((out1, out2, out3), dim=1) # print(out3.size()) # torch.Size([N, 128, H, W]) if self.downsample is not None: residual = self.downsample(residual) # print(residual.size()) # torch.Size([N, 128, H, W]) out3 += residual return out3 class SinCosEncoder(torch.nn.Module): def __init__(self, n_freq_bands=8): super(SinCosEncoder, self).__init__() self.freq_bands = 2.0 ** torch.linspace( 0, n_freq_bands - 1, steps=n_freq_bands, ) def forward(self, features): cord_sin = torch.cat( [torch.sin(features * fb) for fb in self.freq_bands], dim=-1 ) cord_cos = torch.cat( [torch.cos(features * fb) for fb in self.freq_bands], dim=-1 ) return torch.cat([cord_sin, cord_cos], dim=-1) class GaussianAttrMLP(torch.nn.Module): r"""MLP with affine modulation.""" def __init__( self, n_classes, in_dim, z_dim, hidden_dim, n_shared_layers, factors={}, n_layers={}, ): super(GaussianAttrMLP, self).__init__() self.factors = factors self.n_layers = n_layers self.n_shared_layers = n_shared_layers self.act = torch.nn.LeakyReLU(negative_slope=0.2) self.fc_m_a = torch.nn.Linear( n_classes, hidden_dim, bias=False, ) self.fc_1 = torch.nn.Linear( in_dim, hidden_dim, ) for i in range(2, n_shared_layers + 1): setattr( self, "fc_%d" % i, ( ModLinear( hidden_dim, hidden_dim, z_dim, bias=False, mod_bias=True, output_mode=True, ) if z_dim is not None else torch.nn.Linear(hidden_dim, hidden_dim) ), ) for k in factors.keys(): assert k in ["xyz", "rgb", "scale", "opacity"], "Unknwon key: %s" % k for i in range(n_layers[k]): setattr( self, "fc_%d_%s_%d" % (n_shared_layers + 1, k, i), ( ModLinear( hidden_dim, hidden_dim, z_dim, bias=False, mod_bias=True, output_mode=True, ) if z_dim is not None else torch.nn.Linear(hidden_dim, hidden_dim) ), ) setattr( self, "fc_out_%s" % k, torch.nn.Linear( hidden_dim, 1 if k == "opacity" else 3, ), ) def forward(self, pt_feat, onehots, zs): b, n, _ = pt_feat.size() f = self.fc_1(pt_feat) f = f + self.fc_m_a(onehots) f = self.act(f) if zs is None: output = self._instance_forward(f) else: output = { k: torch.zeros(b, n, 1 if k == "opacity" else 3, device=pt_feat.device) for k in self.factors.keys() } for v in zs.values(): z = v["z"] idx = v["idx"] _output = self._instance_forward(f[idx].unsqueeze(dim=0), z) for k, v in _output.items(): output[k][idx] = v return output def _instance_forward(self, f, z=None): for i in range(2, self.n_shared_layers + 1): fc = getattr(self, "fc_%d" % i) f = self.act(fc(f, z) if z is not None else fc(f)) output = {} for k in self.factors.keys(): _f = f.clone() for i in range(self.n_layers[k]): _fc = getattr(self, "fc_%d_%s_%d" % (self.n_shared_layers + 1, k, i)) _f = self.act(_fc(_f, z) if z is not None else _fc(f)) fc_out = getattr(self, "fc_out_%s" % k) output[k] = fc_out(_f) if "xyz" in self.factors: output["xyz"] = (torch.sigmoid(output["xyz"]) - 0.5) * self.factors["xyz"] if "rgb" in self.factors: output["rgb"] = (torch.sigmoid(output["rgb"]) - 0.5) * self.factors["rgb"] if "scale" in self.factors: output["scale"] = 1 + output["scale"].clamp(-1, 1) * self.factors["scale"] if "opacity" in self.factors: output["opacity"] = torch.sigmoid(output["opacity"]) * self.factors[ "opacity" ] + (1 - self.factors["opacity"]) return output class ModLinear(torch.nn.Module): r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod). Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across multiple inputs. Args: in_features (int): Number of input features. out_features (int): Number of output features. style_features (int): Number of style features. bias (bool): Apply additive bias before the activation function? mod_bias (bool): Whether to modulate bias. output_mode (bool): If True, modulate output instead of input. weight_gain (float): Initialization gain """ def __init__( self, in_features, out_features, style_features, bias=True, mod_bias=True, output_mode=False, weight_gain=1, bias_init=0, ): super(ModLinear, self).__init__() weight_gain = weight_gain / np.sqrt(in_features) self.weight = torch.nn.Parameter( torch.randn([out_features, in_features]) * weight_gain ) self.bias = ( torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None ) self.weight_alpha = torch.nn.Parameter( torch.randn([in_features, style_features]) / np.sqrt(style_features) ) self.bias_alpha = torch.nn.Parameter( torch.full([in_features], 1, dtype=torch.float) ) # init to 1 self.weight_beta = None self.bias_beta = None self.mod_bias = mod_bias self.output_mode = output_mode if mod_bias: if output_mode: mod_bias_dims = out_features else: mod_bias_dims = in_features self.weight_beta = torch.nn.Parameter( torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features) ) self.bias_beta = torch.nn.Parameter( torch.full([mod_bias_dims], 0, dtype=torch.float) ) @staticmethod def _linear_f(x, w, b): w = w.to(x.dtype) x_shape = x.shape x = x.reshape(-1, x_shape[-1]) if b is not None: b = b.to(x.dtype) x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = x.reshape(*x_shape[:-1], -1) return x # x: B, ... , Cin # z: B, ... , Cz def forward(self, x, z): x_shape = x.shape z_shape = z.shape x = x.reshape(x_shape[0], -1, x_shape[-1]) z = z.reshape(z_shape[0], -1, z_shape[-1]) alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I] w = self.weight.to(x.dtype) # [O I] w = w.unsqueeze(0) * alpha if self.mod_bias: beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I] if not self.output_mode: x = x + beta b = self.bias if b is not None: b = b.to(x.dtype)[None, None, :] if self.mod_bias and self.output_mode: if b is None: b = beta else: b = b + beta # [B ? I] @ [B I O] = [B ? O] if b is not None: x = torch.baddbmm(b, x, w.transpose(1, 2)) else: x = x.bmm(w.transpose(1, 2)) x = x.reshape(*x_shape[:-1], x.shape[-1]) return x