Spaces:
Running on Zero
Running on Zero
| # -*- coding: utf-8 -*- | |
| # | |
| # @File: pt_v3.py | |
| # @Author: Xiaoyang Wu <xiaoyang.wu.cs@gmail.com> | |
| # @Date: 2024-04-01 16:31:36 | |
| # @Last Modified by: Haozhe Xie | |
| # @Last Modified at: 2024-05-15 22:05:09 | |
| # @Email: root@haozhexie.com | |
| # Ref: | |
| # - https://github.com/Pointcept/PointTransformerV3/blob/main/model.py | |
| # - https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py | |
| import addict | |
| import collections | |
| import functools | |
| import flash_attn | |
| import math | |
| import torch | |
| import spconv.pytorch as spconv | |
| import torch_scatter | |
| import typing | |
| def offset2bincount(offset): | |
| return torch.diff( | |
| offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) | |
| ) | |
| def offset2batch(offset): | |
| bincount = offset2bincount(offset) | |
| return torch.arange( | |
| len(bincount), device=offset.device, dtype=torch.long | |
| ).repeat_interleave(bincount) | |
| def batch2offset(batch): | |
| return torch.cumsum(batch.bincount(), dim=0).long() | |
| class KeyLUT: | |
| def __init__(self): | |
| r256 = torch.arange(256, dtype=torch.int64) | |
| r512 = torch.arange(512, dtype=torch.int64) | |
| zero = torch.zeros(256, dtype=torch.int64) | |
| device = torch.device("cpu") | |
| self._encode = { | |
| device: ( | |
| self.xyz2key(r256, zero, zero, 8), | |
| self.xyz2key(zero, r256, zero, 8), | |
| self.xyz2key(zero, zero, r256, 8), | |
| ) | |
| } | |
| self._decode = {device: self.key2xyz(r512, 9)} | |
| def encode_lut(self, device=torch.device("cpu")): | |
| if device not in self._encode: | |
| cpu = torch.device("cpu") | |
| self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) | |
| return self._encode[device] | |
| def decode_lut(self, device=torch.device("cpu")): | |
| if device not in self._decode: | |
| cpu = torch.device("cpu") | |
| self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) | |
| return self._decode[device] | |
| def xyz2key(self, x, y, z, depth): | |
| key = torch.zeros_like(x) | |
| for i in range(depth): | |
| mask = 1 << i | |
| key = ( | |
| key | |
| | ((x & mask) << (2 * i + 2)) | |
| | ((y & mask) << (2 * i + 1)) | |
| | ((z & mask) << (2 * i + 0)) | |
| ) | |
| return key | |
| def key2xyz(self, key, depth): | |
| x = torch.zeros_like(key) | |
| y = torch.zeros_like(key) | |
| z = torch.zeros_like(key) | |
| for i in range(depth): | |
| x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) | |
| y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) | |
| z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) | |
| return x, y, z | |
| class Serializator: | |
| def encode(self, grid_coord, grid_size=0.01, batch=None, depth=16, order="cord"): | |
| assert order in {"cord", "z", "z-trans", "hilbert", "hilbert-trans"} | |
| if order in ["z", "z-trans"]: | |
| self.key_lut = KeyLUT() | |
| if order == "cord": | |
| code = self.cord_encode(grid_coord, grid_size) | |
| elif order == "z": | |
| code = self.z_order_encode(grid_coord, depth=depth) | |
| elif order == "z-trans": | |
| code = self.z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) | |
| elif order == "hilbert": | |
| code = self.hilbert_encode(grid_coord, depth=depth) | |
| elif order == "hilbert-trans": | |
| code = self.hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) | |
| else: | |
| raise NotImplementedError | |
| if batch is not None: | |
| batch = batch.long() | |
| code = batch << depth * 3 | code | |
| return code | |
| def cord_encode(self, grid_coord: torch.Tensor, grid_size: float): | |
| x, y, z = ( | |
| grid_coord[:, 0].long(), | |
| grid_coord[:, 1].long(), | |
| grid_coord[:, 2].long(), | |
| ) | |
| # we block the support to batch, maintain batched code in Point class | |
| code = x / grid_size**2 + y / grid_size + z | |
| return code.long() | |
| def z_order_encode(self, grid_coord: torch.Tensor, depth: int = 16): | |
| x, y, z = ( | |
| grid_coord[:, 0].long(), | |
| grid_coord[:, 1].long(), | |
| grid_coord[:, 2].long(), | |
| ) | |
| # we block the support to batch, maintain batched code in Point class | |
| code = self._xyz2key(x, y, z, b=None, depth=depth) | |
| return code | |
| def _xyz2key( | |
| self, | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| z: torch.Tensor, | |
| b: typing.Optional[typing.Union[torch.Tensor, int]] = None, | |
| depth: int = 16, | |
| ): | |
| r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys | |
| based on pre-computed look up tables. The speed of this function is much | |
| faster than the method based on for-loop. | |
| Args: | |
| x (torch.Tensor): The x coordinate. | |
| y (torch.Tensor): The y coordinate. | |
| z (torch.Tensor): The z coordinate. | |
| b (torch.Tensor or int): The batch index of the coordinates, and should be | |
| smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of | |
| :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. | |
| depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). | |
| """ | |
| EX, EY, EZ = self.key_lut.encode_lut(x.device) | |
| x, y, z = x.long(), y.long(), z.long() | |
| mask = 255 if depth > 8 else (1 << depth) - 1 | |
| key = EX[x & mask] | EY[y & mask] | EZ[z & mask] | |
| if depth > 8: | |
| mask = (1 << (depth - 8)) - 1 | |
| key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] | |
| key = key16 << 24 | key | |
| if b is not None: | |
| b = b.long() | |
| key = b << 48 | key | |
| return key | |
| def hilbert_encode(self, grid_coord: torch.Tensor, depth: int = 16): | |
| return self._hilbert_encode(grid_coord, num_dims=3, num_bits=depth) | |
| def _hilbert_encode(self, locs, num_dims, num_bits): | |
| """Decode an array of locations in a hypercube into a Hilbert integer. | |
| This is a vectorized-ish version of the Hilbert curve implementation by John | |
| Skilling as described in: | |
| Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference | |
| Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. | |
| Params: | |
| ------- | |
| locs - An ndarray of locations in a hypercube of num_dims dimensions, in | |
| which each dimension runs from 0 to 2**num_bits-1. The shape can | |
| be arbitrary, as long as the last dimension of the same has size | |
| num_dims. | |
| num_dims - The dimensionality of the hypercube. Integer. | |
| num_bits - The number of bits for each dimension. Integer. | |
| Returns: | |
| -------- | |
| The output is an ndarray of uint64 integers with the same shape as the | |
| input, excluding the last dimension, which needs to be num_dims. | |
| """ | |
| # Keep around the original shape for later. | |
| orig_shape = locs.shape | |
| bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) | |
| bitpack_mask_rev = bitpack_mask.flip(-1) | |
| if orig_shape[-1] != num_dims: | |
| raise ValueError( | |
| """ | |
| The shape of locs was surprising in that the last dimension was of size | |
| %d, but num_dims=%d. These need to be equal. | |
| """ | |
| % (orig_shape[-1], num_dims) | |
| ) | |
| if num_dims * num_bits > 63: | |
| raise ValueError( | |
| """ | |
| num_dims=%d and num_bits=%d for %d bits total, which can't be encoded | |
| into a int64. Are you sure you need that many points on your Hilbert | |
| curve? | |
| """ | |
| % (num_dims, num_bits, num_dims * num_bits) | |
| ) | |
| # Treat the location integers as 64-bit unsigned and then split them up into | |
| # a sequence of uint8s. Preserve the association by dimension. | |
| locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) | |
| # Now turn these into bits and truncate to num_bits. | |
| gray = ( | |
| locs_uint8.unsqueeze(-1) | |
| .bitwise_and(bitpack_mask_rev) | |
| .ne(0) | |
| .byte() | |
| .flatten(-2, -1)[..., -num_bits:] | |
| ) | |
| # Run the decoding process the other way. | |
| # Iterate forwards through the bits. | |
| for bit in range(0, num_bits): | |
| # Iterate forwards through the dimensions. | |
| for dim in range(0, num_dims): | |
| # Identify which ones have this bit active. | |
| mask = gray[:, dim, bit] | |
| # Where this bit is on, invert the 0 dimension for lower bits. | |
| gray[:, 0, bit + 1 :] = torch.logical_xor( | |
| gray[:, 0, bit + 1 :], mask[:, None] | |
| ) | |
| # Where the bit is off, exchange the lower bits with the 0 dimension. | |
| to_flip = torch.logical_and( | |
| torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), | |
| torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), | |
| ) | |
| gray[:, dim, bit + 1 :] = torch.logical_xor( | |
| gray[:, dim, bit + 1 :], to_flip | |
| ) | |
| gray[:, 0, bit + 1 :] = torch.logical_xor( | |
| gray[:, 0, bit + 1 :], to_flip | |
| ) | |
| # Now flatten out. | |
| # Fix: shape '[-1, 0]' is invalid for input of size 192 | |
| gray = gray.swapaxes(1, 2).reshape((gray.size(0), -1)) | |
| # Convert Gray back to binary. | |
| hh_bin = self._gray2binary(gray) | |
| # Pad back out to 64 bits. | |
| extra_dims = 64 - gray.size(1) | |
| padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) | |
| # Convert binary values into uint8s. | |
| hh_uint8 = ( | |
| (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) | |
| .sum(2) | |
| .squeeze() | |
| .type(torch.uint8) | |
| ) | |
| # Convert uint8s into uint64s. | |
| hh_uint64 = hh_uint8.view(torch.int64).squeeze() | |
| return hh_uint64 | |
| def _gray2binary(self, gray, axis=-1): | |
| """Convert an array of Gray codes back into binary values. | |
| Parameters: | |
| ----------- | |
| gray: An ndarray of gray codes. | |
| axis: The axis along which to perform Gray decoding. Default=-1. | |
| Returns: | |
| -------- | |
| Returns an ndarray of binary values. | |
| """ | |
| # Loop the log2(bits) number of times necessary, with shift and xor. | |
| shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) | |
| while shift > 0: | |
| gray = torch.logical_xor(gray, self._right_shift(gray, shift)) | |
| shift = torch.div(shift, 2, rounding_mode="floor") | |
| return gray | |
| def _right_shift(self, binary, k=1, axis=-1): | |
| """Right shift an array of binary values. | |
| Parameters: | |
| ----------- | |
| binary: An ndarray of binary values. | |
| k: The number of bits to shift. Default 1. | |
| axis: The axis along which to shift. Default -1. | |
| Returns: | |
| -------- | |
| Returns an ndarray with zero prepended and the ends truncated, along | |
| whatever axis was specified.""" | |
| # If we're shifting the whole thing, just return zeros. | |
| if binary.shape[axis] <= k: | |
| return torch.zeros_like(binary) | |
| # Determine the padding pattern. | |
| # padding = [(0,0)] * len(binary.shape) | |
| # padding[axis] = (k,0) | |
| # Determine the slicing pattern to eliminate just the last one. | |
| slicing = [slice(None)] * len(binary.shape) | |
| slicing[axis] = slice(None, -k) | |
| shifted = torch.nn.functional.pad( | |
| binary[tuple(slicing)], (k, 0), mode="constant", value=0 | |
| ) | |
| return shifted | |
| class PointModule(torch.nn.Module): | |
| r"""PointModule | |
| placeholder, all module subclass from this will take Point in PointSequential. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| class Point(addict.Dict): | |
| """ | |
| Point Structure of Pointcept | |
| A Point (point cloud) in Pointcept is a dictionary that contains various properties of | |
| a batched point cloud. The property with the following names have a specific definition | |
| as follows: | |
| - "coord": original coordinate of point cloud; | |
| - "grid_coord": grid coordinate for specific grid size (related to GridSampling); | |
| Point also support the following optional attributes: | |
| - "offset": if not exist, initialized as batch size is 1; | |
| - "batch": if not exist, initialized as batch size is 1; | |
| - "feat": feature of point cloud, default input of model; | |
| - "grid_size": Grid size of point cloud (related to GridSampling); | |
| (related to Serialization) | |
| - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; | |
| - "serialized_code": a list of serialization codes; | |
| - "serialized_order": a list of serialization order determined by code; | |
| - "serialized_inverse": a list of inverse mapping determined by code; | |
| (related to Sparsify: SpConv) | |
| - "sparse_shape": Sparse shape for Sparse Conv Tensor; | |
| - "sparse_conv_feat": SparseConvTensor init with information provide by Point; | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.serializator = Serializator() | |
| # If one of "offset" or "batch" do not exist, generate by the existing one | |
| if "batch" not in self.keys() and "offset" in self.keys(): | |
| self["batch"] = offset2batch(self.offset) | |
| elif "offset" not in self.keys() and "batch" in self.keys(): | |
| self["offset"] = batch2offset(self.batch) | |
| def serialization(self, order="z", depth=None, shuffle_orders=False): | |
| """ | |
| Point Cloud Serialization | |
| relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] | |
| """ | |
| assert "batch" in self.keys() | |
| if "grid_coord" not in self.keys(): | |
| # if you don't want to operate GridSampling in data augmentation, | |
| # please add the following augmentation into your pipline: | |
| # dict(type="Copy", keys_dict={"grid_size": 0.01}), | |
| # (adjust `grid_size` to what your want) | |
| assert {"grid_size", "coord"}.issubset(self.keys()) | |
| self["grid_coord"] = torch.div( | |
| self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" | |
| ).int() | |
| if depth is None: | |
| # Adaptive measure the depth of serialization cube (length = 2 ^ depth) | |
| depth = int(self.grid_coord.max()).bit_length() | |
| self["serialized_depth"] = depth | |
| # Maximum bit length for serialization code is 63 (int64) | |
| assert depth * 3 + len(self.offset).bit_length() <= 63 | |
| # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. | |
| # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 | |
| # cube with a grid size of 0.01 meter. We consider it is enough for the current stage. | |
| # We can unlock the limitation by optimizing the z-order encoding function if necessary. | |
| assert depth <= 16 | |
| # The serialization codes are arranged as following structures: | |
| # [Order1 ([n]), | |
| # Order2 ([n]), | |
| # ... | |
| # OrderN ([n])] (k, n) | |
| code = [ | |
| self.serializator.encode( | |
| self.grid_coord, self.grid_size, self.batch, depth, order=order_ | |
| ) | |
| for order_ in order | |
| ] | |
| code = torch.stack(code) | |
| order = torch.argsort(code) | |
| inverse = torch.zeros_like(order).scatter_( | |
| dim=1, | |
| index=order, | |
| src=torch.arange(0, code.shape[1], device=order.device).repeat( | |
| code.shape[0], 1 | |
| ), | |
| ) | |
| if shuffle_orders: | |
| perm = torch.randperm(code.shape[0]) | |
| code = code[perm] | |
| order = order[perm] | |
| inverse = inverse[perm] | |
| self["serialized_code"] = code | |
| self["serialized_order"] = order | |
| self["serialized_inverse"] = inverse | |
| def sparsify(self, pad=96): | |
| """ | |
| Point Cloud Serialization | |
| Point cloud is sparse, here we use "sparsify" to specifically refer to | |
| preparing "spconv.SparseConvTensor" for SpConv. | |
| relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] | |
| pad: padding sparse for sparse shape. | |
| """ | |
| assert {"feat", "batch"}.issubset(self.keys()) | |
| if "grid_coord" not in self.keys(): | |
| # if you don't want to operate GridSampling in data augmentation, | |
| # please add the following augmentation into your pipline: | |
| # dict(type="Copy", keys_dict={"grid_size": 0.01}), | |
| # (adjust `grid_size` to what your want) | |
| assert {"grid_size", "coord"}.issubset(self.keys()) | |
| self["grid_coord"] = torch.div( | |
| self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" | |
| ).int() | |
| if "sparse_shape" in self.keys(): | |
| sparse_shape = self.sparse_shape | |
| else: | |
| sparse_shape = torch.add( | |
| torch.max(self.grid_coord, dim=0).values, pad | |
| ).tolist() | |
| sparse_conv_feat = spconv.SparseConvTensor( | |
| features=self.feat, | |
| indices=torch.cat( | |
| [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1 | |
| ).contiguous(), | |
| spatial_shape=sparse_shape, | |
| batch_size=self.batch[-1].tolist() + 1, | |
| ) | |
| self["sparse_shape"] = sparse_shape | |
| self["sparse_conv_feat"] = sparse_conv_feat | |
| class PointSequential(PointModule): | |
| r"""A sequential container. | |
| Modules will be added to it in the order they are passed in the constructor. | |
| Alternatively, an ordered dict of modules can also be passed in. | |
| """ | |
| def __init__(self, name="", *args, **kwargs): | |
| super().__init__() | |
| self.name = name | |
| if len(args) == 1 and isinstance(args[0], collections.OrderedDict): | |
| for key, module in args[0].items(): | |
| self.add_module(key, module) | |
| else: | |
| for idx, module in enumerate(args): | |
| self.add_module(str(idx), module) | |
| for name, module in kwargs.items(): | |
| if name in self._modules: | |
| raise ValueError("name exists.") | |
| self.add_module(name, module) | |
| def __getitem__(self, idx): | |
| if not (-len(self) <= idx < len(self)): | |
| raise IndexError("index {} is out of range".format(idx)) | |
| if idx < 0: | |
| idx += len(self) | |
| it = iter(self._modules.values()) | |
| for i in range(idx): | |
| next(it) | |
| return next(it) | |
| def __len__(self): | |
| return len(self._modules) | |
| def add(self, module, name=None): | |
| if name is None: | |
| name = str(len(self._modules)) | |
| if name in self._modules: | |
| raise KeyError("name exists") | |
| self.add_module(name, module) | |
| def forward(self, x): | |
| for module in self._modules.values(): | |
| # Point module | |
| if isinstance(module, PointModule): | |
| x = module(x) | |
| # Spconv module | |
| elif spconv.modules.is_spconv_module(module): | |
| if isinstance(x, Point): | |
| x.sparse_conv_feat = module(x.sparse_conv_feat) | |
| x.feat = x.sparse_conv_feat.features | |
| else: | |
| x = module(x) | |
| # Fix: Expected more than 1 value per channel when training | |
| elif isinstance(module, torch.nn.BatchNorm1d) and isinstance(x, Point): | |
| if x.feat.size(0) != 1: | |
| x.feat = module(x.feat) | |
| # PyTorch module | |
| else: | |
| if isinstance(x, Point): | |
| x.feat = module(x.feat) | |
| if "sparse_conv_feat" in x.keys(): | |
| x.sparse_conv_feat = x.sparse_conv_feat.replace_feature(x.feat) | |
| elif isinstance(x, spconv.SparseConvTensor): | |
| if x.indices.shape[0] != 0: | |
| x = x.replace_feature(module(x.features)) | |
| else: | |
| x = module(x) | |
| return x | |
| class PDNorm(PointModule): | |
| def __init__( | |
| self, | |
| num_features, | |
| norm_layer, | |
| context_channels=256, | |
| conditions=("ScanNet", "S3DIS", "Structured3D"), | |
| decouple=True, | |
| adaptive=False, | |
| ): | |
| super().__init__() | |
| self.conditions = conditions | |
| self.decouple = decouple | |
| self.adaptive = adaptive | |
| if self.decouple: | |
| self.norm = torch.nn.ModuleList( | |
| [norm_layer(num_features) for _ in conditions] | |
| ) | |
| else: | |
| self.norm = norm_layer | |
| if self.adaptive: | |
| self.modulation = torch.nn.Sequential( | |
| torch.nn.SiLU(), | |
| torch.nn.Linear(context_channels, 2 * num_features, bias=True), | |
| ) | |
| def forward(self, point): | |
| assert {"feat", "condition"}.issubset(point.keys()) | |
| if isinstance(point.condition, str): | |
| condition = point.condition | |
| else: | |
| condition = point.condition[0] | |
| if self.decouple: | |
| assert condition in self.conditions | |
| norm = self.norm[self.conditions.index(condition)] | |
| else: | |
| norm = self.norm | |
| point.feat = norm(point.feat) | |
| if self.adaptive: | |
| assert "context" in point.keys() | |
| shift, scale = self.modulation(point.context).chunk(2, dim=1) | |
| point.feat = point.feat * (1.0 + scale) + shift | |
| return point | |
| class RPE(torch.nn.Module): | |
| def __init__(self, patch_size, num_heads): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.num_heads = num_heads | |
| self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) | |
| self.rpe_num = 2 * self.pos_bnd + 1 | |
| self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) | |
| torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) | |
| def forward(self, coord): | |
| idx = ( | |
| coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd | |
| + self.pos_bnd # relative position to positive index | |
| + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride | |
| ) | |
| out = self.rpe_table.index_select(0, idx.reshape(-1)) | |
| out = out.view(idx.shape + (-1,)).sum(3) | |
| out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) | |
| return out | |
| class SerializedAttention(PointModule): | |
| def __init__( | |
| self, | |
| channels, | |
| num_heads, | |
| patch_size, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| order_index=0, | |
| enable_rpe=False, | |
| enable_flash=True, | |
| upcast_attention=True, | |
| upcast_softmax=True, | |
| ): | |
| super().__init__() | |
| assert channels % num_heads == 0 | |
| self.channels = channels | |
| self.num_heads = num_heads | |
| self.scale = qk_scale or (channels // num_heads) ** -0.5 | |
| self.order_index = order_index | |
| self.upcast_attention = upcast_attention | |
| self.upcast_softmax = upcast_softmax | |
| self.enable_rpe = enable_rpe | |
| self.enable_flash = enable_flash | |
| if enable_flash: | |
| assert ( | |
| enable_rpe is False | |
| ), "Set enable_rpe to False when enable Flash Attention" | |
| assert ( | |
| upcast_attention is False | |
| ), "Set upcast_attention to False when enable Flash Attention" | |
| assert ( | |
| upcast_softmax is False | |
| ), "Set upcast_softmax to False when enable Flash Attention" | |
| assert flash_attn is not None, "Make sure flash_attn is installed." | |
| self.patch_size = patch_size | |
| self.attn_drop = attn_drop | |
| else: | |
| # when disable flash attention, we still don't want to use mask | |
| # consequently, patch size will auto set to the | |
| # min number of patch_size_max and number of points | |
| self.patch_size_max = patch_size | |
| self.patch_size = 0 | |
| self.attn_drop = torch.nn.Dropout(attn_drop) | |
| self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) | |
| self.proj = torch.nn.Linear(channels, channels) | |
| self.proj_drop = torch.nn.Dropout(proj_drop) | |
| self.softmax = torch.nn.Softmax(dim=-1) | |
| self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None | |
| def get_rel_pos(self, point, order): | |
| K = self.patch_size | |
| rel_pos_key = f"rel_pos_{self.order_index}" | |
| if rel_pos_key not in point.keys(): | |
| grid_coord = point.grid_coord[order] | |
| grid_coord = grid_coord.reshape(-1, K, 3) | |
| point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) | |
| return point[rel_pos_key] | |
| def get_padding_and_inverse(self, point): | |
| pad_key = "pad" | |
| unpad_key = "unpad" | |
| cu_seqlens_key = "cu_seqlens_key" | |
| if ( | |
| pad_key not in point.keys() | |
| or unpad_key not in point.keys() | |
| or cu_seqlens_key not in point.keys() | |
| ): | |
| offset = point.offset | |
| bincount = offset2bincount(offset) | |
| bincount_pad = ( | |
| torch.div( | |
| bincount + self.patch_size - 1, | |
| self.patch_size, | |
| rounding_mode="trunc", | |
| ) | |
| * self.patch_size | |
| ) | |
| # only pad point when num of points larger than patch_size | |
| mask_pad = bincount > self.patch_size | |
| bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad | |
| _offset = torch.nn.functional.pad(offset, (1, 0)) | |
| _offset_pad = torch.nn.functional.pad( | |
| torch.cumsum(bincount_pad, dim=0), (1, 0) | |
| ) | |
| pad = torch.arange(_offset_pad[-1], device=offset.device) | |
| unpad = torch.arange(_offset[-1], device=offset.device) | |
| cu_seqlens = [] | |
| for i in range(len(offset)): | |
| unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] | |
| if bincount[i] != bincount_pad[i]: | |
| pad[ | |
| _offset_pad[i + 1] | |
| - self.patch_size | |
| + (bincount[i] % self.patch_size) : _offset_pad[i + 1] | |
| ] = pad[ | |
| _offset_pad[i + 1] | |
| - 2 * self.patch_size | |
| + (bincount[i] % self.patch_size) : _offset_pad[i + 1] | |
| - self.patch_size | |
| ] | |
| pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] | |
| cu_seqlens.append( | |
| torch.arange( | |
| _offset_pad[i], | |
| _offset_pad[i + 1], | |
| step=self.patch_size, | |
| dtype=torch.int32, | |
| device=offset.device, | |
| ) | |
| ) | |
| point[pad_key] = pad | |
| point[unpad_key] = unpad | |
| point[cu_seqlens_key] = torch.nn.functional.pad( | |
| torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] | |
| ) | |
| return point[pad_key], point[unpad_key], point[cu_seqlens_key] | |
| def forward(self, point): | |
| if not self.enable_flash: | |
| self.patch_size = min( | |
| offset2bincount(point.offset).min().tolist(), self.patch_size_max | |
| ) | |
| H = self.num_heads | |
| K = self.patch_size | |
| C = self.channels | |
| pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) | |
| order = point.serialized_order[self.order_index][pad] | |
| inverse = unpad[point.serialized_inverse[self.order_index]] | |
| # padding and reshape feat and batch for serialized point patch | |
| qkv = self.qkv(point.feat)[order] | |
| if not self.enable_flash: | |
| # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') | |
| q, k, v = ( | |
| qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) | |
| ) | |
| # attn | |
| if self.upcast_attention: | |
| q = q.float() | |
| k = k.float() | |
| attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) | |
| if self.enable_rpe: | |
| attn = attn + self.rpe(self.get_rel_pos(point, order)) | |
| if self.upcast_softmax: | |
| attn = attn.float() | |
| attn = self.softmax(attn) | |
| attn = self.attn_drop(attn).to(qkv.dtype) | |
| feat = (attn @ v).transpose(1, 2).reshape(-1, C) | |
| else: | |
| feat = flash_attn.flash_attn_varlen_qkvpacked_func( | |
| qkv.half().reshape(-1, 3, H, C // H), | |
| cu_seqlens, | |
| max_seqlen=self.patch_size, | |
| dropout_p=self.attn_drop if self.training else 0, | |
| softmax_scale=self.scale, | |
| ).reshape(-1, C) | |
| feat = feat.to(qkv.dtype) | |
| feat = feat[inverse] | |
| # ffn | |
| feat = self.proj(feat) | |
| feat = self.proj_drop(feat) | |
| point.feat = feat | |
| return point | |
| class MLP(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| hidden_channels=None, | |
| out_channels=None, | |
| act_layer=torch.nn.GELU, | |
| drop=0.0, | |
| ): | |
| super().__init__() | |
| out_channels = out_channels or in_channels | |
| hidden_channels = hidden_channels or in_channels | |
| self.fc1 = torch.nn.Linear(in_channels, hidden_channels) | |
| self.act = act_layer() | |
| self.fc2 = torch.nn.Linear(hidden_channels, out_channels) | |
| # self.drop = torch.nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| # x = self.drop(x) | |
| x = self.fc2(x) | |
| # x = self.drop(x) | |
| return x | |
| class Block(PointModule): | |
| def __init__( | |
| self, | |
| channels, | |
| num_heads, | |
| patch_size=48, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| drop_path=0.0, | |
| norm_layer=torch.nn.LayerNorm, | |
| act_layer=torch.nn.GELU, | |
| pre_norm=True, | |
| order_index=0, | |
| cpe_indice_key=None, | |
| enable_rpe=False, | |
| enable_flash=True, | |
| upcast_attention=True, | |
| upcast_softmax=True, | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.pre_norm = pre_norm | |
| self.cpe = PointSequential( | |
| spconv.SubMConv3d( | |
| channels, | |
| channels, | |
| kernel_size=3, | |
| bias=True, | |
| indice_key=cpe_indice_key, | |
| ), | |
| torch.nn.Linear(channels, channels), | |
| norm_layer(channels), | |
| ) | |
| self.norm1 = PointSequential(norm_layer(channels)) | |
| self.attn = SerializedAttention( | |
| channels=channels, | |
| patch_size=patch_size, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop, | |
| order_index=order_index, | |
| enable_rpe=enable_rpe, | |
| enable_flash=enable_flash, | |
| upcast_attention=upcast_attention, | |
| upcast_softmax=upcast_softmax, | |
| ) | |
| self.norm2 = PointSequential(norm_layer(channels)) | |
| self.mlp = PointSequential( | |
| MLP( | |
| in_channels=channels, | |
| hidden_channels=int(channels * mlp_ratio), | |
| out_channels=channels, | |
| act_layer=act_layer, | |
| drop=proj_drop, | |
| ) | |
| ) | |
| self.drop_path = PointSequential( | |
| DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity() | |
| ) | |
| def forward(self, point: Point): | |
| shortcut = point.feat | |
| point = self.cpe(point) | |
| point.feat = shortcut + point.feat | |
| shortcut = point.feat | |
| if self.pre_norm: | |
| point = self.norm1(point) | |
| point = self.drop_path(self.attn(point)) | |
| point.feat = shortcut + point.feat | |
| if not self.pre_norm: | |
| point = self.norm1(point) | |
| shortcut = point.feat | |
| if self.pre_norm: | |
| point = self.norm2(point) | |
| point = self.drop_path(self.mlp(point)) | |
| point.feat = shortcut + point.feat | |
| if not self.pre_norm: | |
| point = self.norm2(point) | |
| point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) | |
| return point | |
| class DropPath(torch.nn.Module): | |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" | |
| def __init__(self, drop_prob=None, scale_by_keep=True): | |
| super(DropPath, self).__init__() | |
| self.drop_prob = drop_prob | |
| self.scale_by_keep = scale_by_keep | |
| def _drop_path( | |
| self, | |
| x, | |
| drop_prob: float = 0.0, | |
| training: bool = False, | |
| scale_by_keep: bool = True, | |
| ): | |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |
| 'survival rate' as the argument. | |
| """ | |
| if drop_prob == 0.0 or not training: | |
| return x | |
| keep_prob = 1 - drop_prob | |
| shape = (x.shape[0],) + (1,) * ( | |
| x.ndim - 1 | |
| ) # work with diff dim tensors, not just 2D ConvNets | |
| random_tensor = x.new_empty(shape).bernoulli_(keep_prob) | |
| if keep_prob > 0.0 and scale_by_keep: | |
| random_tensor.div_(keep_prob) | |
| return x * random_tensor | |
| def forward(self, x): | |
| return self._drop_path(x, self.drop_prob, self.training, self.scale_by_keep) | |
| class SerializedPooling(PointModule): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| stride=2, | |
| norm_layer=None, | |
| act_layer=None, | |
| reduce="max", | |
| shuffle_orders=True, | |
| traceable=True, # record parent and cluster | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 | |
| # TODO: add support to grid pool (any stride) | |
| self.stride = stride | |
| assert reduce in ["sum", "mean", "min", "max"] | |
| self.reduce = reduce | |
| self.shuffle_orders = shuffle_orders | |
| self.traceable = traceable | |
| self.proj = torch.nn.Linear(in_channels, out_channels) | |
| if norm_layer is not None: | |
| self.norm = PointSequential(norm_layer(out_channels)) | |
| if act_layer is not None: | |
| self.act = PointSequential(act_layer()) | |
| def forward(self, point: Point): | |
| pooling_depth = (math.ceil(self.stride) - 1).bit_length() | |
| if pooling_depth > point.serialized_depth: | |
| pooling_depth = 0 | |
| assert { | |
| "serialized_code", | |
| "serialized_order", | |
| "serialized_inverse", | |
| "serialized_depth", | |
| }.issubset( | |
| point.keys() | |
| ), "Run point.serialization() point cloud before SerializedPooling" | |
| code = point.serialized_code >> pooling_depth * 3 | |
| code_, cluster, counts = torch.unique( | |
| code[0], | |
| sorted=True, | |
| return_inverse=True, | |
| return_counts=True, | |
| ) | |
| # indices of point sorted by cluster, for torch_scatter.segment_csr | |
| _, indices = torch.sort(cluster) | |
| # index pointer for sorted point, for torch_scatter.segment_csr | |
| idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) | |
| # head_indices of each cluster, for reduce attr e.g. code, batch | |
| head_indices = indices[idx_ptr[:-1]] | |
| # generate down code, order, inverse | |
| code = code[:, head_indices] | |
| order = torch.argsort(code) | |
| inverse = torch.zeros_like(order).scatter_( | |
| dim=1, | |
| index=order, | |
| src=torch.arange(0, code.shape[1], device=order.device).repeat( | |
| code.shape[0], 1 | |
| ), | |
| ) | |
| if self.shuffle_orders: | |
| perm = torch.randperm(code.shape[0]) | |
| code = code[perm] | |
| order = order[perm] | |
| inverse = inverse[perm] | |
| # collect information | |
| point_dict = addict.Dict( | |
| feat=torch_scatter.segment_csr( | |
| self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce | |
| ), | |
| coord=torch_scatter.segment_csr( | |
| point.coord[indices], idx_ptr, reduce="mean" | |
| ), | |
| grid_coord=point.grid_coord[head_indices] >> pooling_depth, | |
| serialized_code=code, | |
| serialized_order=order, | |
| serialized_inverse=inverse, | |
| serialized_depth=point.serialized_depth - pooling_depth, | |
| batch=point.batch[head_indices], | |
| ) | |
| if "condition" in point.keys(): | |
| point_dict["condition"] = point.condition | |
| if "context" in point.keys(): | |
| point_dict["context"] = point.context | |
| if self.traceable: | |
| point_dict["pooling_inverse"] = cluster | |
| point_dict["pooling_parent"] = point | |
| point = Point(point_dict) | |
| # Fix: Expected more than 1 value per channel when training | |
| if self.norm is not None and point.feat.size(0) != 1: | |
| point = self.norm(point) | |
| if self.act is not None: | |
| point = self.act(point) | |
| point.sparsify() | |
| return point | |
| class SerializedUnpooling(PointModule): | |
| def __init__( | |
| self, | |
| in_channels, | |
| skip_channels, | |
| out_channels, | |
| norm_layer=None, | |
| act_layer=None, | |
| traceable=False, # record parent and cluster | |
| ): | |
| super().__init__() | |
| self.proj = PointSequential(torch.nn.Linear(in_channels, out_channels)) | |
| self.proj_skip = PointSequential(torch.nn.Linear(skip_channels, out_channels)) | |
| if norm_layer is not None: | |
| self.proj.add(norm_layer(out_channels)) | |
| self.proj_skip.add(norm_layer(out_channels)) | |
| if act_layer is not None: | |
| self.proj.add(act_layer()) | |
| self.proj_skip.add(act_layer()) | |
| self.traceable = traceable | |
| def forward(self, point): | |
| assert "pooling_parent" in point.keys() | |
| assert "pooling_inverse" in point.keys() | |
| parent = point.pop("pooling_parent") | |
| inverse = point.pop("pooling_inverse") | |
| point = self.proj(point) | |
| parent = self.proj_skip(parent) | |
| parent.feat = parent.feat + point.feat[inverse] | |
| if self.traceable: | |
| parent["unpooling_parent"] = point | |
| return parent | |
| class Embedding(PointModule): | |
| def __init__( | |
| self, | |
| in_channels, | |
| embed_channels, | |
| norm_layer=None, | |
| act_layer=None, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.embed_channels = embed_channels | |
| # TODO: check remove spconv | |
| self.stem = PointSequential( | |
| conv=spconv.SubMConv3d( | |
| in_channels, | |
| embed_channels, | |
| kernel_size=5, | |
| padding=1, | |
| bias=False, | |
| indice_key="stem", | |
| ) | |
| ) | |
| if norm_layer is not None: | |
| self.stem.add(norm_layer(embed_channels), name="norm") | |
| if act_layer is not None: | |
| self.stem.add(act_layer(), name="act") | |
| def forward(self, point: Point): | |
| point = self.stem(point) | |
| return point | |
| class PointTransformerV3(PointModule): | |
| def __init__( | |
| self, | |
| in_channels=6, | |
| order=("cord"), | |
| stride=(2, 2, 2, 2), | |
| enc_depths=(2, 2, 2, 6, 2), | |
| enc_channels=(32, 64, 128, 256, 512), | |
| enc_num_head=(2, 4, 8, 16, 32), | |
| enc_patch_size=(1024, 1024, 1024, 1024, 1024), | |
| dec_depths=(2, 2, 2, 2), | |
| dec_channels=(64, 64, 128, 256), | |
| dec_num_head=(4, 4, 8, 16), | |
| dec_patch_size=(1024, 1024, 1024, 1024), | |
| mlp_ratio=4, | |
| grid_size=0.01, | |
| qkv_bias=True, | |
| qk_scale=None, | |
| attn_drop=0.0, | |
| proj_drop=0.0, | |
| drop_path=0.3, | |
| pre_norm=True, | |
| shuffle_orders=True, | |
| enable_rpe=False, | |
| enable_flash=True, | |
| upcast_attention=False, | |
| upcast_softmax=False, | |
| cls_mode=False, | |
| pdnorm_bn=False, | |
| pdnorm_ln=False, | |
| pdnorm_decouple=True, | |
| pdnorm_adaptive=False, | |
| pdnorm_affine=True, | |
| pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"), | |
| ): | |
| super().__init__() | |
| self.num_stages = len(enc_depths) | |
| self.order = [order] if isinstance(order, str) else order | |
| self.cls_mode = cls_mode | |
| self.shuffle_orders = shuffle_orders | |
| self.grid_size = grid_size | |
| assert self.num_stages == len(stride) + 1 | |
| assert self.num_stages == len(enc_depths) | |
| assert self.num_stages == len(enc_channels) | |
| assert self.num_stages == len(enc_num_head) | |
| assert self.num_stages == len(enc_patch_size) | |
| assert self.cls_mode or self.num_stages == len(dec_depths) + 1 | |
| assert self.cls_mode or self.num_stages == len(dec_channels) + 1 | |
| assert self.cls_mode or self.num_stages == len(dec_num_head) + 1 | |
| assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1 | |
| # norm layers | |
| if pdnorm_bn: | |
| bn_layer = functools.partial( | |
| PDNorm, | |
| norm_layer=functools.partial( | |
| torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine | |
| ), | |
| conditions=pdnorm_conditions, | |
| decouple=pdnorm_decouple, | |
| adaptive=pdnorm_adaptive, | |
| ) | |
| else: | |
| bn_layer = functools.partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01) | |
| if pdnorm_ln: | |
| ln_layer = functools.partial( | |
| PDNorm, | |
| norm_layer=functools.partial( | |
| torch.nn.LayerNorm, elementwise_affine=pdnorm_affine | |
| ), | |
| conditions=pdnorm_conditions, | |
| decouple=pdnorm_decouple, | |
| adaptive=pdnorm_adaptive, | |
| ) | |
| else: | |
| ln_layer = torch.nn.LayerNorm | |
| # activation layers | |
| act_layer = torch.nn.GELU | |
| self.embedding = Embedding( | |
| in_channels=in_channels, | |
| embed_channels=enc_channels[0], | |
| norm_layer=bn_layer, | |
| act_layer=act_layer, | |
| ) | |
| # encoder | |
| enc_drop_path = [ | |
| x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) | |
| ] | |
| self.enc = PointSequential(name="encoder") | |
| for s in range(self.num_stages): | |
| enc_drop_path_ = enc_drop_path[ | |
| sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) | |
| ] | |
| enc = PointSequential(name="encoder_layer_%d" % s) | |
| if s > 0: | |
| enc.add( | |
| SerializedPooling( | |
| in_channels=enc_channels[s - 1], | |
| out_channels=enc_channels[s], | |
| stride=stride[s - 1], | |
| norm_layer=bn_layer, | |
| act_layer=act_layer, | |
| ), | |
| name="down", | |
| ) | |
| for i in range(enc_depths[s]): | |
| enc.add( | |
| Block( | |
| channels=enc_channels[s], | |
| num_heads=enc_num_head[s], | |
| patch_size=enc_patch_size[s], | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop, | |
| drop_path=enc_drop_path_[i], | |
| norm_layer=ln_layer, | |
| act_layer=act_layer, | |
| pre_norm=pre_norm, | |
| order_index=i % len(self.order), | |
| cpe_indice_key=f"stage{s}", | |
| enable_rpe=enable_rpe, | |
| enable_flash=enable_flash, | |
| upcast_attention=upcast_attention, | |
| upcast_softmax=upcast_softmax, | |
| ), | |
| name=f"block{i}", | |
| ) | |
| if len(enc) != 0: | |
| self.enc.add(module=enc, name=f"enc{s}") | |
| # decoder | |
| if not self.cls_mode: | |
| dec_drop_path = [ | |
| x.item() for x in torch.linspace(0, drop_path, sum(dec_depths)) | |
| ] | |
| self.dec = PointSequential(name="decoder") | |
| dec_channels = list(dec_channels) + [enc_channels[-1]] | |
| for s in reversed(range(self.num_stages - 1)): | |
| dec_drop_path_ = dec_drop_path[ | |
| sum(dec_depths[:s]) : sum(dec_depths[: s + 1]) | |
| ] | |
| dec_drop_path_.reverse() | |
| dec = PointSequential(name="decoder_layer_%d" % s) | |
| dec.add( | |
| SerializedUnpooling( | |
| in_channels=dec_channels[s + 1], | |
| skip_channels=enc_channels[s], | |
| out_channels=dec_channels[s], | |
| norm_layer=bn_layer, | |
| act_layer=act_layer, | |
| ), | |
| name="up", | |
| ) | |
| for i in range(dec_depths[s]): | |
| dec.add( | |
| Block( | |
| channels=dec_channels[s], | |
| num_heads=dec_num_head[s], | |
| patch_size=dec_patch_size[s], | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| qk_scale=qk_scale, | |
| attn_drop=attn_drop, | |
| proj_drop=proj_drop, | |
| drop_path=dec_drop_path_[i], | |
| norm_layer=ln_layer, | |
| act_layer=act_layer, | |
| pre_norm=pre_norm, | |
| order_index=i % len(self.order), | |
| cpe_indice_key=f"stage{s}", | |
| enable_rpe=enable_rpe, | |
| enable_flash=enable_flash, | |
| upcast_attention=upcast_attention, | |
| upcast_softmax=upcast_softmax, | |
| ), | |
| name=f"block{i}", | |
| ) | |
| self.dec.add(module=dec, name=f"dec{s}") | |
| def forward(self, batch, feat, coord): | |
| """ | |
| A data_dict is a dictionary containing properties of a batched point cloud. | |
| It should contain the following properties for PTv3: | |
| 1. "feat": feature of point cloud | |
| 2. "grid_coord": discrete coordinate after grid sampling (voxelization) or "coord" + "grid_size" | |
| 3. "offset" or "batch": https://github.com/Pointcept/Pointcept?tab=readme-ov-file#offset | |
| """ | |
| point = Point( | |
| { | |
| "batch": batch.squeeze(dim=0), | |
| "feat": feat.squeeze(dim=0), | |
| "coord": coord.squeeze(dim=0), | |
| "grid_size": self.grid_size, | |
| } | |
| ) | |
| point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) | |
| point.sparsify() | |
| point = self.embedding(point) | |
| point = self.enc(point) | |
| if not self.cls_mode: | |
| point = self.dec(point) | |
| return point.feat.unsqueeze(dim=0) | |