Spaces:
Running on Zero
Running on Zero
| # -*- coding: utf-8 -*- | |
| # | |
| # @File: inference.py | |
| # @Author: Haozhe Xie | |
| # @Date: 2024-03-02 16:30:00 | |
| # @Last Modified by: Haozhe Xie | |
| # @Last Modified at: 2024-10-13 15:17:20 | |
| # @Email: root@haozhexie.com | |
| import cv2 | |
| import math | |
| import numpy as np | |
| import scipy.spatial.transform | |
| import torch | |
| from tqdm import tqdm | |
| CLASSES = { | |
| "NULL": 0, | |
| "ROAD": 1, | |
| "BLDG_FACADE": 2, | |
| "GREEN_LANDS": 3, | |
| "CONSTRUCTION": 4, | |
| "COAST_ZONES": 5, | |
| "ZONE": 6, | |
| "BLDG_ROOF": 7, | |
| } | |
| SCALES = { | |
| "ROAD": 2, | |
| "BLDG_FACADE": 1, | |
| "BLDG_ROOF": 1, | |
| "GREEN_LANDS": 2, | |
| "CONSTRUCTION": 1, | |
| "COAST_ZONES": 4, | |
| "ZONE": 2, | |
| } | |
| CONSTANTS = { | |
| "CAM_K": [1528.1469407006614, 0, 480, 0, 1528.1469407006614, 270, 0, 0, 1], | |
| "SENSOR_SIZE": [960, 540], | |
| "BLDG_INST_RANGE": [100, 16384], | |
| "PROJECTION_SIZE": 2048, | |
| "POINT_SCALE_FACTOR": 0.5, | |
| "SPECIAL_Z_SCALE_CLASSES": [ | |
| CLASSES["ROAD"], | |
| CLASSES["COAST_ZONES"], | |
| CLASSES["ZONE"], | |
| ], | |
| } | |
| def get_instance_seg_map(seg_map): | |
| # Mapping constructions to buildings | |
| seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLDG_FACADE"] | |
| # Use connected components to get building instances | |
| _, labels, _, _ = cv2.connectedComponentsWithStats( | |
| (seg_map == CLASSES["BLDG_FACADE"]).astype(np.uint8), connectivity=4 | |
| ) | |
| # Remove non-building instance masks | |
| labels[seg_map != CLASSES["BLDG_FACADE"]] = 0 | |
| # Building instance mask | |
| building_mask = labels != 0 | |
| # Make building instance IDs are even numbers and start from 10 | |
| # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1. | |
| labels = (labels + CONSTANTS["BLDG_INST_RANGE"][0]) * 2 | |
| seg_map[seg_map == CLASSES["BLDG_FACADE"]] = 0 | |
| seg_map = seg_map * (1 - building_mask) + labels * building_mask | |
| assert np.max(labels) < 2147483648 | |
| return seg_map.astype(np.int32) | |
| def get_point_map(seg_map): | |
| inverted_index = {v: k for k, v in CLASSES.items()} | |
| pts_map = np.zeros(seg_map.shape, dtype=bool) | |
| for c in np.unique(seg_map): | |
| cls_name = inverted_index[c] | |
| if cls_name == "NULL": | |
| continue | |
| mask = seg_map == c | |
| pt_map = _get_point_map(seg_map.shape, SCALES[cls_name]) | |
| pt_map[~mask] = False | |
| pts_map += pt_map | |
| return pts_map | |
| def _get_point_map(map_size, stride): | |
| pts_map = np.zeros(map_size, dtype=bool) | |
| ys = np.arange(0, map_size[0], stride) | |
| xs = np.arange(0, map_size[1], stride) | |
| coords = np.stack(np.meshgrid(ys, xs), axis=-1).reshape(-1, 2) | |
| pts_map[coords[:, 0], coords[:, 1]] = True | |
| return pts_map | |
| def get_centers(ins_map, td_hf): | |
| centers = {} | |
| instances = np.unique(ins_map) | |
| for i in tqdm(instances, desc="Calculating centers ..."): | |
| if i >= CONSTANTS["BLDG_INST_RANGE"][0]: | |
| ds_mask = ins_map == i | |
| contours, _ = cv2.findContours( | |
| ds_mask.astype(np.uint8), | |
| cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_SIMPLE, | |
| ) | |
| contours = np.vstack(contours).reshape(-1, 2) | |
| min_x, max_x = np.min(contours[:, 0]), np.max(contours[:, 0]) | |
| min_y, max_y = np.min(contours[:, 1]), np.max(contours[:, 1]) | |
| max_z = np.max(td_hf[ds_mask]) + 1 | |
| else: | |
| min_x, max_x = 0, CONSTANTS["PROJECTION_SIZE"] | |
| min_y, max_y = 0, CONSTANTS["PROJECTION_SIZE"] | |
| max_z = np.max(td_hf) | |
| centers[i] = np.array( | |
| [ | |
| (min_x + max_x) / 2, | |
| (min_y + max_y) / 2, | |
| (max_x - min_x), | |
| (max_y - min_y), | |
| max_z, | |
| ], | |
| dtype=np.float32, | |
| ) | |
| return centers | |
| def generate_city( | |
| fgm, bgm, city_layout, cx, cy, radius, altitude, azimuth, style_lut=None | |
| ): | |
| import gaussiancity.extensions.diff_gaussian_rasterization as dgr | |
| device = torch.device("cuda") | |
| gr = dgr.GaussianRasterizerWrapper( | |
| np.array(CONSTANTS["CAM_K"], dtype=np.float32).reshape((3, 3)), | |
| CONSTANTS["SENSOR_SIZE"], | |
| flip_lr=True, | |
| flip_ud=False, | |
| device=device, | |
| ) | |
| layout = _get_local_layout( | |
| city_layout, | |
| cx, | |
| cy, | |
| CONSTANTS["PROJECTION_SIZE"] // 2, | |
| CONSTANTS["BLDG_INST_RANGE"], | |
| device, | |
| ) | |
| bev_pts = _get_bev_points(layout, SCALES, CLASSES) | |
| bev_pt_classes = _instances_to_classes( | |
| bev_pts[:, [3]], CONSTANTS["BLDG_INST_RANGE"], CLASSES | |
| ) | |
| bev_pt_classes_onehot = _get_onehot_seg(bev_pt_classes, len(CLASSES)) | |
| bev_pt_scales = _get_point_scales( | |
| bev_pt_classes, | |
| SCALES, | |
| CLASSES, | |
| CONSTANTS["SPECIAL_Z_SCALE_CLASSES"], | |
| ) | |
| bev_pts = torch.cat([bev_pts, bev_pt_scales, bev_pt_classes_onehot], dim=1) | |
| # print(bev_pts.shape) # [N, XYZ + Inst + Scale3D + N_CLASSES] | |
| if style_lut is None: | |
| style_lut = _get_style_lut( | |
| layout["CTR"], | |
| {"BLDG": fgm, "REST": bgm}, | |
| { | |
| "BLDG": CONSTANTS["BLDG_INST_RANGE"], | |
| "REST": [0, CONSTANTS["BLDG_INST_RANGE"][0]], | |
| }, | |
| device, | |
| ) | |
| cam_look_at, cam_pose = _get_orbit_camera_pose( | |
| radius, altitude, azimuth, CONSTANTS["PROJECTION_SIZE"] // 2, device | |
| ) | |
| vp_idx = _get_visible_points( | |
| bev_pts[:, :3], | |
| bev_pt_scales, | |
| CONSTANTS["CAM_K"], | |
| CONSTANTS["SENSOR_SIZE"], | |
| cam_pose[:3], | |
| cam_look_at, | |
| ) | |
| gs_attrs = _get_gs_attrs( | |
| bev_pts[vp_idx], | |
| layout["TD_HF"].float(), | |
| layout["SEG"].float(), | |
| style_lut, | |
| layout["CTR"], | |
| {"BLDG": fgm, "REST": bgm}, | |
| CONSTANTS["POINT_SCALE_FACTOR"], | |
| CONSTANTS["BLDG_INST_RANGE"], | |
| ) | |
| return _render(gs_attrs, gr, cam_pose) | |
| def _get_local_layout(city_layout, cx, cy, half_proj_size, bldg_inst_range, device): | |
| x_min, x_max = cx - half_proj_size, cx + half_proj_size | |
| y_min, y_max = cy - half_proj_size, cy + half_proj_size | |
| _layout = { | |
| k: torch.from_numpy(v[None, None, y_min:y_max, x_min:x_max]).cuda(device) | |
| for k, v in city_layout.items() | |
| if k in ["TD_HF", "BU_HF", "SEG", "INS", "PTS"] | |
| } | |
| _layout["SEG"] = _get_onehot_seg(_layout["SEG"], len(CLASSES)) | |
| _instances = torch.unique(_layout["INS"]) | |
| _centers = {} | |
| for inst in _instances: | |
| inst = inst.item() | |
| if inst >= bldg_inst_range[0]: | |
| _centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device) | |
| _centers[inst][0] -= x_min | |
| _centers[inst][1] -= y_min | |
| _centers[inst + 1] = _centers[inst] # Fix the centers for BLDG_ROOF | |
| else: | |
| _centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device) | |
| _centers[inst][0] = x_min | |
| _centers[inst][1] = y_min | |
| _layout["CTR"] = _centers | |
| return _layout | |
| def _get_onehot_seg(seg_map, n_classes): | |
| shape = seg_map.shape | |
| # shape -> NxCxHxW or NxC | |
| # assert shape[1] == 1 | |
| output_shape = (shape[0], n_classes, *shape[2:]) | |
| one_hot_masks = torch.zeros(output_shape, device=seg_map.device, dtype=torch.bool) | |
| for i in range(n_classes): | |
| one_hot_masks[:, [i]] = seg_map == i | |
| return one_hot_masks | |
| def _get_style_lut(centers, models, inst_ranges, device, z_dim=256): | |
| lut = {ins: torch.rand(1, z_dim, device=device) for ins in centers.keys()} | |
| for k, v in models.items(): | |
| if v is None: | |
| continue | |
| if v.module.cfg.Z_DIM is None: | |
| for i in range(*inst_ranges[k]): | |
| if i in lut: | |
| del lut[i] | |
| continue | |
| if hasattr(v.module, "z"): | |
| zs = v.module.z | |
| lut.update( | |
| { | |
| ins: zs[np.random.choice(list(zs.keys()))].unsqueeze(0) | |
| for ins in centers.keys() | |
| } | |
| ) | |
| return lut | |
| def _get_orbit_camera_pose(radius, altitude, azimuth, half_proj_size, device): | |
| cx, cy = half_proj_size, half_proj_size | |
| theta = np.deg2rad(azimuth) | |
| cam_x = cx + radius * math.cos(theta) | |
| cam_y = cy + radius * math.sin(theta) | |
| cam_pos = np.array([cam_x, cam_y, altitude], dtype=np.float32) | |
| cam_look_at = np.array([cx, cy, 1], dtype=np.float32) | |
| quat = _get_quat_from_look_at(cam_pos, cam_look_at) | |
| return torch.tensor([*cam_look_at], device=device), torch.tensor( | |
| [*cam_pos, *quat], device=device | |
| ) | |
| def _get_quat_from_look_at(cam_pos, cam_look_at): | |
| fwd_vec = cam_look_at - cam_pos | |
| fwd_vec /= np.linalg.norm(fwd_vec) | |
| up_vec = np.array([0, 0, 1]) | |
| right_vec = np.cross(up_vec, fwd_vec) | |
| right_vec /= np.linalg.norm(right_vec) | |
| up_vec = np.cross(fwd_vec, right_vec) | |
| R = np.stack([fwd_vec, right_vec, up_vec], axis=1) | |
| return scipy.spatial.transform.Rotation.from_matrix(R).as_quat() | |
| def _get_bev_points(layout, scales, classes): | |
| import gaussiancity.extensions.voxlib | |
| assert torch.max(layout["INS"]) < 16384 | |
| # torch.nonzero(torch.zeros(2048, 2048, 512).cuda()) | |
| # -> nonzero is not supported for tensors with more than INT_MAX elements | |
| # torch.nonzero(torch.zeros(2048, 2048, 508).cuda()) | |
| # -> an illegal memory access was encountered | |
| assert torch.max(layout["TD_HF"]) <= 500 | |
| volume = gaussiancity.extensions.voxlib.maps_to_volume( | |
| layout["INS"].squeeze().short(), | |
| layout["TD_HF"].squeeze().short(), | |
| layout["BU_HF"].squeeze().short(), | |
| layout["PTS"].squeeze().bool(), | |
| torch.tensor( | |
| [scales[k] if k in scales else 0 for k in classes.keys()], | |
| dtype=torch.int8, | |
| device=layout["INS"].device, | |
| ), | |
| ) | |
| non_zero_indices = torch.nonzero(volume, as_tuple=False) | |
| non_zero_values = volume[ | |
| non_zero_indices[:, 0], non_zero_indices[:, 1], non_zero_indices[:, 2] | |
| ] | |
| return torch.cat( | |
| [non_zero_indices.short(), non_zero_values.unsqueeze(dim=1)], dim=1 | |
| ) | |
| def _instances_to_classes(instances, bldg_inst_range, bldg_classes): | |
| bldg_facade_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 0) | |
| bldg_roof_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 1) | |
| classes = instances.clone() | |
| classes[bldg_facade_idx] = bldg_classes["BLDG_FACADE"] | |
| classes[bldg_roof_idx] = bldg_classes["BLDG_ROOF"] | |
| return classes | |
| def _get_point_scales(pt_classes, scales, classes, special_z_scale_classes=[]): | |
| pt_scales = pt_classes.clone() | |
| for k, v in scales.items(): | |
| pt_scales[pt_classes == classes[k]] = v | |
| pt_scales_3d = torch.ones_like(pt_scales).repeat(1, 3) * pt_scales | |
| # Set the z-scale = 1 for roads, zones, and waters | |
| pt_scales_3d[..., 2][ | |
| torch.isin( | |
| pt_classes.squeeze(dim=-1), | |
| torch.tensor( | |
| list(special_z_scale_classes), | |
| device=pt_classes.device, | |
| ), | |
| ) | |
| ] = 1 | |
| return pt_scales_3d | |
| def _get_visible_points(points, scales, K, sensor_size, cam_pos, cam_look_at): | |
| ## NOTE: Each point is assigned with a unique ID. The values in the rendered map | |
| ## denotes the visibility of the points. The values are the same as the point IDs. | |
| # Generate 3D volume | |
| volume, offsets = _get_volume(points, scales) | |
| # Ray-voxel intersection | |
| vp_map = _get_ray_voxel_intersection( | |
| K, sensor_size, cam_pos - offsets, cam_look_at - cam_pos, volume | |
| ) | |
| ## Generate the instance segmentation map as a side product | |
| # ins_map = instances[vp_map] | |
| # null_mask = vp_map == -1 | |
| # ins_map[null_mask] = null_class_id | |
| # Manually release the memory to avoid OOM | |
| del volume | |
| torch.cuda.empty_cache() | |
| vp_idx = torch.unique(vp_map) | |
| return vp_idx[vp_idx >= 0] | |
| def _get_volume(points, scales): | |
| import gaussiancity.extensions.voxlib | |
| x_min, x_max = torch.min(points[:, 0]).item(), torch.max(points[:, 0]).item() | |
| y_min, y_max = torch.min(points[:, 1]).item(), torch.max(points[:, 1]).item() | |
| z_min, z_max = torch.min(points[:, 2]).item(), torch.max(points[:, 2]).item() | |
| offsets = torch.tensor( | |
| [x_min, y_min, z_min], dtype=torch.int16, device=points.device | |
| ) | |
| # Normalize points coordinates to local coordinate system | |
| points = _get_localized_pt_cords(points, offsets) | |
| # Generate an empty 3D volume | |
| w, h, d = x_max - x_min + 1, y_max - y_min + 1, z_max - z_min + 2 | |
| # Generate point IDs | |
| # NOTE: The point IDs start from 1 to avoid the conflict with the NULL class. | |
| assert points.shape[0] < 2147483648 | |
| pt_ids = torch.arange( | |
| start=1, end=points.shape[0] + 1, dtype=torch.int32, device=points.device | |
| ).unsqueeze(dim=1) | |
| volume = gaussiancity.extensions.voxlib.points_to_volume( | |
| points.contiguous(), pt_ids, scales, h, w, d | |
| ) | |
| return volume, offsets | |
| def _get_localized_pt_cords(points, offsets): | |
| points[:, 0] -= offsets[0] | |
| points[:, 1] -= offsets[1] | |
| points[:, 2] -= offsets[2] - 1 | |
| return points | |
| def _get_ray_voxel_intersection(K, sensor_size, cam_origin, viewdir, volume): | |
| import gaussiancity.extensions.voxlib | |
| N_MAX_SAMPLES = 1 | |
| voxel_id, _, _ = gaussiancity.extensions.voxlib.ray_voxel_intersection_perspective( | |
| volume, | |
| cam_origin[[1, 0, 2]].float(), | |
| viewdir[[1, 0, 2]].float(), | |
| torch.tensor([0, 0, 1], dtype=torch.float32), | |
| K[0], | |
| [K[5], K[2]], | |
| [sensor_size[1], sensor_size[0]], | |
| N_MAX_SAMPLES, | |
| ) | |
| # NOTE: The point ID for NULL class is -1, the rest point IDs are from 0 to N - 1. | |
| # The ray_voxel_intersection_perspective seems not accepting the negative values. | |
| return voxel_id.squeeze() - 1 | |
| def get_hf_seg_tensor(part_hf, part_seg, layout_cfg, output_device): | |
| part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device) | |
| part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device) | |
| part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"] | |
| part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"]) | |
| return torch.cat([part_hf, part_seg], dim=1) | |
| def _masks_to_onehots(masks, n_class, ignored_classes=[]): | |
| b, h, w = masks.shape | |
| n_class_actual = n_class - len(ignored_classes) | |
| one_hot_masks = torch.zeros( | |
| (b, n_class_actual, h, w), dtype=torch.float32, device=masks.device | |
| ) | |
| n_class_cnt = 0 | |
| for i in range(n_class): | |
| if i not in ignored_classes: | |
| one_hot_masks[:, n_class_cnt] = masks == i | |
| n_class_cnt += 1 | |
| return one_hot_masks | |
| def _get_gs_attrs( | |
| pts, | |
| proj_hf, | |
| proj_seg, | |
| style_lut, | |
| centers, | |
| models, | |
| scale_factor, | |
| bldg_inst_range, | |
| ): | |
| n_pts, _ = pts.shape | |
| # NOTE: 4: XYZ, Instance ID; 3: Scale; N_CLASSES: One-hot | |
| # print(pts.shape) # [N, 4 + 3 + N_CLASSES] | |
| bldg_selector = pts[:, 3] >= bldg_inst_range[0] | |
| bldg_pts = pts[bldg_selector] | |
| rest_pts = pts[~bldg_selector] | |
| bldg_attrs = _get_pt_input_attrs( | |
| bldg_pts[:, :4], | |
| centers, | |
| style_lut, | |
| models["BLDG"].module.cfg.Z_DIM, | |
| bldg_inst_range, | |
| ) | |
| rest_attrs = _get_pt_input_attrs( | |
| rest_pts[:, :4], | |
| centers, | |
| style_lut, | |
| models["REST"].module.cfg.Z_DIM, | |
| bldg_inst_range, | |
| ) | |
| bldg_colors = _get_gs_colors( | |
| bldg_pts, bldg_attrs, proj_hf, proj_seg, models["BLDG"] | |
| ) | |
| rest_colors = _get_gs_colors( | |
| rest_pts, rest_attrs, proj_hf, proj_seg, models["REST"] | |
| ) | |
| abs_xyz = torch.cat([bldg_pts[:, :3], rest_pts[:, :3]], dim=0) | |
| scales = torch.cat([bldg_pts[:, 4:7], rest_pts[:, 4:7]], dim=0) * scale_factor | |
| rgb = torch.cat([bldg_colors, rest_colors], dim=0) | |
| # Attributes with default values | |
| opacity = torch.ones((n_pts, 1), device=pts.device) | |
| rotations = torch.cat( | |
| [ | |
| torch.ones(n_pts, 1, device=pts.device), | |
| torch.zeros(n_pts, 3, device=pts.device), | |
| ], | |
| dim=-1, | |
| ) | |
| return torch.cat((abs_xyz, opacity, scales, rotations, rgb), dim=-1) | |
| def _get_pt_input_attrs(pts, centers, style_lut, z_dim, bldg_inst_range): | |
| n_pts = pts.shape[0] | |
| instances = torch.unique(pts[:, -1]) | |
| rel_xyz = torch.zeros(1, n_pts, 3, dtype=torch.float32, device=pts.device) | |
| batch_idx = torch.zeros(1, n_pts, dtype=torch.int32, device=pts.device) | |
| zs = {} if z_dim is not None else None | |
| for idx, ins in enumerate(instances): | |
| ins = ins.item() | |
| is_pts = pts[:, -1] == ins | |
| cx, cy, w, h, d = centers[ins] | |
| if ins >= bldg_inst_range[0]: | |
| rel_xyz[:, is_pts, 0] = (pts[is_pts, 0] - cx) / w * 2 if w > 0 else 0 | |
| rel_xyz[:, is_pts, 1] = (pts[is_pts, 1] - cy) / h * 2 if h > 0 else 0 | |
| else: | |
| # Make the BG contiguous | |
| period_x = torch.ceil((pts[is_pts, 0] / w / 2) - 0.5) | |
| period_y = torch.ceil((pts[is_pts, 1] / h / 2) - 0.5) | |
| rel_xyz[:, is_pts, 0] = ( | |
| (pts[is_pts, 0] - 2 * period_x * w) * (-1) ** period_x | |
| ) / w | |
| rel_xyz[:, is_pts, 1] = ( | |
| (pts[is_pts, 1] - 2 * period_y * h) * (-1) ** period_y | |
| ) / h | |
| rel_xyz[:, is_pts, 2] = ( | |
| torch.clip(pts[is_pts, 2] / d * 2 - 1, -1, 1) if d > 0 else 0 | |
| ) | |
| batch_idx[:, is_pts] = idx | |
| if zs is not None: | |
| zs[ins] = {"z": style_lut[ins], "idx": is_pts.unsqueeze(dim=0)} | |
| return rel_xyz, batch_idx, zs | |
| def _get_gs_colors(pts, pt_attrs, proj_hf, proj_seg, model): | |
| if pts.shape[0] == 0: | |
| return torch.empty(0, 3, dtype=torch.float32, device=pts.device) | |
| abs_xyz, onehots = pts[None, :, :3], pts[None, :, 7:] | |
| rel_xyz, batch_idx, zs = pt_attrs | |
| proj_uv = None | |
| if model.module.cfg.ENCODER is not None: | |
| proj_uv = get_projection_uv(abs_xyz) | |
| with torch.no_grad(): | |
| # TODO: Optimize the _instance_forward in Generator | |
| gs_attrs = model( | |
| proj_uv, rel_xyz, batch_idx, onehots.float(), zs, proj_hf, proj_seg | |
| ) | |
| return gs_attrs["rgb"].squeeze(dim=0) | |
| def get_projection_uv(xyz, proj_tlp=None, proj_size=2048): | |
| n_pts = xyz.size(1) | |
| if proj_tlp is None: | |
| proj_uv = xyz[..., :2].clone().float() | |
| else: | |
| proj_uv = xyz[..., :2] - proj_tlp.unsqueeze(dim=1) | |
| assert proj_uv.size() == (xyz.size(0), n_pts, 2) | |
| proj_uv[..., 0] /= proj_size | |
| proj_uv[..., 1] /= proj_size | |
| # Normalize to [-1, 1] | |
| return proj_uv * 2 - 1 | |
| def _render(gs_attrs, rasterizator, cam_pose): | |
| import torchvision.transforms.functional as F | |
| with torch.no_grad(): | |
| img = rasterizator( | |
| gs_attrs, | |
| cam_pose[:3], # Position | |
| cam_pose[3:], # Quaternion | |
| ) | |
| img = img.squeeze() / 2 + 0.5 | |
| img = F.adjust_brightness(img, 1.2) | |
| img = F.adjust_contrast(img, 1.2) | |
| return (img * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8) | |