hzxie's picture
fix: reinitialize the repo.
83d5461 verified
Raw
History Blame
49 kB
# -*- coding: utf-8 -*-
#
# @File: pt_v3.py
# @Author: Xiaoyang Wu <xiaoyang.wu.cs@gmail.com>
# @Date: 2024-04-01 16:31:36
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-05-15 22:05:09
# @Email: root@haozhexie.com
# Ref:
# - https://github.com/Pointcept/PointTransformerV3/blob/main/model.py
# - https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py
import addict
import collections
import functools
import flash_attn
import math
import torch
import spconv.pytorch as spconv
import torch_scatter
import typing
@torch.inference_mode()
def offset2bincount(offset):
return torch.diff(
offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
)
@torch.inference_mode()
def offset2batch(offset):
bincount = offset2bincount(offset)
return torch.arange(
len(bincount), device=offset.device, dtype=torch.long
).repeat_interleave(bincount)
@torch.inference_mode()
def batch2offset(batch):
return torch.cumsum(batch.bincount(), dim=0).long()
class KeyLUT:
def __init__(self):
r256 = torch.arange(256, dtype=torch.int64)
r512 = torch.arange(512, dtype=torch.int64)
zero = torch.zeros(256, dtype=torch.int64)
device = torch.device("cpu")
self._encode = {
device: (
self.xyz2key(r256, zero, zero, 8),
self.xyz2key(zero, r256, zero, 8),
self.xyz2key(zero, zero, r256, 8),
)
}
self._decode = {device: self.key2xyz(r512, 9)}
def encode_lut(self, device=torch.device("cpu")):
if device not in self._encode:
cpu = torch.device("cpu")
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
return self._encode[device]
def decode_lut(self, device=torch.device("cpu")):
if device not in self._decode:
cpu = torch.device("cpu")
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
return self._decode[device]
def xyz2key(self, x, y, z, depth):
key = torch.zeros_like(x)
for i in range(depth):
mask = 1 << i
key = (
key
| ((x & mask) << (2 * i + 2))
| ((y & mask) << (2 * i + 1))
| ((z & mask) << (2 * i + 0))
)
return key
def key2xyz(self, key, depth):
x = torch.zeros_like(key)
y = torch.zeros_like(key)
z = torch.zeros_like(key)
for i in range(depth):
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
return x, y, z
class Serializator:
def encode(self, grid_coord, grid_size=0.01, batch=None, depth=16, order="cord"):
assert order in {"cord", "z", "z-trans", "hilbert", "hilbert-trans"}
if order in ["z", "z-trans"]:
self.key_lut = KeyLUT()
if order == "cord":
code = self.cord_encode(grid_coord, grid_size)
elif order == "z":
code = self.z_order_encode(grid_coord, depth=depth)
elif order == "z-trans":
code = self.z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
elif order == "hilbert":
code = self.hilbert_encode(grid_coord, depth=depth)
elif order == "hilbert-trans":
code = self.hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
else:
raise NotImplementedError
if batch is not None:
batch = batch.long()
code = batch << depth * 3 | code
return code
def cord_encode(self, grid_coord: torch.Tensor, grid_size: float):
x, y, z = (
grid_coord[:, 0].long(),
grid_coord[:, 1].long(),
grid_coord[:, 2].long(),
)
# we block the support to batch, maintain batched code in Point class
code = x / grid_size**2 + y / grid_size + z
return code.long()
def z_order_encode(self, grid_coord: torch.Tensor, depth: int = 16):
x, y, z = (
grid_coord[:, 0].long(),
grid_coord[:, 1].long(),
grid_coord[:, 2].long(),
)
# we block the support to batch, maintain batched code in Point class
code = self._xyz2key(x, y, z, b=None, depth=depth)
return code
def _xyz2key(
self,
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
b: typing.Optional[typing.Union[torch.Tensor, int]] = None,
depth: int = 16,
):
r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
based on pre-computed look up tables. The speed of this function is much
faster than the method based on for-loop.
Args:
x (torch.Tensor): The x coordinate.
y (torch.Tensor): The y coordinate.
z (torch.Tensor): The z coordinate.
b (torch.Tensor or int): The batch index of the coordinates, and should be
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
"""
EX, EY, EZ = self.key_lut.encode_lut(x.device)
x, y, z = x.long(), y.long(), z.long()
mask = 255 if depth > 8 else (1 << depth) - 1
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
if depth > 8:
mask = (1 << (depth - 8)) - 1
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
key = key16 << 24 | key
if b is not None:
b = b.long()
key = b << 48 | key
return key
def hilbert_encode(self, grid_coord: torch.Tensor, depth: int = 16):
return self._hilbert_encode(grid_coord, num_dims=3, num_bits=depth)
def _hilbert_encode(self, locs, num_dims, num_bits):
"""Decode an array of locations in a hypercube into a Hilbert integer.
This is a vectorized-ish version of the Hilbert curve implementation by John
Skilling as described in:
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
Params:
-------
locs - An ndarray of locations in a hypercube of num_dims dimensions, in
which each dimension runs from 0 to 2**num_bits-1. The shape can
be arbitrary, as long as the last dimension of the same has size
num_dims.
num_dims - The dimensionality of the hypercube. Integer.
num_bits - The number of bits for each dimension. Integer.
Returns:
--------
The output is an ndarray of uint64 integers with the same shape as the
input, excluding the last dimension, which needs to be num_dims.
"""
# Keep around the original shape for later.
orig_shape = locs.shape
bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
bitpack_mask_rev = bitpack_mask.flip(-1)
if orig_shape[-1] != num_dims:
raise ValueError(
"""
The shape of locs was surprising in that the last dimension was of size
%d, but num_dims=%d. These need to be equal.
"""
% (orig_shape[-1], num_dims)
)
if num_dims * num_bits > 63:
raise ValueError(
"""
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
into a int64. Are you sure you need that many points on your Hilbert
curve?
"""
% (num_dims, num_bits, num_dims * num_bits)
)
# Treat the location integers as 64-bit unsigned and then split them up into
# a sequence of uint8s. Preserve the association by dimension.
locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
# Now turn these into bits and truncate to num_bits.
gray = (
locs_uint8.unsqueeze(-1)
.bitwise_and(bitpack_mask_rev)
.ne(0)
.byte()
.flatten(-2, -1)[..., -num_bits:]
)
# Run the decoding process the other way.
# Iterate forwards through the bits.
for bit in range(0, num_bits):
# Iterate forwards through the dimensions.
for dim in range(0, num_dims):
# Identify which ones have this bit active.
mask = gray[:, dim, bit]
# Where this bit is on, invert the 0 dimension for lower bits.
gray[:, 0, bit + 1 :] = torch.logical_xor(
gray[:, 0, bit + 1 :], mask[:, None]
)
# Where the bit is off, exchange the lower bits with the 0 dimension.
to_flip = torch.logical_and(
torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
)
gray[:, dim, bit + 1 :] = torch.logical_xor(
gray[:, dim, bit + 1 :], to_flip
)
gray[:, 0, bit + 1 :] = torch.logical_xor(
gray[:, 0, bit + 1 :], to_flip
)
# Now flatten out.
# Fix: shape '[-1, 0]' is invalid for input of size 192
gray = gray.swapaxes(1, 2).reshape((gray.size(0), -1))
# Convert Gray back to binary.
hh_bin = self._gray2binary(gray)
# Pad back out to 64 bits.
extra_dims = 64 - gray.size(1)
padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
# Convert binary values into uint8s.
hh_uint8 = (
(padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
.sum(2)
.squeeze()
.type(torch.uint8)
)
# Convert uint8s into uint64s.
hh_uint64 = hh_uint8.view(torch.int64).squeeze()
return hh_uint64
def _gray2binary(self, gray, axis=-1):
"""Convert an array of Gray codes back into binary values.
Parameters:
-----------
gray: An ndarray of gray codes.
axis: The axis along which to perform Gray decoding. Default=-1.
Returns:
--------
Returns an ndarray of binary values.
"""
# Loop the log2(bits) number of times necessary, with shift and xor.
shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
while shift > 0:
gray = torch.logical_xor(gray, self._right_shift(gray, shift))
shift = torch.div(shift, 2, rounding_mode="floor")
return gray
def _right_shift(self, binary, k=1, axis=-1):
"""Right shift an array of binary values.
Parameters:
-----------
binary: An ndarray of binary values.
k: The number of bits to shift. Default 1.
axis: The axis along which to shift. Default -1.
Returns:
--------
Returns an ndarray with zero prepended and the ends truncated, along
whatever axis was specified."""
# If we're shifting the whole thing, just return zeros.
if binary.shape[axis] <= k:
return torch.zeros_like(binary)
# Determine the padding pattern.
# padding = [(0,0)] * len(binary.shape)
# padding[axis] = (k,0)
# Determine the slicing pattern to eliminate just the last one.
slicing = [slice(None)] * len(binary.shape)
slicing[axis] = slice(None, -k)
shifted = torch.nn.functional.pad(
binary[tuple(slicing)], (k, 0), mode="constant", value=0
)
return shifted
class PointModule(torch.nn.Module):
r"""PointModule
placeholder, all module subclass from this will take Point in PointSequential.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class Point(addict.Dict):
"""
Point Structure of Pointcept
A Point (point cloud) in Pointcept is a dictionary that contains various properties of
a batched point cloud. The property with the following names have a specific definition
as follows:
- "coord": original coordinate of point cloud;
- "grid_coord": grid coordinate for specific grid size (related to GridSampling);
Point also support the following optional attributes:
- "offset": if not exist, initialized as batch size is 1;
- "batch": if not exist, initialized as batch size is 1;
- "feat": feature of point cloud, default input of model;
- "grid_size": Grid size of point cloud (related to GridSampling);
(related to Serialization)
- "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
- "serialized_code": a list of serialization codes;
- "serialized_order": a list of serialization order determined by code;
- "serialized_inverse": a list of inverse mapping determined by code;
(related to Sparsify: SpConv)
- "sparse_shape": Sparse shape for Sparse Conv Tensor;
- "sparse_conv_feat": SparseConvTensor init with information provide by Point;
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.serializator = Serializator()
# If one of "offset" or "batch" do not exist, generate by the existing one
if "batch" not in self.keys() and "offset" in self.keys():
self["batch"] = offset2batch(self.offset)
elif "offset" not in self.keys() and "batch" in self.keys():
self["offset"] = batch2offset(self.batch)
def serialization(self, order="z", depth=None, shuffle_orders=False):
"""
Point Cloud Serialization
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
"""
assert "batch" in self.keys()
if "grid_coord" not in self.keys():
# if you don't want to operate GridSampling in data augmentation,
# please add the following augmentation into your pipline:
# dict(type="Copy", keys_dict={"grid_size": 0.01}),
# (adjust `grid_size` to what your want)
assert {"grid_size", "coord"}.issubset(self.keys())
self["grid_coord"] = torch.div(
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
).int()
if depth is None:
# Adaptive measure the depth of serialization cube (length = 2 ^ depth)
depth = int(self.grid_coord.max()).bit_length()
self["serialized_depth"] = depth
# Maximum bit length for serialization code is 63 (int64)
assert depth * 3 + len(self.offset).bit_length() <= 63
# Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
# Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
# cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
# We can unlock the limitation by optimizing the z-order encoding function if necessary.
assert depth <= 16
# The serialization codes are arranged as following structures:
# [Order1 ([n]),
# Order2 ([n]),
# ...
# OrderN ([n])] (k, n)
code = [
self.serializator.encode(
self.grid_coord, self.grid_size, self.batch, depth, order=order_
)
for order_ in order
]
code = torch.stack(code)
order = torch.argsort(code)
inverse = torch.zeros_like(order).scatter_(
dim=1,
index=order,
src=torch.arange(0, code.shape[1], device=order.device).repeat(
code.shape[0], 1
),
)
if shuffle_orders:
perm = torch.randperm(code.shape[0])
code = code[perm]
order = order[perm]
inverse = inverse[perm]
self["serialized_code"] = code
self["serialized_order"] = order
self["serialized_inverse"] = inverse
def sparsify(self, pad=96):
"""
Point Cloud Serialization
Point cloud is sparse, here we use "sparsify" to specifically refer to
preparing "spconv.SparseConvTensor" for SpConv.
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
pad: padding sparse for sparse shape.
"""
assert {"feat", "batch"}.issubset(self.keys())
if "grid_coord" not in self.keys():
# if you don't want to operate GridSampling in data augmentation,
# please add the following augmentation into your pipline:
# dict(type="Copy", keys_dict={"grid_size": 0.01}),
# (adjust `grid_size` to what your want)
assert {"grid_size", "coord"}.issubset(self.keys())
self["grid_coord"] = torch.div(
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
).int()
if "sparse_shape" in self.keys():
sparse_shape = self.sparse_shape
else:
sparse_shape = torch.add(
torch.max(self.grid_coord, dim=0).values, pad
).tolist()
sparse_conv_feat = spconv.SparseConvTensor(
features=self.feat,
indices=torch.cat(
[self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
).contiguous(),
spatial_shape=sparse_shape,
batch_size=self.batch[-1].tolist() + 1,
)
self["sparse_shape"] = sparse_shape
self["sparse_conv_feat"] = sparse_conv_feat
class PointSequential(PointModule):
r"""A sequential container.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
"""
def __init__(self, name="", *args, **kwargs):
super().__init__()
self.name = name
if len(args) == 1 and isinstance(args[0], collections.OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
for name, module in kwargs.items():
if name in self._modules:
raise ValueError("name exists.")
self.add_module(name, module)
def __getitem__(self, idx):
if not (-len(self) <= idx < len(self)):
raise IndexError("index {} is out of range".format(idx))
if idx < 0:
idx += len(self)
it = iter(self._modules.values())
for i in range(idx):
next(it)
return next(it)
def __len__(self):
return len(self._modules)
def add(self, module, name=None):
if name is None:
name = str(len(self._modules))
if name in self._modules:
raise KeyError("name exists")
self.add_module(name, module)
def forward(self, x):
for module in self._modules.values():
# Point module
if isinstance(module, PointModule):
x = module(x)
# Spconv module
elif spconv.modules.is_spconv_module(module):
if isinstance(x, Point):
x.sparse_conv_feat = module(x.sparse_conv_feat)
x.feat = x.sparse_conv_feat.features
else:
x = module(x)
# Fix: Expected more than 1 value per channel when training
elif isinstance(module, torch.nn.BatchNorm1d) and isinstance(x, Point):
if x.feat.size(0) != 1:
x.feat = module(x.feat)
# PyTorch module
else:
if isinstance(x, Point):
x.feat = module(x.feat)
if "sparse_conv_feat" in x.keys():
x.sparse_conv_feat = x.sparse_conv_feat.replace_feature(x.feat)
elif isinstance(x, spconv.SparseConvTensor):
if x.indices.shape[0] != 0:
x = x.replace_feature(module(x.features))
else:
x = module(x)
return x
class PDNorm(PointModule):
def __init__(
self,
num_features,
norm_layer,
context_channels=256,
conditions=("ScanNet", "S3DIS", "Structured3D"),
decouple=True,
adaptive=False,
):
super().__init__()
self.conditions = conditions
self.decouple = decouple
self.adaptive = adaptive
if self.decouple:
self.norm = torch.nn.ModuleList(
[norm_layer(num_features) for _ in conditions]
)
else:
self.norm = norm_layer
if self.adaptive:
self.modulation = torch.nn.Sequential(
torch.nn.SiLU(),
torch.nn.Linear(context_channels, 2 * num_features, bias=True),
)
def forward(self, point):
assert {"feat", "condition"}.issubset(point.keys())
if isinstance(point.condition, str):
condition = point.condition
else:
condition = point.condition[0]
if self.decouple:
assert condition in self.conditions
norm = self.norm[self.conditions.index(condition)]
else:
norm = self.norm
point.feat = norm(point.feat)
if self.adaptive:
assert "context" in point.keys()
shift, scale = self.modulation(point.context).chunk(2, dim=1)
point.feat = point.feat * (1.0 + scale) + shift
return point
class RPE(torch.nn.Module):
def __init__(self, patch_size, num_heads):
super().__init__()
self.patch_size = patch_size
self.num_heads = num_heads
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
self.rpe_num = 2 * self.pos_bnd + 1
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
def forward(self, coord):
idx = (
coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
+ self.pos_bnd # relative position to positive index
+ torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
)
out = self.rpe_table.index_select(0, idx.reshape(-1))
out = out.view(idx.shape + (-1,)).sum(3)
out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
return out
class SerializedAttention(PointModule):
def __init__(
self,
channels,
num_heads,
patch_size,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
order_index=0,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
):
super().__init__()
assert channels % num_heads == 0
self.channels = channels
self.num_heads = num_heads
self.scale = qk_scale or (channels // num_heads) ** -0.5
self.order_index = order_index
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.enable_rpe = enable_rpe
self.enable_flash = enable_flash
if enable_flash:
assert (
enable_rpe is False
), "Set enable_rpe to False when enable Flash Attention"
assert (
upcast_attention is False
), "Set upcast_attention to False when enable Flash Attention"
assert (
upcast_softmax is False
), "Set upcast_softmax to False when enable Flash Attention"
assert flash_attn is not None, "Make sure flash_attn is installed."
self.patch_size = patch_size
self.attn_drop = attn_drop
else:
# when disable flash attention, we still don't want to use mask
# consequently, patch size will auto set to the
# min number of patch_size_max and number of points
self.patch_size_max = patch_size
self.patch_size = 0
self.attn_drop = torch.nn.Dropout(attn_drop)
self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
self.proj = torch.nn.Linear(channels, channels)
self.proj_drop = torch.nn.Dropout(proj_drop)
self.softmax = torch.nn.Softmax(dim=-1)
self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
@torch.no_grad()
def get_rel_pos(self, point, order):
K = self.patch_size
rel_pos_key = f"rel_pos_{self.order_index}"
if rel_pos_key not in point.keys():
grid_coord = point.grid_coord[order]
grid_coord = grid_coord.reshape(-1, K, 3)
point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
return point[rel_pos_key]
@torch.no_grad()
def get_padding_and_inverse(self, point):
pad_key = "pad"
unpad_key = "unpad"
cu_seqlens_key = "cu_seqlens_key"
if (
pad_key not in point.keys()
or unpad_key not in point.keys()
or cu_seqlens_key not in point.keys()
):
offset = point.offset
bincount = offset2bincount(offset)
bincount_pad = (
torch.div(
bincount + self.patch_size - 1,
self.patch_size,
rounding_mode="trunc",
)
* self.patch_size
)
# only pad point when num of points larger than patch_size
mask_pad = bincount > self.patch_size
bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
_offset = torch.nn.functional.pad(offset, (1, 0))
_offset_pad = torch.nn.functional.pad(
torch.cumsum(bincount_pad, dim=0), (1, 0)
)
pad = torch.arange(_offset_pad[-1], device=offset.device)
unpad = torch.arange(_offset[-1], device=offset.device)
cu_seqlens = []
for i in range(len(offset)):
unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
if bincount[i] != bincount_pad[i]:
pad[
_offset_pad[i + 1]
- self.patch_size
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1]
] = pad[
_offset_pad[i + 1]
- 2 * self.patch_size
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1]
- self.patch_size
]
pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
cu_seqlens.append(
torch.arange(
_offset_pad[i],
_offset_pad[i + 1],
step=self.patch_size,
dtype=torch.int32,
device=offset.device,
)
)
point[pad_key] = pad
point[unpad_key] = unpad
point[cu_seqlens_key] = torch.nn.functional.pad(
torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
)
return point[pad_key], point[unpad_key], point[cu_seqlens_key]
def forward(self, point):
if not self.enable_flash:
self.patch_size = min(
offset2bincount(point.offset).min().tolist(), self.patch_size_max
)
H = self.num_heads
K = self.patch_size
C = self.channels
pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
order = point.serialized_order[self.order_index][pad]
inverse = unpad[point.serialized_inverse[self.order_index]]
# padding and reshape feat and batch for serialized point patch
qkv = self.qkv(point.feat)[order]
if not self.enable_flash:
# encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
q, k, v = (
qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
)
# attn
if self.upcast_attention:
q = q.float()
k = k.float()
attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
if self.enable_rpe:
attn = attn + self.rpe(self.get_rel_pos(point, order))
if self.upcast_softmax:
attn = attn.float()
attn = self.softmax(attn)
attn = self.attn_drop(attn).to(qkv.dtype)
feat = (attn @ v).transpose(1, 2).reshape(-1, C)
else:
feat = flash_attn.flash_attn_varlen_qkvpacked_func(
qkv.half().reshape(-1, 3, H, C // H),
cu_seqlens,
max_seqlen=self.patch_size,
dropout_p=self.attn_drop if self.training else 0,
softmax_scale=self.scale,
).reshape(-1, C)
feat = feat.to(qkv.dtype)
feat = feat[inverse]
# ffn
feat = self.proj(feat)
feat = self.proj_drop(feat)
point.feat = feat
return point
class MLP(torch.nn.Module):
def __init__(
self,
in_channels,
hidden_channels=None,
out_channels=None,
act_layer=torch.nn.GELU,
drop=0.0,
):
super().__init__()
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.fc1 = torch.nn.Linear(in_channels, hidden_channels)
self.act = act_layer()
self.fc2 = torch.nn.Linear(hidden_channels, out_channels)
# self.drop = torch.nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
x = self.fc2(x)
# x = self.drop(x)
return x
class Block(PointModule):
def __init__(
self,
channels,
num_heads,
patch_size=48,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
norm_layer=torch.nn.LayerNorm,
act_layer=torch.nn.GELU,
pre_norm=True,
order_index=0,
cpe_indice_key=None,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
):
super().__init__()
self.channels = channels
self.pre_norm = pre_norm
self.cpe = PointSequential(
spconv.SubMConv3d(
channels,
channels,
kernel_size=3,
bias=True,
indice_key=cpe_indice_key,
),
torch.nn.Linear(channels, channels),
norm_layer(channels),
)
self.norm1 = PointSequential(norm_layer(channels))
self.attn = SerializedAttention(
channels=channels,
patch_size=patch_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
order_index=order_index,
enable_rpe=enable_rpe,
enable_flash=enable_flash,
upcast_attention=upcast_attention,
upcast_softmax=upcast_softmax,
)
self.norm2 = PointSequential(norm_layer(channels))
self.mlp = PointSequential(
MLP(
in_channels=channels,
hidden_channels=int(channels * mlp_ratio),
out_channels=channels,
act_layer=act_layer,
drop=proj_drop,
)
)
self.drop_path = PointSequential(
DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
)
def forward(self, point: Point):
shortcut = point.feat
point = self.cpe(point)
point.feat = shortcut + point.feat
shortcut = point.feat
if self.pre_norm:
point = self.norm1(point)
point = self.drop_path(self.attn(point))
point.feat = shortcut + point.feat
if not self.pre_norm:
point = self.norm1(point)
shortcut = point.feat
if self.pre_norm:
point = self.norm2(point)
point = self.drop_path(self.mlp(point))
point.feat = shortcut + point.feat
if not self.pre_norm:
point = self.norm2(point)
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
return point
class DropPath(torch.nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def _drop_path(
self,
x,
drop_prob: float = 0.0,
training: bool = False,
scale_by_keep: bool = True,
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
def forward(self, x):
return self._drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
class SerializedPooling(PointModule):
def __init__(
self,
in_channels,
out_channels,
stride=2,
norm_layer=None,
act_layer=None,
reduce="max",
shuffle_orders=True,
traceable=True, # record parent and cluster
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8
# TODO: add support to grid pool (any stride)
self.stride = stride
assert reduce in ["sum", "mean", "min", "max"]
self.reduce = reduce
self.shuffle_orders = shuffle_orders
self.traceable = traceable
self.proj = torch.nn.Linear(in_channels, out_channels)
if norm_layer is not None:
self.norm = PointSequential(norm_layer(out_channels))
if act_layer is not None:
self.act = PointSequential(act_layer())
def forward(self, point: Point):
pooling_depth = (math.ceil(self.stride) - 1).bit_length()
if pooling_depth > point.serialized_depth:
pooling_depth = 0
assert {
"serialized_code",
"serialized_order",
"serialized_inverse",
"serialized_depth",
}.issubset(
point.keys()
), "Run point.serialization() point cloud before SerializedPooling"
code = point.serialized_code >> pooling_depth * 3
code_, cluster, counts = torch.unique(
code[0],
sorted=True,
return_inverse=True,
return_counts=True,
)
# indices of point sorted by cluster, for torch_scatter.segment_csr
_, indices = torch.sort(cluster)
# index pointer for sorted point, for torch_scatter.segment_csr
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
# head_indices of each cluster, for reduce attr e.g. code, batch
head_indices = indices[idx_ptr[:-1]]
# generate down code, order, inverse
code = code[:, head_indices]
order = torch.argsort(code)
inverse = torch.zeros_like(order).scatter_(
dim=1,
index=order,
src=torch.arange(0, code.shape[1], device=order.device).repeat(
code.shape[0], 1
),
)
if self.shuffle_orders:
perm = torch.randperm(code.shape[0])
code = code[perm]
order = order[perm]
inverse = inverse[perm]
# collect information
point_dict = addict.Dict(
feat=torch_scatter.segment_csr(
self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
),
coord=torch_scatter.segment_csr(
point.coord[indices], idx_ptr, reduce="mean"
),
grid_coord=point.grid_coord[head_indices] >> pooling_depth,
serialized_code=code,
serialized_order=order,
serialized_inverse=inverse,
serialized_depth=point.serialized_depth - pooling_depth,
batch=point.batch[head_indices],
)
if "condition" in point.keys():
point_dict["condition"] = point.condition
if "context" in point.keys():
point_dict["context"] = point.context
if self.traceable:
point_dict["pooling_inverse"] = cluster
point_dict["pooling_parent"] = point
point = Point(point_dict)
# Fix: Expected more than 1 value per channel when training
if self.norm is not None and point.feat.size(0) != 1:
point = self.norm(point)
if self.act is not None:
point = self.act(point)
point.sparsify()
return point
class SerializedUnpooling(PointModule):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
norm_layer=None,
act_layer=None,
traceable=False, # record parent and cluster
):
super().__init__()
self.proj = PointSequential(torch.nn.Linear(in_channels, out_channels))
self.proj_skip = PointSequential(torch.nn.Linear(skip_channels, out_channels))
if norm_layer is not None:
self.proj.add(norm_layer(out_channels))
self.proj_skip.add(norm_layer(out_channels))
if act_layer is not None:
self.proj.add(act_layer())
self.proj_skip.add(act_layer())
self.traceable = traceable
def forward(self, point):
assert "pooling_parent" in point.keys()
assert "pooling_inverse" in point.keys()
parent = point.pop("pooling_parent")
inverse = point.pop("pooling_inverse")
point = self.proj(point)
parent = self.proj_skip(parent)
parent.feat = parent.feat + point.feat[inverse]
if self.traceable:
parent["unpooling_parent"] = point
return parent
class Embedding(PointModule):
def __init__(
self,
in_channels,
embed_channels,
norm_layer=None,
act_layer=None,
):
super().__init__()
self.in_channels = in_channels
self.embed_channels = embed_channels
# TODO: check remove spconv
self.stem = PointSequential(
conv=spconv.SubMConv3d(
in_channels,
embed_channels,
kernel_size=5,
padding=1,
bias=False,
indice_key="stem",
)
)
if norm_layer is not None:
self.stem.add(norm_layer(embed_channels), name="norm")
if act_layer is not None:
self.stem.add(act_layer(), name="act")
def forward(self, point: Point):
point = self.stem(point)
return point
class PointTransformerV3(PointModule):
def __init__(
self,
in_channels=6,
order=("cord"),
stride=(2, 2, 2, 2),
enc_depths=(2, 2, 2, 6, 2),
enc_channels=(32, 64, 128, 256, 512),
enc_num_head=(2, 4, 8, 16, 32),
enc_patch_size=(1024, 1024, 1024, 1024, 1024),
dec_depths=(2, 2, 2, 2),
dec_channels=(64, 64, 128, 256),
dec_num_head=(4, 4, 8, 16),
dec_patch_size=(1024, 1024, 1024, 1024),
mlp_ratio=4,
grid_size=0.01,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.3,
pre_norm=True,
shuffle_orders=True,
enable_rpe=False,
enable_flash=True,
upcast_attention=False,
upcast_softmax=False,
cls_mode=False,
pdnorm_bn=False,
pdnorm_ln=False,
pdnorm_decouple=True,
pdnorm_adaptive=False,
pdnorm_affine=True,
pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"),
):
super().__init__()
self.num_stages = len(enc_depths)
self.order = [order] if isinstance(order, str) else order
self.cls_mode = cls_mode
self.shuffle_orders = shuffle_orders
self.grid_size = grid_size
assert self.num_stages == len(stride) + 1
assert self.num_stages == len(enc_depths)
assert self.num_stages == len(enc_channels)
assert self.num_stages == len(enc_num_head)
assert self.num_stages == len(enc_patch_size)
assert self.cls_mode or self.num_stages == len(dec_depths) + 1
assert self.cls_mode or self.num_stages == len(dec_channels) + 1
assert self.cls_mode or self.num_stages == len(dec_num_head) + 1
assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1
# norm layers
if pdnorm_bn:
bn_layer = functools.partial(
PDNorm,
norm_layer=functools.partial(
torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine
),
conditions=pdnorm_conditions,
decouple=pdnorm_decouple,
adaptive=pdnorm_adaptive,
)
else:
bn_layer = functools.partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01)
if pdnorm_ln:
ln_layer = functools.partial(
PDNorm,
norm_layer=functools.partial(
torch.nn.LayerNorm, elementwise_affine=pdnorm_affine
),
conditions=pdnorm_conditions,
decouple=pdnorm_decouple,
adaptive=pdnorm_adaptive,
)
else:
ln_layer = torch.nn.LayerNorm
# activation layers
act_layer = torch.nn.GELU
self.embedding = Embedding(
in_channels=in_channels,
embed_channels=enc_channels[0],
norm_layer=bn_layer,
act_layer=act_layer,
)
# encoder
enc_drop_path = [
x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
]
self.enc = PointSequential(name="encoder")
for s in range(self.num_stages):
enc_drop_path_ = enc_drop_path[
sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
]
enc = PointSequential(name="encoder_layer_%d" % s)
if s > 0:
enc.add(
SerializedPooling(
in_channels=enc_channels[s - 1],
out_channels=enc_channels[s],
stride=stride[s - 1],
norm_layer=bn_layer,
act_layer=act_layer,
),
name="down",
)
for i in range(enc_depths[s]):
enc.add(
Block(
channels=enc_channels[s],
num_heads=enc_num_head[s],
patch_size=enc_patch_size[s],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
drop_path=enc_drop_path_[i],
norm_layer=ln_layer,
act_layer=act_layer,
pre_norm=pre_norm,
order_index=i % len(self.order),
cpe_indice_key=f"stage{s}",
enable_rpe=enable_rpe,
enable_flash=enable_flash,
upcast_attention=upcast_attention,
upcast_softmax=upcast_softmax,
),
name=f"block{i}",
)
if len(enc) != 0:
self.enc.add(module=enc, name=f"enc{s}")
# decoder
if not self.cls_mode:
dec_drop_path = [
x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
]
self.dec = PointSequential(name="decoder")
dec_channels = list(dec_channels) + [enc_channels[-1]]
for s in reversed(range(self.num_stages - 1)):
dec_drop_path_ = dec_drop_path[
sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
]
dec_drop_path_.reverse()
dec = PointSequential(name="decoder_layer_%d" % s)
dec.add(
SerializedUnpooling(
in_channels=dec_channels[s + 1],
skip_channels=enc_channels[s],
out_channels=dec_channels[s],
norm_layer=bn_layer,
act_layer=act_layer,
),
name="up",
)
for i in range(dec_depths[s]):
dec.add(
Block(
channels=dec_channels[s],
num_heads=dec_num_head[s],
patch_size=dec_patch_size[s],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
drop_path=dec_drop_path_[i],
norm_layer=ln_layer,
act_layer=act_layer,
pre_norm=pre_norm,
order_index=i % len(self.order),
cpe_indice_key=f"stage{s}",
enable_rpe=enable_rpe,
enable_flash=enable_flash,
upcast_attention=upcast_attention,
upcast_softmax=upcast_softmax,
),
name=f"block{i}",
)
self.dec.add(module=dec, name=f"dec{s}")
def forward(self, batch, feat, coord):
"""
A data_dict is a dictionary containing properties of a batched point cloud.
It should contain the following properties for PTv3:
1. "feat": feature of point cloud
2. "grid_coord": discrete coordinate after grid sampling (voxelization) or "coord" + "grid_size"
3. "offset" or "batch": https://github.com/Pointcept/Pointcept?tab=readme-ov-file#offset
"""
point = Point(
{
"batch": batch.squeeze(dim=0),
"feat": feat.squeeze(dim=0),
"coord": coord.squeeze(dim=0),
"grid_size": self.grid_size,
}
)
point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
point.sparsify()
point = self.embedding(point)
point = self.enc(point)
if not self.cls_mode:
point = self.dec(point)
return point.feat.unsqueeze(dim=0)