# -*- coding: utf-8 -*- # # @File: inference.py # @Author: Haozhe Xie # @Date: 2024-03-02 16:30:00 # @Last Modified by: Haozhe Xie # @Last Modified at: 2024-10-13 15:17:20 # @Email: root@haozhexie.com import cv2 import math import numpy as np import scipy.spatial.transform import torch from tqdm import tqdm CLASSES = { "NULL": 0, "ROAD": 1, "BLDG_FACADE": 2, "GREEN_LANDS": 3, "CONSTRUCTION": 4, "COAST_ZONES": 5, "ZONE": 6, "BLDG_ROOF": 7, } SCALES = { "ROAD": 2, "BLDG_FACADE": 1, "BLDG_ROOF": 1, "GREEN_LANDS": 2, "CONSTRUCTION": 1, "COAST_ZONES": 4, "ZONE": 2, } CONSTANTS = { "CAM_K": [1528.1469407006614, 0, 480, 0, 1528.1469407006614, 270, 0, 0, 1], "SENSOR_SIZE": [960, 540], "BLDG_INST_RANGE": [100, 16384], "PROJECTION_SIZE": 2048, "POINT_SCALE_FACTOR": 0.5, "SPECIAL_Z_SCALE_CLASSES": [ CLASSES["ROAD"], CLASSES["COAST_ZONES"], CLASSES["ZONE"], ], } def get_instance_seg_map(seg_map): # Mapping constructions to buildings seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLDG_FACADE"] # Use connected components to get building instances _, labels, _, _ = cv2.connectedComponentsWithStats( (seg_map == CLASSES["BLDG_FACADE"]).astype(np.uint8), connectivity=4 ) # Remove non-building instance masks labels[seg_map != CLASSES["BLDG_FACADE"]] = 0 # Building instance mask building_mask = labels != 0 # Make building instance IDs are even numbers and start from 10 # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1. labels = (labels + CONSTANTS["BLDG_INST_RANGE"][0]) * 2 seg_map[seg_map == CLASSES["BLDG_FACADE"]] = 0 seg_map = seg_map * (1 - building_mask) + labels * building_mask assert np.max(labels) < 2147483648 return seg_map.astype(np.int32) def get_point_map(seg_map): inverted_index = {v: k for k, v in CLASSES.items()} pts_map = np.zeros(seg_map.shape, dtype=bool) for c in np.unique(seg_map): cls_name = inverted_index[c] if cls_name == "NULL": continue mask = seg_map == c pt_map = _get_point_map(seg_map.shape, SCALES[cls_name]) pt_map[~mask] = False pts_map += pt_map return pts_map def _get_point_map(map_size, stride): pts_map = np.zeros(map_size, dtype=bool) ys = np.arange(0, map_size[0], stride) xs = np.arange(0, map_size[1], stride) coords = np.stack(np.meshgrid(ys, xs), axis=-1).reshape(-1, 2) pts_map[coords[:, 0], coords[:, 1]] = True return pts_map def get_centers(ins_map, td_hf): centers = {} instances = np.unique(ins_map) for i in tqdm(instances, desc="Calculating centers ..."): if i >= CONSTANTS["BLDG_INST_RANGE"][0]: ds_mask = ins_map == i contours, _ = cv2.findContours( ds_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE, ) contours = np.vstack(contours).reshape(-1, 2) min_x, max_x = np.min(contours[:, 0]), np.max(contours[:, 0]) min_y, max_y = np.min(contours[:, 1]), np.max(contours[:, 1]) max_z = np.max(td_hf[ds_mask]) + 1 else: min_x, max_x = 0, CONSTANTS["PROJECTION_SIZE"] min_y, max_y = 0, CONSTANTS["PROJECTION_SIZE"] max_z = np.max(td_hf) centers[i] = np.array( [ (min_x + max_x) / 2, (min_y + max_y) / 2, (max_x - min_x), (max_y - min_y), max_z, ], dtype=np.float32, ) return centers def generate_city( fgm, bgm, city_layout, cx, cy, radius, altitude, azimuth, style_lut=None ): import gaussiancity.extensions.diff_gaussian_rasterization as dgr device = torch.device("cuda") gr = dgr.GaussianRasterizerWrapper( np.array(CONSTANTS["CAM_K"], dtype=np.float32).reshape((3, 3)), CONSTANTS["SENSOR_SIZE"], flip_lr=True, flip_ud=False, device=device, ) layout = _get_local_layout( city_layout, cx, cy, CONSTANTS["PROJECTION_SIZE"] // 2, CONSTANTS["BLDG_INST_RANGE"], device, ) bev_pts = _get_bev_points(layout, SCALES, CLASSES) bev_pt_classes = _instances_to_classes( bev_pts[:, [3]], CONSTANTS["BLDG_INST_RANGE"], CLASSES ) bev_pt_classes_onehot = _get_onehot_seg(bev_pt_classes, len(CLASSES)) bev_pt_scales = _get_point_scales( bev_pt_classes, SCALES, CLASSES, CONSTANTS["SPECIAL_Z_SCALE_CLASSES"], ) bev_pts = torch.cat([bev_pts, bev_pt_scales, bev_pt_classes_onehot], dim=1) # print(bev_pts.shape) # [N, XYZ + Inst + Scale3D + N_CLASSES] if style_lut is None: style_lut = _get_style_lut( layout["CTR"], {"BLDG": fgm, "REST": bgm}, { "BLDG": CONSTANTS["BLDG_INST_RANGE"], "REST": [0, CONSTANTS["BLDG_INST_RANGE"][0]], }, device, ) cam_look_at, cam_pose = _get_orbit_camera_pose( radius, altitude, azimuth, CONSTANTS["PROJECTION_SIZE"] // 2, device ) vp_idx = _get_visible_points( bev_pts[:, :3], bev_pt_scales, CONSTANTS["CAM_K"], CONSTANTS["SENSOR_SIZE"], cam_pose[:3], cam_look_at, ) gs_attrs = _get_gs_attrs( bev_pts[vp_idx], layout["TD_HF"].float(), layout["SEG"].float(), style_lut, layout["CTR"], {"BLDG": fgm, "REST": bgm}, CONSTANTS["POINT_SCALE_FACTOR"], CONSTANTS["BLDG_INST_RANGE"], ) return _render(gs_attrs, gr, cam_pose) def _get_local_layout(city_layout, cx, cy, half_proj_size, bldg_inst_range, device): x_min, x_max = cx - half_proj_size, cx + half_proj_size y_min, y_max = cy - half_proj_size, cy + half_proj_size _layout = { k: torch.from_numpy(v[None, None, y_min:y_max, x_min:x_max]).cuda(device) for k, v in city_layout.items() if k in ["TD_HF", "BU_HF", "SEG", "INS", "PTS"] } _layout["SEG"] = _get_onehot_seg(_layout["SEG"], len(CLASSES)) _instances = torch.unique(_layout["INS"]) _centers = {} for inst in _instances: inst = inst.item() if inst >= bldg_inst_range[0]: _centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device) _centers[inst][0] -= x_min _centers[inst][1] -= y_min _centers[inst + 1] = _centers[inst] # Fix the centers for BLDG_ROOF else: _centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device) _centers[inst][0] = x_min _centers[inst][1] = y_min _layout["CTR"] = _centers return _layout def _get_onehot_seg(seg_map, n_classes): shape = seg_map.shape # shape -> NxCxHxW or NxC # assert shape[1] == 1 output_shape = (shape[0], n_classes, *shape[2:]) one_hot_masks = torch.zeros(output_shape, device=seg_map.device, dtype=torch.bool) for i in range(n_classes): one_hot_masks[:, [i]] = seg_map == i return one_hot_masks def _get_style_lut(centers, models, inst_ranges, device, z_dim=256): lut = {ins: torch.rand(1, z_dim, device=device) for ins in centers.keys()} for k, v in models.items(): if v is None: continue if v.module.cfg.Z_DIM is None: for i in range(*inst_ranges[k]): if i in lut: del lut[i] continue if hasattr(v.module, "z"): zs = v.module.z lut.update( { ins: zs[np.random.choice(list(zs.keys()))].unsqueeze(0) for ins in centers.keys() } ) return lut def _get_orbit_camera_pose(radius, altitude, azimuth, half_proj_size, device): cx, cy = half_proj_size, half_proj_size theta = np.deg2rad(azimuth) cam_x = cx + radius * math.cos(theta) cam_y = cy + radius * math.sin(theta) cam_pos = np.array([cam_x, cam_y, altitude], dtype=np.float32) cam_look_at = np.array([cx, cy, 1], dtype=np.float32) quat = _get_quat_from_look_at(cam_pos, cam_look_at) return torch.tensor([*cam_look_at], device=device), torch.tensor( [*cam_pos, *quat], device=device ) def _get_quat_from_look_at(cam_pos, cam_look_at): fwd_vec = cam_look_at - cam_pos fwd_vec /= np.linalg.norm(fwd_vec) up_vec = np.array([0, 0, 1]) right_vec = np.cross(up_vec, fwd_vec) right_vec /= np.linalg.norm(right_vec) up_vec = np.cross(fwd_vec, right_vec) R = np.stack([fwd_vec, right_vec, up_vec], axis=1) return scipy.spatial.transform.Rotation.from_matrix(R).as_quat() def _get_bev_points(layout, scales, classes): import gaussiancity.extensions.voxlib assert torch.max(layout["INS"]) < 16384 # torch.nonzero(torch.zeros(2048, 2048, 512).cuda()) # -> nonzero is not supported for tensors with more than INT_MAX elements # torch.nonzero(torch.zeros(2048, 2048, 508).cuda()) # -> an illegal memory access was encountered assert torch.max(layout["TD_HF"]) <= 500 volume = gaussiancity.extensions.voxlib.maps_to_volume( layout["INS"].squeeze().short(), layout["TD_HF"].squeeze().short(), layout["BU_HF"].squeeze().short(), layout["PTS"].squeeze().bool(), torch.tensor( [scales[k] if k in scales else 0 for k in classes.keys()], dtype=torch.int8, device=layout["INS"].device, ), ) non_zero_indices = torch.nonzero(volume, as_tuple=False) non_zero_values = volume[ non_zero_indices[:, 0], non_zero_indices[:, 1], non_zero_indices[:, 2] ] return torch.cat( [non_zero_indices.short(), non_zero_values.unsqueeze(dim=1)], dim=1 ) def _instances_to_classes(instances, bldg_inst_range, bldg_classes): bldg_facade_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 0) bldg_roof_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 1) classes = instances.clone() classes[bldg_facade_idx] = bldg_classes["BLDG_FACADE"] classes[bldg_roof_idx] = bldg_classes["BLDG_ROOF"] return classes def _get_point_scales(pt_classes, scales, classes, special_z_scale_classes=[]): pt_scales = pt_classes.clone() for k, v in scales.items(): pt_scales[pt_classes == classes[k]] = v pt_scales_3d = torch.ones_like(pt_scales).repeat(1, 3) * pt_scales # Set the z-scale = 1 for roads, zones, and waters pt_scales_3d[..., 2][ torch.isin( pt_classes.squeeze(dim=-1), torch.tensor( list(special_z_scale_classes), device=pt_classes.device, ), ) ] = 1 return pt_scales_3d def _get_visible_points(points, scales, K, sensor_size, cam_pos, cam_look_at): ## NOTE: Each point is assigned with a unique ID. The values in the rendered map ## denotes the visibility of the points. The values are the same as the point IDs. # Generate 3D volume volume, offsets = _get_volume(points, scales) # Ray-voxel intersection vp_map = _get_ray_voxel_intersection( K, sensor_size, cam_pos - offsets, cam_look_at - cam_pos, volume ) ## Generate the instance segmentation map as a side product # ins_map = instances[vp_map] # null_mask = vp_map == -1 # ins_map[null_mask] = null_class_id # Manually release the memory to avoid OOM del volume torch.cuda.empty_cache() vp_idx = torch.unique(vp_map) return vp_idx[vp_idx >= 0] def _get_volume(points, scales): import gaussiancity.extensions.voxlib x_min, x_max = torch.min(points[:, 0]).item(), torch.max(points[:, 0]).item() y_min, y_max = torch.min(points[:, 1]).item(), torch.max(points[:, 1]).item() z_min, z_max = torch.min(points[:, 2]).item(), torch.max(points[:, 2]).item() offsets = torch.tensor( [x_min, y_min, z_min], dtype=torch.int16, device=points.device ) # Normalize points coordinates to local coordinate system points = _get_localized_pt_cords(points, offsets) # Generate an empty 3D volume w, h, d = x_max - x_min + 1, y_max - y_min + 1, z_max - z_min + 2 # Generate point IDs # NOTE: The point IDs start from 1 to avoid the conflict with the NULL class. assert points.shape[0] < 2147483648 pt_ids = torch.arange( start=1, end=points.shape[0] + 1, dtype=torch.int32, device=points.device ).unsqueeze(dim=1) volume = gaussiancity.extensions.voxlib.points_to_volume( points.contiguous(), pt_ids, scales, h, w, d ) return volume, offsets def _get_localized_pt_cords(points, offsets): points[:, 0] -= offsets[0] points[:, 1] -= offsets[1] points[:, 2] -= offsets[2] - 1 return points def _get_ray_voxel_intersection(K, sensor_size, cam_origin, viewdir, volume): import gaussiancity.extensions.voxlib N_MAX_SAMPLES = 1 voxel_id, _, _ = gaussiancity.extensions.voxlib.ray_voxel_intersection_perspective( volume, cam_origin[[1, 0, 2]].float(), viewdir[[1, 0, 2]].float(), torch.tensor([0, 0, 1], dtype=torch.float32), K[0], [K[5], K[2]], [sensor_size[1], sensor_size[0]], N_MAX_SAMPLES, ) # NOTE: The point ID for NULL class is -1, the rest point IDs are from 0 to N - 1. # The ray_voxel_intersection_perspective seems not accepting the negative values. return voxel_id.squeeze() - 1 def get_hf_seg_tensor(part_hf, part_seg, layout_cfg, output_device): part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device) part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device) part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"] part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"]) return torch.cat([part_hf, part_seg], dim=1) def _masks_to_onehots(masks, n_class, ignored_classes=[]): b, h, w = masks.shape n_class_actual = n_class - len(ignored_classes) one_hot_masks = torch.zeros( (b, n_class_actual, h, w), dtype=torch.float32, device=masks.device ) n_class_cnt = 0 for i in range(n_class): if i not in ignored_classes: one_hot_masks[:, n_class_cnt] = masks == i n_class_cnt += 1 return one_hot_masks def _get_gs_attrs( pts, proj_hf, proj_seg, style_lut, centers, models, scale_factor, bldg_inst_range, ): n_pts, _ = pts.shape # NOTE: 4: XYZ, Instance ID; 3: Scale; N_CLASSES: One-hot # print(pts.shape) # [N, 4 + 3 + N_CLASSES] bldg_selector = pts[:, 3] >= bldg_inst_range[0] bldg_pts = pts[bldg_selector] rest_pts = pts[~bldg_selector] bldg_attrs = _get_pt_input_attrs( bldg_pts[:, :4], centers, style_lut, models["BLDG"].module.cfg.Z_DIM, bldg_inst_range, ) rest_attrs = _get_pt_input_attrs( rest_pts[:, :4], centers, style_lut, models["REST"].module.cfg.Z_DIM, bldg_inst_range, ) bldg_colors = _get_gs_colors( bldg_pts, bldg_attrs, proj_hf, proj_seg, models["BLDG"] ) rest_colors = _get_gs_colors( rest_pts, rest_attrs, proj_hf, proj_seg, models["REST"] ) abs_xyz = torch.cat([bldg_pts[:, :3], rest_pts[:, :3]], dim=0) scales = torch.cat([bldg_pts[:, 4:7], rest_pts[:, 4:7]], dim=0) * scale_factor rgb = torch.cat([bldg_colors, rest_colors], dim=0) # Attributes with default values opacity = torch.ones((n_pts, 1), device=pts.device) rotations = torch.cat( [ torch.ones(n_pts, 1, device=pts.device), torch.zeros(n_pts, 3, device=pts.device), ], dim=-1, ) return torch.cat((abs_xyz, opacity, scales, rotations, rgb), dim=-1) def _get_pt_input_attrs(pts, centers, style_lut, z_dim, bldg_inst_range): n_pts = pts.shape[0] instances = torch.unique(pts[:, -1]) rel_xyz = torch.zeros(1, n_pts, 3, dtype=torch.float32, device=pts.device) batch_idx = torch.zeros(1, n_pts, dtype=torch.int32, device=pts.device) zs = {} if z_dim is not None else None for idx, ins in enumerate(instances): ins = ins.item() is_pts = pts[:, -1] == ins cx, cy, w, h, d = centers[ins] if ins >= bldg_inst_range[0]: rel_xyz[:, is_pts, 0] = (pts[is_pts, 0] - cx) / w * 2 if w > 0 else 0 rel_xyz[:, is_pts, 1] = (pts[is_pts, 1] - cy) / h * 2 if h > 0 else 0 else: # Make the BG contiguous period_x = torch.ceil((pts[is_pts, 0] / w / 2) - 0.5) period_y = torch.ceil((pts[is_pts, 1] / h / 2) - 0.5) rel_xyz[:, is_pts, 0] = ( (pts[is_pts, 0] - 2 * period_x * w) * (-1) ** period_x ) / w rel_xyz[:, is_pts, 1] = ( (pts[is_pts, 1] - 2 * period_y * h) * (-1) ** period_y ) / h rel_xyz[:, is_pts, 2] = ( torch.clip(pts[is_pts, 2] / d * 2 - 1, -1, 1) if d > 0 else 0 ) batch_idx[:, is_pts] = idx if zs is not None: zs[ins] = {"z": style_lut[ins], "idx": is_pts.unsqueeze(dim=0)} return rel_xyz, batch_idx, zs def _get_gs_colors(pts, pt_attrs, proj_hf, proj_seg, model): if pts.shape[0] == 0: return torch.empty(0, 3, dtype=torch.float32, device=pts.device) abs_xyz, onehots = pts[None, :, :3], pts[None, :, 7:] rel_xyz, batch_idx, zs = pt_attrs proj_uv = None if model.module.cfg.ENCODER is not None: proj_uv = get_projection_uv(abs_xyz) with torch.no_grad(): # TODO: Optimize the _instance_forward in Generator gs_attrs = model( proj_uv, rel_xyz, batch_idx, onehots.float(), zs, proj_hf, proj_seg ) return gs_attrs["rgb"].squeeze(dim=0) def get_projection_uv(xyz, proj_tlp=None, proj_size=2048): n_pts = xyz.size(1) if proj_tlp is None: proj_uv = xyz[..., :2].clone().float() else: proj_uv = xyz[..., :2] - proj_tlp.unsqueeze(dim=1) assert proj_uv.size() == (xyz.size(0), n_pts, 2) proj_uv[..., 0] /= proj_size proj_uv[..., 1] /= proj_size # Normalize to [-1, 1] return proj_uv * 2 - 1 def _render(gs_attrs, rasterizator, cam_pose): import torchvision.transforms.functional as F with torch.no_grad(): img = rasterizator( gs_attrs, cam_pose[:3], # Position cam_pose[3:], # Quaternion ) img = img.squeeze() / 2 + 0.5 img = F.adjust_brightness(img, 1.2) img = F.adjust_contrast(img, 1.2) return (img * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)