hzxie commited on
Commit
83d5461
·
verified ·
0 Parent(s):

fix: reinitialize the repo.

Browse files
Files changed (45) hide show
  1. .gitattributes +37 -0
  2. .gitignore +183 -0
  3. .gitmodules +3 -0
  4. ARTICLE.md +25 -0
  5. LICENSE +35 -0
  6. README.md +17 -0
  7. app.py +241 -0
  8. assets/CENTERS.pkl +3 -0
  9. assets/NYC-HghtFld.png +3 -0
  10. assets/NYC-SegMap.png +3 -0
  11. gaussiancity/__init__.py +0 -0
  12. gaussiancity/extensions/__init__.py +0 -0
  13. gaussiancity/extensions/diff_gaussian_rasterization/CMakeLists.txt +36 -0
  14. gaussiancity/extensions/diff_gaussian_rasterization/LICENSE.md +83 -0
  15. gaussiancity/extensions/diff_gaussian_rasterization/__init__.py +426 -0
  16. gaussiancity/extensions/diff_gaussian_rasterization/bindings.cpp +19 -0
  17. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/auxiliary.h +169 -0
  18. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/backward.cu +622 -0
  19. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/backward.h +41 -0
  20. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/config.h +19 -0
  21. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/forward.cu +376 -0
  22. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/forward.h +43 -0
  23. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer.h +52 -0
  24. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer_impl.cu +339 -0
  25. gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer_impl.h +70 -0
  26. gaussiancity/extensions/diff_gaussian_rasterization/rasterize_points.cu +173 -0
  27. gaussiancity/extensions/diff_gaussian_rasterization/rasterize_points.h +46 -0
  28. gaussiancity/extensions/diff_gaussian_rasterization/setup.py +40 -0
  29. gaussiancity/extensions/diff_gaussian_rasterization/third_party/glm +1 -0
  30. gaussiancity/extensions/diff_gaussian_rasterization/third_party/stbi_image_write.h +1724 -0
  31. gaussiancity/extensions/grid_encoder/__init__.py +193 -0
  32. gaussiancity/extensions/grid_encoder/bindings.cpp +40 -0
  33. gaussiancity/extensions/grid_encoder/grid_encoder_ext.cu +605 -0
  34. gaussiancity/extensions/grid_encoder/setup.py +39 -0
  35. gaussiancity/extensions/voxlib/__init__.py +12 -0
  36. gaussiancity/extensions/voxlib/bindings.cpp +41 -0
  37. gaussiancity/extensions/voxlib/maps_to_volume.cu +142 -0
  38. gaussiancity/extensions/voxlib/points_to_volume.cu +79 -0
  39. gaussiancity/extensions/voxlib/ray_voxel_intersection.cu +332 -0
  40. gaussiancity/extensions/voxlib/setup.py +32 -0
  41. gaussiancity/extensions/voxlib/voxlib_common.h +83 -0
  42. gaussiancity/generator.py +536 -0
  43. gaussiancity/inference.py +582 -0
  44. gaussiancity/pt_v3.py +1344 -0
  45. requirements.txt +14 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.png filter=lfs diff=lfs merge=lfs -text
36
+ *.pth filter=lfs diff=lfs merge=lfs -text
37
+ *.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---> Python
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ idea/
162
+
163
+ # VSCode
164
+ .vscode/
165
+
166
+ # ---> JupyterNotebooks
167
+ # gitignore template for Jupyter Notebooks
168
+ # website: http://jupyter.org/
169
+
170
+ .ipynb_checkpoints
171
+ */.ipynb_checkpoints/*
172
+
173
+ # IPython
174
+ profile_default/
175
+ ipython_config.py
176
+
177
+ # User data
178
+ configs/
179
+ data/
180
+ notebooks/
181
+ output/
182
+ flagged/
183
+ *.pth
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "gaussiancity/extensions/diff_gaussian_rasterization/third_party/glm"]
2
+ path = gaussiancity/extensions/diff_gaussian_rasterization/third_party/glm
3
+ url = https://github.com/g-truc/glm.git
ARTICLE.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Citation 📝
2
+
3
+ If our work is useful for your research, please consider citing:
4
+
5
+ ```bibtex
6
+ @inproceedings{xie2025gaussiancity,
7
+ title = {Generative Gaussian Splatting for Unbounded 3{D} City Generation},
8
+ author = {Xie, Haozhe and
9
+ Chen, Zhaoxi and
10
+ Hong, Fangzhou and
11
+ Liu, Ziwei},
12
+ booktitle = {CVPR},
13
+ year = {2025}
14
+ }
15
+ ```
16
+
17
+ ### License 📋
18
+
19
+ This project is licensed under [S-Lab License 1.0](https://huggingface.co/hzxie/city-dreamer/blob/main/LICENSE).
20
+ Redistribution and use for non-commercial purposes should follow this license.
21
+
22
+ ![Counter](https://api.infinitescript.com/badgen/count?name=hzxie/CityDreamer&ltext=Visitors&color=f97316)
23
+
24
+ ---
25
+
LICENSE ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0
2
+
3
+ Copyright 2025 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and
6
+ binary forms, with or without modification, are permitted provided
7
+ that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright
10
+ notice, this list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright
13
+ notice, this list of conditions and the following disclaimer in
14
+ the documentation and/or other materials provided with the
15
+ distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived
19
+ from this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ In the event that redistribution and/or use for commercial purpose in
34
+ source or binary forms, with or without modification is required,
35
+ please contact the contributor(s) of the work.
README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GaussianCity
3
+ emoji: 🏙️
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.20.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Official demo for **[Generative Gaussian Splatting for Unbounded 3D City Generation](https://github.com/hzxie/GaussianCity) (CVPR 2025).**
13
+
14
+ - 🔥 GaussianCity is a unbounded 3D city generator based on 3D Gaussian Splatting.
15
+ - 🤗 Try GaussianCity to generate photolistic 3D cities.
16
+ - ⚠️ Due to the limited computational resources at Hugging Face, this demo only generates **A SINGLE IMAGE** based on the New York City layout.
17
+
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: app.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2024-03-02 16:30:00
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-10-13 15:36:50
8
+ # @Email: root@haozhexie.com
9
+
10
+ import gradio as gr
11
+ import logging
12
+ import numpy as np
13
+ import os
14
+ import pickle
15
+ import ssl
16
+ import subprocess
17
+ import sys
18
+ import urllib.request
19
+
20
+ from PIL import Image
21
+
22
+ # Reinstall PyTorch with CUDA 11.8 (Default version is 12.1)
23
+ # subprocess.call(
24
+ # [
25
+ # "pip",
26
+ # "install",
27
+ # "torch==2.2.2",
28
+ # "torchvision==0.17.2",
29
+ # "--index-url",
30
+ # "https://download.pytorch.org/whl/cu118",
31
+ # ]
32
+ # )
33
+ import torch
34
+
35
+ # Create a dummy decorator for Non-ZeroGPU environments
36
+ if os.environ.get("SPACES_ZERO_GPU") is not None:
37
+ import spaces
38
+ else:
39
+
40
+ class spaces:
41
+ @staticmethod
42
+ def GPU(func):
43
+ # This is a dummy wrapper that just calls the function.
44
+ def wrapper(*args, **kwargs):
45
+ return func(*args, **kwargs)
46
+
47
+ return wrapper
48
+
49
+
50
+ # Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
51
+ ssl._create_default_https_context = ssl._create_unverified_context
52
+ # Import GaussianCity modules
53
+ sys.path.append(os.path.join(os.path.dirname(__file__), "gaussiancity"))
54
+
55
+
56
+ def _get_output(cmd):
57
+ try:
58
+ return subprocess.check_output(cmd).decode("utf-8")
59
+ except Exception as ex:
60
+ logging.exception(ex)
61
+
62
+ return None
63
+
64
+
65
+ def install_cuda_toolkit():
66
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
67
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
68
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
69
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
70
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
71
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
72
+
73
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
74
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
75
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
76
+ os.environ["CUDA_HOME"],
77
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
78
+ )
79
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
80
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
81
+
82
+
83
+ def setup_runtime_env():
84
+ logging.info("Python Version: %s" % _get_output(["python", "--version"]))
85
+ logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"]))
86
+ logging.info("GCC Version: %s" % _get_output(["gcc", "--version"]))
87
+ logging.info("CUDA is available: %s" % torch.cuda.is_available())
88
+ logging.info("CUDA Device Capability: %s" % (torch.cuda.get_device_capability(),))
89
+
90
+ # Install Pre-compiled CUDA extensions (Not working)
91
+ # Ref: https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/110
92
+ #
93
+ # ext_dir = os.path.join(os.path.dirname(__file__), "wheels")
94
+ # for e in os.listdir(ext_dir):
95
+ # logging.info("Installing Extensions from %s" % e)
96
+ # subprocess.call(
97
+ # ["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT
98
+ # )
99
+ # Compile CUDA extensions
100
+ ext_dir = os.path.join(os.path.dirname(__file__), "gaussiancity", "extensions")
101
+ for e in os.listdir(ext_dir):
102
+ if os.path.isdir(os.path.join(ext_dir, e)):
103
+ subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))
104
+
105
+ logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"]))
106
+
107
+
108
+ def get_models(file_name):
109
+ import gaussiancity.generator
110
+
111
+ if not os.path.exists(file_name):
112
+ urllib.request.urlretrieve(
113
+ "https://huggingface.co/hzxie/gaussian-city/resolve/main/%s" % file_name,
114
+ file_name,
115
+ )
116
+
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ ckpt = torch.load(file_name, map_location=torch.device(device), weights_only=False)
119
+ model = gaussiancity.generator.Generator(
120
+ ckpt["cfg"].NETWORK.GAUSSIAN,
121
+ n_classes=ckpt["cfg"].DATASETS.GOOGLE_EARTH.N_CLASSES,
122
+ proj_size=ckpt["cfg"].DATASETS.GOOGLE_EARTH.PROJ_SIZE,
123
+ )
124
+ if torch.cuda.is_available():
125
+ model = torch.nn.DataParallel(model).cuda().eval()
126
+
127
+ model.load_state_dict(ckpt["gaussian_g"], strict=False)
128
+ return model
129
+
130
+
131
+ def get_city_layout():
132
+ import gaussiancity.inference
133
+
134
+ layout = None
135
+ if os.path.exists("assets/NYC.pkl"):
136
+ with open("assets/NYC.pkl", "rb") as fp:
137
+ layout = pickle.load(fp)
138
+ else:
139
+ td_hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
140
+ # Fix: nonzero is not supported for tensors with more than INT_MAX elements
141
+ td_hf[td_hf > 500] = 500
142
+ bu_hf = np.zeros_like(td_hf)
143
+ seg_map = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(
144
+ np.int32
145
+ )
146
+ ins_map = gaussiancity.inference.get_instance_seg_map(seg_map.copy())
147
+ pts_map = gaussiancity.inference.get_point_map(seg_map)
148
+ layout = {
149
+ "TD_HF": td_hf,
150
+ "BU_HF": bu_hf,
151
+ "SEG": seg_map,
152
+ "INS": ins_map,
153
+ "PTS": pts_map,
154
+ }
155
+ with open("assets/NYC.pkl", "wb") as fp:
156
+ pickle.dump(layout, fp)
157
+
158
+ centers = None
159
+ if os.path.exists("assets/CENTERS.pkl"):
160
+ with open("assets/CENTERS.pkl", "rb") as fp:
161
+ centers = pickle.load(fp)
162
+ else:
163
+ centers = gaussiancity.inference.get_centers(layout["INS"], layout["TD_HF"])
164
+ with open("assets/CENTERS.pkl", "wb") as fp:
165
+ pickle.dump(centers, fp)
166
+
167
+ layout["CTR"] = centers
168
+ return layout
169
+
170
+
171
+ @spaces.GPU
172
+ def get_generated_city(radius, altitude, azimuth, map_center):
173
+ logging.info("CUDA is available: %s" % torch.cuda.is_available())
174
+ logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda)
175
+ # The import must be done after CUDA extension compilation
176
+ import gaussiancity.inference
177
+
178
+ return gaussiancity.inference.generate_city(
179
+ get_generated_city.fgm.to("cuda"),
180
+ get_generated_city.bgm.to("cuda"),
181
+ get_generated_city.city_layout,
182
+ map_center,
183
+ map_center,
184
+ radius,
185
+ altitude,
186
+ azimuth,
187
+ )
188
+
189
+
190
+ def main(debug):
191
+ title = "Generative Gaussian Splatting for Unbounded 3D City Generation"
192
+ with open("README.md", "r") as f:
193
+ markdown = f.read()
194
+ desc = markdown[markdown.rfind("---") + 3 :]
195
+ with open("ARTICLE.md", "r") as f:
196
+ arti = f.read()
197
+
198
+ app = gr.Interface(
199
+ get_generated_city,
200
+ [
201
+ gr.Slider(256, 768, value=512, step=4, label="Camera Radius (m)"),
202
+ gr.Slider(256, 768, value=512, step=4, label="Camera Altitude (m)"),
203
+ gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"),
204
+ gr.Slider(1024, 7168, value=3570, step=4, label="Map Center (px)"),
205
+ ],
206
+ [gr.Image(type="numpy", label="Generated City")],
207
+ title=title,
208
+ description=desc,
209
+ article=arti,
210
+ flagging_mode="never",
211
+ )
212
+ app.queue(api_open=False)
213
+ app.launch(debug=debug)
214
+
215
+
216
+ if __name__ == "__main__":
217
+ logging.basicConfig(
218
+ format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
219
+ )
220
+ logging.info("Environment Variables: %s" % os.environ)
221
+ if _get_output(["nvcc", "--version"]) is None:
222
+ logging.info("Installing CUDA toolkit...")
223
+ install_cuda_toolkit()
224
+ else:
225
+ logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"]))
226
+
227
+ logging.info("Compiling CUDA extensions...")
228
+ # setup_runtime_env()
229
+
230
+ logging.info("Downloading pretrained models...")
231
+ fgm = get_models("GaussianCity-Fgnd.pth")
232
+ bgm = get_models("GaussianCity-Bgnd.pth")
233
+ get_generated_city.fgm = fgm
234
+ get_generated_city.bgm = bgm
235
+
236
+ logging.info("Loading New York city layout to RAM...")
237
+ city_layout = get_city_layout()
238
+ get_generated_city.city_layout = city_layout
239
+
240
+ logging.info("Starting the main application...")
241
+ main(os.getenv("DEBUG") == "1")
assets/CENTERS.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cad871bfb3997485a6d464c1464c2a551b601a4444913b9ec808530d093eefd8
3
+ size 728474
assets/NYC-HghtFld.png ADDED

Git LFS Details

  • SHA256: 51bcb2d4b097e1307e254427dbf8ec05772ff8e833a5d53993c7380188214ba9
  • Pointer size: 132 Bytes
  • Size of remote file: 5.29 MB
assets/NYC-SegMap.png ADDED

Git LFS Details

  • SHA256: 0e6f34f802829f97462885ab3a07b7720d7eb18bf15f810e683c89c8d53c3b6d
  • Pointer size: 132 Bytes
  • Size of remote file: 3 MB
gaussiancity/__init__.py ADDED
File without changes
gaussiancity/extensions/__init__.py ADDED
File without changes
gaussiancity/extensions/diff_gaussian_rasterization/CMakeLists.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ cmake_minimum_required(VERSION 3.20)
13
+
14
+ project(DiffRast LANGUAGES CUDA CXX)
15
+
16
+ set(CMAKE_CXX_STANDARD 17)
17
+ set(CMAKE_CXX_EXTENSIONS OFF)
18
+ set(CMAKE_CUDA_STANDARD 17)
19
+
20
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
21
+
22
+ add_library(CudaRasterizer
23
+ cuda_rasterizer/backward.h
24
+ cuda_rasterizer/backward.cu
25
+ cuda_rasterizer/forward.h
26
+ cuda_rasterizer/forward.cu
27
+ cuda_rasterizer/auxiliary.h
28
+ cuda_rasterizer/rasterizer_impl.cu
29
+ cuda_rasterizer/rasterizer_impl.h
30
+ cuda_rasterizer/rasterizer.h
31
+ )
32
+
33
+ set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86")
34
+
35
+ target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer)
36
+ target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
gaussiancity/extensions/diff_gaussian_rasterization/LICENSE.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Gaussian-Splatting License
2
+ ===========================
3
+
4
+ **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5
+ The *Software* is in the process of being registered with the Agence pour la Protection des
6
+ Programmes (APP).
7
+
8
+ The *Software* is still being developed by the *Licensor*.
9
+
10
+ *Licensor*'s goal is to allow the research community to use, test and evaluate
11
+ the *Software*.
12
+
13
+ ## 1. Definitions
14
+
15
+ *Licensee* means any person or entity that uses the *Software* and distributes
16
+ its *Work*.
17
+
18
+ *Licensor* means the owners of the *Software*, i.e Inria and MPII
19
+
20
+ *Software* means the original work of authorship made available under this
21
+ License ie gaussian-splatting.
22
+
23
+ *Work* means the *Software* and any additions to or derivative works of the
24
+ *Software* that are made available under this License.
25
+
26
+
27
+ ## 2. Purpose
28
+ This license is intended to define the rights granted to the *Licensee* by
29
+ Licensors under the *Software*.
30
+
31
+ ## 3. Rights granted
32
+
33
+ For the above reasons Licensors have decided to distribute the *Software*.
34
+ Licensors grant non-exclusive rights to use the *Software* for research purposes
35
+ to research users (both academic and industrial), free of charge, without right
36
+ to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37
+ and/or evaluation purposes only.
38
+
39
+ Subject to the terms and conditions of this License, you are granted a
40
+ non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41
+ publicly display, publicly perform and distribute its *Work* and any resulting
42
+ derivative works in any form.
43
+
44
+ ## 4. Limitations
45
+
46
+ **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47
+ so under this License, (b) you include a complete copy of this License with
48
+ your distribution, and (c) you retain without modification any copyright,
49
+ patent, trademark, or attribution notices that are present in the *Work*.
50
+
51
+ **4.2 Derivative Works.** You may specify that additional or different terms apply
52
+ to the use, reproduction, and distribution of your derivative works of the *Work*
53
+ ("Your Terms") only if (a) Your Terms provide that the use limitation in
54
+ Section 2 applies to your derivative works, and (b) you identify the specific
55
+ derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56
+ this License (including the redistribution requirements in Section 3.1) will
57
+ continue to apply to the *Work* itself.
58
+
59
+ **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60
+ users explicitly acknowledge having received from Licensors all information
61
+ allowing to appreciate the adequacy between of the *Software* and their needs and
62
+ to undertake all necessary precautions for its execution and use.
63
+
64
+ **4.4** The *Software* is provided both as a compiled library file and as source
65
+ code. In case of using the *Software* for a publication or other results obtained
66
+ through the use of the *Software*, users are strongly encouraged to cite the
67
+ corresponding publications as explained in the documentation of the *Software*.
68
+
69
+ ## 5. Disclaimer
70
+
71
+ THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72
+ WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73
+ UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74
+ CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75
+ OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76
+ USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77
+ ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78
+ AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80
+ GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81
+ HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83
+ IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
gaussiancity/extensions/diff_gaussian_rasterization/__init__.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: __init__.py
4
+ # @Author: Inria <george.drettakis@inria.fr>
5
+ # @Date: 2024-01-31 19:07:01
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-05-01 14:14:49
8
+ # @Email: root@haozhexie.com
9
+
10
+ import math
11
+ import numpy as np
12
+ import scipy.spatial.transform
13
+ import torch
14
+ import typing
15
+
16
+ import diff_gaussian_rasterization_ext as dgr_ext
17
+
18
+
19
+ class RasterizeGaussiansFunction(torch.autograd.Function):
20
+ @staticmethod
21
+ def _cpu_deep_copy_tuple(input_tuple):
22
+ copied_tensors = [
23
+ item.cpu().clone() if isinstance(item, torch.Tensor) else item
24
+ for item in input_tuple
25
+ ]
26
+ return tuple(copied_tensors)
27
+
28
+ @staticmethod
29
+ def forward(
30
+ ctx,
31
+ means3D,
32
+ means2D,
33
+ sh,
34
+ colors_precomp,
35
+ opacities,
36
+ scales,
37
+ rotations,
38
+ cov3Ds_precomp,
39
+ raster_settings,
40
+ ):
41
+ # Restructure arguments the way that the C++ lib expects them
42
+ args = (
43
+ raster_settings.bg,
44
+ means3D,
45
+ colors_precomp,
46
+ opacities,
47
+ scales,
48
+ rotations,
49
+ raster_settings.scale_modifier,
50
+ cov3Ds_precomp,
51
+ raster_settings.view_matrix,
52
+ raster_settings.proj_matrix,
53
+ raster_settings.tanfovx,
54
+ raster_settings.tanfovy,
55
+ raster_settings.img_h,
56
+ raster_settings.img_w,
57
+ sh,
58
+ raster_settings.sh_degree,
59
+ raster_settings.campos,
60
+ raster_settings.prefiltered,
61
+ raster_settings.debug,
62
+ )
63
+
64
+ # Invoke C++/CUDA rasterizer
65
+ if raster_settings.debug:
66
+ cpu_args = RasterizeGaussiansFunction._cpu_deep_copy_tuple(
67
+ input_tuple=args
68
+ ) # Copy them before they can be corrupted
69
+ try:
70
+ (
71
+ num_rendered,
72
+ color,
73
+ radii,
74
+ geom_buffer,
75
+ binning_buffer,
76
+ img_buffer,
77
+ ) = dgr_ext.rasterize_gaussians(*args)
78
+ except Exception as ex:
79
+ torch.save(cpu_args, "snapshot_fw.dump")
80
+ print(
81
+ "\nAn error occured in forward. Please forward snapshot_fw.dump for debugging."
82
+ )
83
+ raise ex
84
+ else:
85
+ (
86
+ num_rendered,
87
+ color,
88
+ radii,
89
+ geom_buffer,
90
+ binning_buffer,
91
+ img_buffer,
92
+ ) = dgr_ext.rasterize_gaussians(*args)
93
+
94
+ # Keep relevant tensors for backward
95
+ ctx.raster_settings = raster_settings
96
+ ctx.num_rendered = num_rendered
97
+ ctx.save_for_backward(
98
+ colors_precomp,
99
+ means3D,
100
+ scales,
101
+ rotations,
102
+ cov3Ds_precomp,
103
+ radii,
104
+ sh,
105
+ geom_buffer,
106
+ binning_buffer,
107
+ img_buffer,
108
+ )
109
+ return color, radii
110
+
111
+ @staticmethod
112
+ def backward(ctx, grad_out_color, _):
113
+ # Restore necessary values from context
114
+ num_rendered = ctx.num_rendered
115
+ raster_settings = ctx.raster_settings
116
+ (
117
+ colors_precomp,
118
+ means3D,
119
+ scales,
120
+ rotations,
121
+ cov3Ds_precomp,
122
+ radii,
123
+ sh,
124
+ geom_buffer,
125
+ binning_buffer,
126
+ img_buffer,
127
+ ) = ctx.saved_tensors
128
+
129
+ # Restructure args as C++ method expects them
130
+ args = (
131
+ raster_settings.bg,
132
+ means3D,
133
+ radii,
134
+ colors_precomp,
135
+ scales,
136
+ rotations,
137
+ raster_settings.scale_modifier,
138
+ cov3Ds_precomp,
139
+ raster_settings.view_matrix,
140
+ raster_settings.proj_matrix,
141
+ raster_settings.tanfovx,
142
+ raster_settings.tanfovy,
143
+ grad_out_color,
144
+ sh,
145
+ raster_settings.sh_degree,
146
+ raster_settings.campos,
147
+ geom_buffer,
148
+ num_rendered,
149
+ binning_buffer,
150
+ img_buffer,
151
+ raster_settings.debug,
152
+ )
153
+
154
+ # Compute gradients for relevant tensors by invoking backward method
155
+ if raster_settings.debug:
156
+ cpu_args = RasterizeGaussiansFunction._cpu_deep_copy_tuple(
157
+ input_tuple=args
158
+ ) # Copy them before they can be corrupted
159
+ try:
160
+ (
161
+ grad_means2D,
162
+ grad_colors_precomp,
163
+ grad_opacities,
164
+ grad_means3D,
165
+ grad_cov3Ds_precomp,
166
+ grad_sh,
167
+ grad_scales,
168
+ grad_rotations,
169
+ ) = dgr_ext.rasterize_gaussians_backward(*args)
170
+ except Exception as ex:
171
+ torch.save(cpu_args, "snapshot_bw.dump")
172
+ print(
173
+ "\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n"
174
+ )
175
+ raise ex
176
+ else:
177
+ (
178
+ grad_means2D,
179
+ grad_colors_precomp,
180
+ grad_opacities,
181
+ grad_means3D,
182
+ grad_cov3Ds_precomp,
183
+ grad_sh,
184
+ grad_scales,
185
+ grad_rotations,
186
+ ) = dgr_ext.rasterize_gaussians_backward(*args)
187
+
188
+ grads = (
189
+ grad_means3D,
190
+ grad_means2D,
191
+ grad_sh,
192
+ grad_colors_precomp,
193
+ grad_opacities,
194
+ grad_scales,
195
+ grad_rotations,
196
+ grad_cov3Ds_precomp,
197
+ None,
198
+ )
199
+
200
+ return grads
201
+
202
+
203
+ class GaussianRasterizationSettings(typing.NamedTuple):
204
+ img_h: int
205
+ img_w: int
206
+ tanfovx: float
207
+ tanfovy: float
208
+ bg: torch.Tensor
209
+ scale_modifier: float
210
+ view_matrix: torch.Tensor
211
+ proj_matrix: torch.Tensor
212
+ sh_degree: int
213
+ campos: torch.Tensor
214
+ prefiltered: bool
215
+ debug: bool
216
+
217
+
218
+ class GaussianRasterizer(torch.nn.Module):
219
+ def __init__(self, raster_settings):
220
+ super(GaussianRasterizer, self).__init__()
221
+ self.raster_settings = raster_settings
222
+
223
+ def forward(
224
+ self,
225
+ means3D,
226
+ means2D,
227
+ opacities,
228
+ shs=None,
229
+ colors_precomp=None,
230
+ scales=None,
231
+ rotations=None,
232
+ cov3D_precomp=None,
233
+ ):
234
+ raster_settings = self.raster_settings
235
+
236
+ if (shs is None and colors_precomp is None) or (
237
+ shs is not None and colors_precomp is not None
238
+ ):
239
+ raise Exception(
240
+ "Please provide excatly one of either SHs or precomputed colors!"
241
+ )
242
+
243
+ if ((scales is None or rotations is None) and cov3D_precomp is None) or (
244
+ (scales is not None or rotations is not None) and cov3D_precomp is not None
245
+ ):
246
+ raise Exception(
247
+ "Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!"
248
+ )
249
+
250
+ if shs is None:
251
+ shs = torch.Tensor([])
252
+ if colors_precomp is None:
253
+ colors_precomp = torch.Tensor([])
254
+
255
+ if scales is None:
256
+ scales = torch.Tensor([])
257
+ if rotations is None:
258
+ rotations = torch.Tensor([])
259
+ if cov3D_precomp is None:
260
+ cov3D_precomp = torch.Tensor([])
261
+
262
+ # Invoke C++/CUDA rasterization routine
263
+ return RasterizeGaussiansFunction.apply(
264
+ means3D,
265
+ means2D,
266
+ shs,
267
+ colors_precomp,
268
+ opacities,
269
+ scales,
270
+ rotations,
271
+ cov3D_precomp,
272
+ raster_settings,
273
+ )
274
+
275
+
276
+ class GaussianRasterizerWrapper(torch.nn.Module):
277
+ # Carving flowers on a mountain of dung code.
278
+ #
279
+ # This class is a wrapper for the GaussianRasterizer class.
280
+ # It is used to port for the GaussianCity project.
281
+ def __init__(
282
+ self,
283
+ K,
284
+ sensor_size,
285
+ flip_lr=True,
286
+ flip_ud=False,
287
+ z_near=0.01,
288
+ z_far=50000.0,
289
+ device=torch.device("cuda"),
290
+ ):
291
+ super(GaussianRasterizerWrapper, self).__init__()
292
+ self.flip_lr = flip_lr
293
+ self.flip_ud = flip_ud
294
+ self.z_near = z_near
295
+ self.z_far = z_far
296
+ self.device = device
297
+ # Shared camera parameters
298
+ self.K = K
299
+ self.sensor_size = sensor_size
300
+ self.fov_x, self.fov_y = self._intrinsic_to_fov()
301
+ self.P = self._get_projection_matrix()
302
+
303
+ def get_gaussian_rasterizer(self, cam_position, cam_quaternion):
304
+ # cam_position in (tx, ty, tz)
305
+ # cam_quaternion in (qx, qy, qz, qw)
306
+ return GaussianRasterizer(
307
+ raster_settings=self._get_gaussian_rasterization_settings(
308
+ cam_position, cam_quaternion
309
+ )
310
+ )
311
+
312
+ def forward(
313
+ self, points, cam_position=None, cam_quaternion=None, gaussian_rasterizer=None
314
+ ):
315
+ # points: [N, M], M -> 0:3 xyz, 3:4 opacity, 4:7 scale, 7:11 rotation, 11:14 rgbs
316
+ _, M = points.shape
317
+ assert M == 14, "The input tensor should have 14 channels."
318
+
319
+ if gaussian_rasterizer is None:
320
+ gaussian_rasterizer = self.get_gaussian_rasterizer(
321
+ cam_position, cam_quaternion
322
+ )
323
+
324
+ return self._get_gaussian_rasterization(points, gaussian_rasterizer)
325
+
326
+ def _intrinsic_to_fov(self):
327
+ # graphdeco-inria/gaussian-splatting/utils/graphics_utils.py#L76
328
+ fx, fy = self.K[0, 0], self.K[1, 1]
329
+ fov_x = 2 * np.arctan2(self.sensor_size[0], (2 * fx))
330
+ fov_y = 2 * np.arctan2(self.sensor_size[1], (2 * fy))
331
+ return fov_x, fov_y
332
+
333
+ def _get_projection_matrix(self):
334
+ fx = self.K[0, 0]
335
+ fy = self.K[1, 1]
336
+ cx = self.K[0, 2]
337
+ cy = self.K[1, 2]
338
+
339
+ P = np.zeros((4, 4), dtype=np.float32)
340
+ P[0, 0] = 2.0 * fx / self.sensor_size[0]
341
+ P[1, 1] = 2.0 * fy / self.sensor_size[1]
342
+ P[0, 2] = (2.0 * cx / self.sensor_size[0]) - 1.0
343
+ P[1, 2] = (2.0 * cy / self.sensor_size[1]) - 1.0
344
+ P[2, 2] = -(self.z_far + self.z_near) / (self.z_far - self.z_near)
345
+ P[3, 2] = -1.0
346
+ P[2, 3] = -2.0 * self.z_far * self.z_near / (self.z_far - self.z_near)
347
+ return torch.from_numpy(P).to(self.device)
348
+
349
+ def _get_w2c_matrix(self, cam_position, cam_quaternion):
350
+ if type(cam_position) is torch.Tensor:
351
+ cam_position = cam_position.cpu().numpy()
352
+ if type(cam_quaternion) is torch.Tensor:
353
+ cam_quaternion = cam_quaternion.cpu().numpy()
354
+
355
+ R = scipy.spatial.transform.Rotation.from_quat(cam_quaternion).as_matrix()
356
+ # look_at = cam_position + R[:3, 0]
357
+ R = R[:, [1, 2, 0]] # [F|R|U] -> [R|U|F]
358
+ # graphdeco-inria/gaussian-splatting/blob/main/scene/cameras.py#L31
359
+ # The w2c matrix
360
+ Rt = np.zeros((4, 4), dtype=np.float32)
361
+ Rt[:3, :3] = R.transpose()
362
+ Rt[:3, [3]] = -R.transpose() @ cam_position[:, None]
363
+ Rt[3, 3] = 1.0
364
+ # The c2w matrix
365
+ # Rt[:3, :3] = R
366
+ # Rt[:3, 3] = cam_position
367
+ # Rt[3, 3] = 1.0
368
+ return torch.from_numpy(Rt).to(self.device)
369
+
370
+ def _world_to_pixel(self, world_coords, w2c):
371
+ # NOTE: The function is used to debug whether the w2c matrix is correct.
372
+ # Convert world coordinates to camera coordinates using the inverse of w2c
373
+ camera_coords = np.dot(w2c[:3, :3], world_coords) + w2c[:3, 3]
374
+ # camera_coords = np.dot(np.linalg.inv(c2w[:3, :3]), (world_coords- w2c[:3, 3]))
375
+ # Apply the camera intrinsic matrix K to obtain normalized image coordinates
376
+ homogeneous_coords = np.dot(self.K, camera_coords)
377
+ # Normalize homogeneous coordinates
378
+ normalized_coords = homogeneous_coords / homogeneous_coords[2]
379
+ # Convert normalized coordinates to pixel coordinates
380
+ return normalized_coords[:2].astype(int)
381
+
382
+ def _get_gaussian_rasterization_settings(self, cam_position, cam_quaternion):
383
+ BG_COLOR = torch.tensor(
384
+ [0.0, 0.0, 0.0], dtype=torch.float32, device=self.device
385
+ )
386
+ w2c = self._get_w2c_matrix(cam_position, cam_quaternion).transpose(0, 1)
387
+ prj_mtx = self.P.transpose(0, 1)
388
+
389
+ return GaussianRasterizationSettings(
390
+ img_h=self.sensor_size[1],
391
+ img_w=self.sensor_size[0],
392
+ tanfovx=math.tan(self.fov_x * 0.5),
393
+ tanfovy=math.tan(self.fov_y * 0.5),
394
+ bg=BG_COLOR,
395
+ scale_modifier=1.0,
396
+ view_matrix=w2c,
397
+ proj_matrix=w2c @ prj_mtx,
398
+ sh_degree=0,
399
+ campos=w2c.inverse()[3, :3],
400
+ prefiltered=False,
401
+ debug=False,
402
+ )
403
+
404
+ def _get_gaussian_rasterization(self, points, rasterizer):
405
+ xyz = points[:, 0:3]
406
+ opacity = points[:, 3:4]
407
+ scales = points[:, 4:7]
408
+ quaternion = points[:, 7:11]
409
+ rgbs = points[:, 11:]
410
+
411
+ rendered_image, _ = rasterizer(
412
+ means3D=xyz,
413
+ means2D=torch.zeros_like(xyz, dtype=torch.float32, device=self.device),
414
+ shs=None,
415
+ colors_precomp=rgbs,
416
+ opacities=opacity,
417
+ scales=scales,
418
+ rotations=quaternion,
419
+ cov3D_precomp=None,
420
+ )
421
+ if self.flip_lr:
422
+ rendered_image = torch.flip(rendered_image, dims=[2])
423
+ if self.flip_ud:
424
+ rendered_image = torch.flip(rendered_image, dims=[1])
425
+
426
+ return rendered_image
gaussiancity/extensions/diff_gaussian_rasterization/bindings.cpp ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include "rasterize_points.h"
13
+ #include <torch/extension.h>
14
+
15
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16
+ m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
17
+ m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
18
+ m.def("mark_visible", &markVisible);
19
+ }
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/auxiliary.h ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
13
+ #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
14
+
15
+ #include "config.h"
16
+ #include "stdio.h"
17
+
18
+ #define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
19
+ #define NUM_WARPS (BLOCK_SIZE / 32)
20
+
21
+ // Spherical harmonics coefficients
22
+ __device__ const float SH_C0 = 0.28209479177387814f;
23
+ __device__ const float SH_C1 = 0.4886025119029199f;
24
+ __device__ const float SH_C2[] = {1.0925484305920792f, -1.0925484305920792f,
25
+ 0.31539156525252005f, -1.0925484305920792f,
26
+ 0.5462742152960396f};
27
+ __device__ const float SH_C3[] = {-0.5900435899266435f, 2.890611442640554f,
28
+ -0.4570457994644658f, 0.3731763325901154f,
29
+ -0.4570457994644658f, 1.445305721320277f,
30
+ -0.5900435899266435f};
31
+
32
+ __forceinline__ __device__ float ndc2Pix(float v, int S) {
33
+ return ((v + 1.0) * S - 1.0) * 0.5;
34
+ }
35
+
36
+ __forceinline__ __device__ void getRect(const float2 p, int max_radius,
37
+ uint2 &rect_min, uint2 &rect_max,
38
+ dim3 grid) {
39
+ rect_min = {min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
40
+ min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))};
41
+ rect_max = {
42
+ min(grid.x,
43
+ max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
44
+ min(grid.y,
45
+ max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))};
46
+ }
47
+
48
+ __forceinline__ __device__ float3 transformPoint4x3(const float3 &p,
49
+ const float *matrix) {
50
+ float3 transformed = {
51
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
52
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
53
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
54
+ };
55
+ return transformed;
56
+ }
57
+
58
+ __forceinline__ __device__ float4 transformPoint4x4(const float3 &p,
59
+ const float *matrix) {
60
+ float4 transformed = {
61
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
62
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
63
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
64
+ matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]};
65
+ return transformed;
66
+ }
67
+
68
+ __forceinline__ __device__ float3 transformVec4x3(const float3 &p,
69
+ const float *matrix) {
70
+ float3 transformed = {
71
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z,
72
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z,
73
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z,
74
+ };
75
+ return transformed;
76
+ }
77
+
78
+ __forceinline__ __device__ float3
79
+ transformVec4x3Transpose(const float3 &p, const float *matrix) {
80
+ float3 transformed = {
81
+ matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z,
82
+ matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z,
83
+ matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z,
84
+ };
85
+ return transformed;
86
+ }
87
+
88
+ __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) {
89
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
90
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
91
+ float dnormvdz =
92
+ (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) *
93
+ invsum32;
94
+ return dnormvdz;
95
+ }
96
+
97
+ __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) {
98
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
99
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
100
+
101
+ float3 dnormvdv;
102
+ dnormvdv.x =
103
+ ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) *
104
+ invsum32;
105
+ dnormvdv.y =
106
+ (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) *
107
+ invsum32;
108
+ dnormvdv.z =
109
+ (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) *
110
+ invsum32;
111
+ return dnormvdv;
112
+ }
113
+
114
+ __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) {
115
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
116
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
117
+
118
+ float4 vdv = {v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w};
119
+ float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
120
+ float4 dnormvdv;
121
+ dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32;
122
+ dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32;
123
+ dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32;
124
+ dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32;
125
+ return dnormvdv;
126
+ }
127
+
128
+ __forceinline__ __device__ float sigmoid(float x) {
129
+ return 1.0f / (1.0f + expf(-x));
130
+ }
131
+
132
+ __forceinline__ __device__ bool in_frustum(int idx, const float *orig_points,
133
+ const float *viewmatrix,
134
+ const float *projmatrix,
135
+ bool prefiltered, float3 &p_view) {
136
+ float3 p_orig = {orig_points[3 * idx], orig_points[3 * idx + 1],
137
+ orig_points[3 * idx + 2]};
138
+
139
+ // Bring points to screen space
140
+ float4 p_hom = transformPoint4x4(p_orig, projmatrix);
141
+ float p_w = 1.0f / (p_hom.w + 0.0000001f);
142
+ float3 p_proj = {p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w};
143
+ p_view = transformPoint4x3(p_orig, viewmatrix);
144
+
145
+ if (p_view.z <= 0.2f) // || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y <
146
+ // -1.3 || p_proj.y > 1.3)))
147
+ {
148
+ if (prefiltered) {
149
+ printf("Point is filtered although prefiltered is set. This shouldn't "
150
+ "happen!");
151
+ __trap();
152
+ }
153
+ return false;
154
+ }
155
+ return true;
156
+ }
157
+
158
+ #define CHECK_CUDA(A, debug) \
159
+ A; \
160
+ if (debug) { \
161
+ auto ret = cudaDeviceSynchronize(); \
162
+ if (ret != cudaSuccess) { \
163
+ std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ \
164
+ << ": " << cudaGetErrorString(ret); \
165
+ throw std::runtime_error(cudaGetErrorString(ret)); \
166
+ } \
167
+ }
168
+
169
+ #endif
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/backward.cu ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include "auxiliary.h"
13
+ #include "backward.h"
14
+ #include <cooperative_groups.h>
15
+ #include <cooperative_groups/reduce.h>
16
+ namespace cg = cooperative_groups;
17
+
18
+ // Backward pass for conversion of spherical harmonics to RGB for
19
+ // each Gaussian.
20
+ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs,
21
+ const glm::vec3 *means, glm::vec3 campos,
22
+ const float *shs, const bool *clamped,
23
+ const glm::vec3 *dL_dcolor,
24
+ glm::vec3 *dL_dmeans, glm::vec3 *dL_dshs) {
25
+ // Compute intermediate values, as it is done during forward
26
+ glm::vec3 pos = means[idx];
27
+ glm::vec3 dir_orig = pos - campos;
28
+ glm::vec3 dir = dir_orig / glm::length(dir_orig);
29
+
30
+ glm::vec3 *sh = ((glm::vec3 *)shs) + idx * max_coeffs;
31
+
32
+ // Use PyTorch rule for clamping: if clamping was applied,
33
+ // gradient becomes 0.
34
+ glm::vec3 dL_dRGB = dL_dcolor[idx];
35
+ dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
36
+ dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
37
+ dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
38
+
39
+ glm::vec3 dRGBdx(0, 0, 0);
40
+ glm::vec3 dRGBdy(0, 0, 0);
41
+ glm::vec3 dRGBdz(0, 0, 0);
42
+ float x = dir.x;
43
+ float y = dir.y;
44
+ float z = dir.z;
45
+
46
+ // Target location for this Gaussian to write SH gradients to
47
+ glm::vec3 *dL_dsh = dL_dshs + idx * max_coeffs;
48
+
49
+ // No tricks here, just high school-level calculus.
50
+ float dRGBdsh0 = SH_C0;
51
+ dL_dsh[0] = dRGBdsh0 * dL_dRGB;
52
+ if (deg > 0) {
53
+ float dRGBdsh1 = -SH_C1 * y;
54
+ float dRGBdsh2 = SH_C1 * z;
55
+ float dRGBdsh3 = -SH_C1 * x;
56
+ dL_dsh[1] = dRGBdsh1 * dL_dRGB;
57
+ dL_dsh[2] = dRGBdsh2 * dL_dRGB;
58
+ dL_dsh[3] = dRGBdsh3 * dL_dRGB;
59
+
60
+ dRGBdx = -SH_C1 * sh[3];
61
+ dRGBdy = -SH_C1 * sh[1];
62
+ dRGBdz = SH_C1 * sh[2];
63
+
64
+ if (deg > 1) {
65
+ float xx = x * x, yy = y * y, zz = z * z;
66
+ float xy = x * y, yz = y * z, xz = x * z;
67
+
68
+ float dRGBdsh4 = SH_C2[0] * xy;
69
+ float dRGBdsh5 = SH_C2[1] * yz;
70
+ float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
71
+ float dRGBdsh7 = SH_C2[3] * xz;
72
+ float dRGBdsh8 = SH_C2[4] * (xx - yy);
73
+ dL_dsh[4] = dRGBdsh4 * dL_dRGB;
74
+ dL_dsh[5] = dRGBdsh5 * dL_dRGB;
75
+ dL_dsh[6] = dRGBdsh6 * dL_dRGB;
76
+ dL_dsh[7] = dRGBdsh7 * dL_dRGB;
77
+ dL_dsh[8] = dRGBdsh8 * dL_dRGB;
78
+
79
+ dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] +
80
+ SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
81
+ dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] +
82
+ SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
83
+ dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] +
84
+ SH_C2[3] * x * sh[7];
85
+
86
+ if (deg > 2) {
87
+ float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
88
+ float dRGBdsh10 = SH_C3[1] * xy * z;
89
+ float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
90
+ float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
91
+ float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
92
+ float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
93
+ float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
94
+ dL_dsh[9] = dRGBdsh9 * dL_dRGB;
95
+ dL_dsh[10] = dRGBdsh10 * dL_dRGB;
96
+ dL_dsh[11] = dRGBdsh11 * dL_dRGB;
97
+ dL_dsh[12] = dRGBdsh12 * dL_dRGB;
98
+ dL_dsh[13] = dRGBdsh13 * dL_dRGB;
99
+ dL_dsh[14] = dRGBdsh14 * dL_dRGB;
100
+ dL_dsh[15] = dRGBdsh15 * dL_dRGB;
101
+
102
+ dRGBdx += (SH_C3[0] * sh[9] * 3.f * 2.f * xy + SH_C3[1] * sh[10] * yz +
103
+ SH_C3[2] * sh[11] * -2.f * xy +
104
+ SH_C3[3] * sh[12] * -3.f * 2.f * xz +
105
+ SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
106
+ SH_C3[5] * sh[14] * 2.f * xz +
107
+ SH_C3[6] * sh[15] * 3.f * (xx - yy));
108
+
109
+ dRGBdy +=
110
+ (SH_C3[0] * sh[9] * 3.f * (xx - yy) + SH_C3[1] * sh[10] * xz +
111
+ SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
112
+ SH_C3[3] * sh[12] * -3.f * 2.f * yz +
113
+ SH_C3[4] * sh[13] * -2.f * xy + SH_C3[5] * sh[14] * -2.f * yz +
114
+ SH_C3[6] * sh[15] * -3.f * 2.f * xy);
115
+
116
+ dRGBdz += (SH_C3[1] * sh[10] * xy + SH_C3[2] * sh[11] * 4.f * 2.f * yz +
117
+ SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
118
+ SH_C3[4] * sh[13] * 4.f * 2.f * xz +
119
+ SH_C3[5] * sh[14] * (xx - yy));
120
+ }
121
+ }
122
+ }
123
+
124
+ // The view direction is an input to the computation. View direction
125
+ // is influenced by the Gaussian's mean, so SHs gradients
126
+ // must propagate back into 3D position.
127
+ glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB),
128
+ glm::dot(dRGBdz, dL_dRGB));
129
+
130
+ // Account for normalization of direction
131
+ float3 dL_dmean = dnormvdv(float3{dir_orig.x, dir_orig.y, dir_orig.z},
132
+ float3{dL_ddir.x, dL_ddir.y, dL_ddir.z});
133
+
134
+ // Gradients of loss w.r.t. Gaussian means, but only the portion
135
+ // that is caused because the mean affects the view-dependent color.
136
+ // Additional mean gradient is accumulated in below methods.
137
+ dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
138
+ }
139
+
140
+ // Backward version of INVERSE 2D covariance matrix computation
141
+ // (due to length launched as separate kernel before other
142
+ // backward steps contained in preprocess)
143
+ __global__ void computeCov2DCUDA(int P, const float3 *means, const int *radii,
144
+ const float *cov3Ds, const float h_x,
145
+ float h_y, const float tan_fovx,
146
+ float tan_fovy, const float *view_matrix,
147
+ const float *dL_dconics, float3 *dL_dmeans,
148
+ float *dL_dcov) {
149
+ auto idx = cg::this_grid().thread_rank();
150
+ if (idx >= P || !(radii[idx] > 0))
151
+ return;
152
+
153
+ // Reading location of 3D covariance for this Gaussian
154
+ const float *cov3D = cov3Ds + 6 * idx;
155
+
156
+ // Fetch gradients, recompute 2D covariance and relevant
157
+ // intermediate forward results needed in the backward.
158
+ float3 mean = means[idx];
159
+ float3 dL_dconic = {dL_dconics[4 * idx], dL_dconics[4 * idx + 1],
160
+ dL_dconics[4 * idx + 3]};
161
+ float3 t = transformPoint4x3(mean, view_matrix);
162
+
163
+ const float limx = 1.3f * tan_fovx;
164
+ const float limy = 1.3f * tan_fovy;
165
+ const float txtz = t.x / t.z;
166
+ const float tytz = t.y / t.z;
167
+ t.x = min(limx, max(-limx, txtz)) * t.z;
168
+ t.y = min(limy, max(-limy, tytz)) * t.z;
169
+
170
+ const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
171
+ const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;
172
+
173
+ glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z), 0.0f,
174
+ h_y / t.z, -(h_y * t.y) / (t.z * t.z), 0, 0, 0);
175
+
176
+ glm::mat3 W = glm::mat3(view_matrix[0], view_matrix[4], view_matrix[8],
177
+ view_matrix[1], view_matrix[5], view_matrix[9],
178
+ view_matrix[2], view_matrix[6], view_matrix[10]);
179
+
180
+ glm::mat3 Vrk = glm::mat3(cov3D[0], cov3D[1], cov3D[2], cov3D[1], cov3D[3],
181
+ cov3D[4], cov3D[2], cov3D[4], cov3D[5]);
182
+
183
+ glm::mat3 T = W * J;
184
+
185
+ glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T;
186
+
187
+ // Use helper variables for 2D covariance entries. More compact.
188
+ float a = cov2D[0][0] += 0.3f;
189
+ float b = cov2D[0][1];
190
+ float c = cov2D[1][1] += 0.3f;
191
+
192
+ float denom = a * c - b * b;
193
+ float dL_da = 0, dL_db = 0, dL_dc = 0;
194
+ float denom2inv = 1.0f / ((denom * denom) + 0.0000001f);
195
+
196
+ if (denom2inv != 0) {
197
+ // Gradients of loss w.r.t. entries of 2D covariance matrix,
198
+ // given gradients of loss w.r.t. conic matrix (inverse covariance matrix).
199
+ // e.g., dL / da = dL / d_conic_a * d_conic_a / d_a
200
+ dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y +
201
+ (denom - a * c) * dL_dconic.z);
202
+ dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y +
203
+ (denom - a * c) * dL_dconic.x);
204
+ dL_db = denom2inv * 2 *
205
+ (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y +
206
+ a * b * dL_dconic.z);
207
+
208
+ // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
209
+ // given gradients w.r.t. 2D covariance matrix (diagonal).
210
+ // cov2D = transpose(T) * transpose(Vrk) * T;
211
+ dL_dcov[6 * idx + 0] =
212
+ (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db +
213
+ T[1][0] * T[1][0] * dL_dc);
214
+ dL_dcov[6 * idx + 3] =
215
+ (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db +
216
+ T[1][1] * T[1][1] * dL_dc);
217
+ dL_dcov[6 * idx + 5] =
218
+ (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db +
219
+ T[1][2] * T[1][2] * dL_dc);
220
+
221
+ // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
222
+ // given gradients w.r.t. 2D covariance matrix (off-diagonal).
223
+ // Off-diagonal elements appear twice --> double the gradient.
224
+ // cov2D = transpose(T) * transpose(Vrk) * T;
225
+ dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da +
226
+ (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db +
227
+ 2 * T[1][0] * T[1][1] * dL_dc;
228
+ dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da +
229
+ (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db +
230
+ 2 * T[1][0] * T[1][2] * dL_dc;
231
+ dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da +
232
+ (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db +
233
+ 2 * T[1][1] * T[1][2] * dL_dc;
234
+ } else {
235
+ for (int i = 0; i < 6; i++)
236
+ dL_dcov[6 * idx + i] = 0;
237
+ }
238
+
239
+ // Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T
240
+ // cov2D = transpose(T) * transpose(Vrk) * T;
241
+ float dL_dT00 =
242
+ 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) *
243
+ dL_da +
244
+ (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db;
245
+ float dL_dT01 =
246
+ 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) *
247
+ dL_da +
248
+ (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db;
249
+ float dL_dT02 =
250
+ 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) *
251
+ dL_da +
252
+ (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db;
253
+ float dL_dT10 =
254
+ 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) *
255
+ dL_dc +
256
+ (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db;
257
+ float dL_dT11 =
258
+ 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) *
259
+ dL_dc +
260
+ (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db;
261
+ float dL_dT12 =
262
+ 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) *
263
+ dL_dc +
264
+ (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db;
265
+
266
+ // Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix
267
+ // T = W * J
268
+ float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02;
269
+ float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02;
270
+ float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12;
271
+ float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12;
272
+
273
+ float tz = 1.f / t.z;
274
+ float tz2 = tz * tz;
275
+ float tz3 = tz2 * tz;
276
+
277
+ // Gradients of loss w.r.t. transformed Gaussian mean t
278
+ float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
279
+ float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
280
+ float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 +
281
+ (2 * h_x * t.x) * tz3 * dL_dJ02 +
282
+ (2 * h_y * t.y) * tz3 * dL_dJ12;
283
+
284
+ // Account for transformation of mean to t
285
+ // t = transformPoint4x3(mean, view_matrix);
286
+ float3 dL_dmean =
287
+ transformVec4x3Transpose({dL_dtx, dL_dty, dL_dtz}, view_matrix);
288
+
289
+ // Gradients of loss w.r.t. Gaussian means, but only the portion
290
+ // that is caused because the mean affects the covariance matrix.
291
+ // Additional mean gradient is accumulated in BACKWARD::preprocess.
292
+ dL_dmeans[idx] = dL_dmean;
293
+ }
294
+
295
+ // Backward pass for the conversion of scale and rotation to a
296
+ // 3D covariance matrix for each Gaussian.
297
+ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod,
298
+ const glm::vec4 rot, const float *dL_dcov3Ds,
299
+ glm::vec3 *dL_dscales, glm::vec4 *dL_drots) {
300
+ // Recompute (intermediate) results for the 3D covariance computation.
301
+ glm::vec4 q = rot; // / glm::length(rot);
302
+ float r = q.x;
303
+ float x = q.y;
304
+ float y = q.z;
305
+ float z = q.w;
306
+
307
+ glm::mat3 R = glm::mat3(1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z),
308
+ 2.f * (x * z + r * y), 2.f * (x * y + r * z),
309
+ 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
310
+ 2.f * (x * z - r * y), 2.f * (y * z + r * x),
311
+ 1.f - 2.f * (x * x + y * y));
312
+
313
+ glm::mat3 S = glm::mat3(1.0f);
314
+
315
+ glm::vec3 s = mod * scale;
316
+ S[0][0] = s.x;
317
+ S[1][1] = s.y;
318
+ S[2][2] = s.z;
319
+
320
+ glm::mat3 M = S * R;
321
+
322
+ const float *dL_dcov3D = dL_dcov3Ds + 6 * idx;
323
+
324
+ glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]);
325
+ glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]);
326
+
327
+ // Convert per-element covariance loss gradients to matrix form
328
+ glm::mat3 dL_dSigma =
329
+ glm::mat3(dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2],
330
+ 0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4],
331
+ 0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5]);
332
+
333
+ // Compute loss gradient w.r.t. matrix M
334
+ // dSigma_dM = 2 * M
335
+ glm::mat3 dL_dM = 2.0f * M * dL_dSigma;
336
+
337
+ glm::mat3 Rt = glm::transpose(R);
338
+ glm::mat3 dL_dMt = glm::transpose(dL_dM);
339
+
340
+ // Gradients of loss w.r.t. scale
341
+ glm::vec3 *dL_dscale = dL_dscales + idx;
342
+ dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
343
+ dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
344
+ dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);
345
+
346
+ dL_dMt[0] *= s.x;
347
+ dL_dMt[1] *= s.y;
348
+ dL_dMt[2] *= s.z;
349
+
350
+ // Gradients of loss w.r.t. normalized quaternion
351
+ glm::vec4 dL_dq;
352
+ dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) +
353
+ 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) +
354
+ 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
355
+ dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) +
356
+ 2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) +
357
+ 2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) -
358
+ 4 * x * (dL_dMt[2][2] + dL_dMt[1][1]);
359
+ dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) +
360
+ 2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) +
361
+ 2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) -
362
+ 4 * y * (dL_dMt[2][2] + dL_dMt[0][0]);
363
+ dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) +
364
+ 2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) +
365
+ 2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) -
366
+ 4 * z * (dL_dMt[1][1] + dL_dMt[0][0]);
367
+
368
+ // Gradients of loss w.r.t. unnormalized quaternion
369
+ float4 *dL_drot = (float4 *)(dL_drots + idx);
370
+ *dL_drot = float4{dL_dq.x, dL_dq.y, dL_dq.z,
371
+ dL_dq.w}; // dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w },
372
+ // float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
373
+ }
374
+
375
+ // Backward pass of the preprocessing steps, except
376
+ // for the covariance computation and inversion
377
+ // (those are handled by a previous kernel call)
378
+ template <int C>
379
+ __global__ void
380
+ preprocessCUDA(int P, int D, int M, const float3 *means, const int *radii,
381
+ const float *shs, const bool *clamped, const glm::vec3 *scales,
382
+ const glm::vec4 *rotations, const float scale_modifier,
383
+ const float *proj, const glm::vec3 *campos,
384
+ const float3 *dL_dmean2D, glm::vec3 *dL_dmeans, float *dL_dcolor,
385
+ float *dL_dcov3D, float *dL_dsh, glm::vec3 *dL_dscale,
386
+ glm::vec4 *dL_drot) {
387
+ auto idx = cg::this_grid().thread_rank();
388
+ if (idx >= P || !(radii[idx] > 0))
389
+ return;
390
+
391
+ float3 m = means[idx];
392
+
393
+ // Taking care of gradients from the screenspace points
394
+ float4 m_hom = transformPoint4x4(m, proj);
395
+ float m_w = 1.0f / (m_hom.w + 0.0000001f);
396
+
397
+ // Compute loss gradient w.r.t. 3D means due to gradients of 2D means
398
+ // from rendering procedure
399
+ glm::vec3 dL_dmean;
400
+ float mul1 =
401
+ (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w;
402
+ float mul2 =
403
+ (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w;
404
+ dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x +
405
+ (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y;
406
+ dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x +
407
+ (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y;
408
+ dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x +
409
+ (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y;
410
+
411
+ // That's the second part of the mean gradient. Previous computation
412
+ // of cov2D and following SH conversion also affects it.
413
+ dL_dmeans[idx] += dL_dmean;
414
+
415
+ // Compute gradient updates due to computing colors from SHs
416
+ if (shs)
417
+ computeColorFromSH(idx, D, M, (glm::vec3 *)means, *campos, shs, clamped,
418
+ (glm::vec3 *)dL_dcolor, (glm::vec3 *)dL_dmeans,
419
+ (glm::vec3 *)dL_dsh);
420
+
421
+ // Compute gradient updates due to computing covariance from scale/rotation
422
+ if (scales)
423
+ computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D,
424
+ dL_dscale, dL_drot);
425
+ }
426
+
427
+ // Backward version of the rendering procedure.
428
+ template <uint32_t C>
429
+ __global__ void __launch_bounds__(BLOCK_X *BLOCK_Y) renderCUDA(
430
+ const uint2 *__restrict__ ranges, const uint32_t *__restrict__ point_list,
431
+ int W, int H, const float *__restrict__ bg_color,
432
+ const float2 *__restrict__ points_xy_image,
433
+ const float4 *__restrict__ conic_opacity, const float *__restrict__ colors,
434
+ const float *__restrict__ final_Ts, const uint32_t *__restrict__ n_contrib,
435
+ const float *__restrict__ dL_dpixels, float3 *__restrict__ dL_dmean2D,
436
+ float4 *__restrict__ dL_dconic2D, float *__restrict__ dL_dopacity,
437
+ float *__restrict__ dL_dcolors) {
438
+ // We rasterize again. Compute necessary block info.
439
+ auto block = cg::this_thread_block();
440
+ const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
441
+ const uint2 pix_min = {block.group_index().x * BLOCK_X,
442
+ block.group_index().y * BLOCK_Y};
443
+ const uint2 pix_max = {min(pix_min.x + BLOCK_X, W),
444
+ min(pix_min.y + BLOCK_Y, H)};
445
+ const uint2 pix = {pix_min.x + block.thread_index().x,
446
+ pix_min.y + block.thread_index().y};
447
+ const uint32_t pix_id = W * pix.y + pix.x;
448
+ const float2 pixf = {(float)pix.x, (float)pix.y};
449
+
450
+ const bool inside = pix.x < W && pix.y < H;
451
+ const uint2 range =
452
+ ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
453
+
454
+ const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
455
+
456
+ bool done = !inside;
457
+ int toDo = range.y - range.x;
458
+
459
+ __shared__ int collected_id[BLOCK_SIZE];
460
+ __shared__ float2 collected_xy[BLOCK_SIZE];
461
+ __shared__ float4 collected_conic_opacity[BLOCK_SIZE];
462
+ __shared__ float collected_colors[C * BLOCK_SIZE];
463
+
464
+ // In the forward, we stored the final value for T, the
465
+ // product of all (1 - alpha) factors.
466
+ const float T_final = inside ? final_Ts[pix_id] : 0;
467
+ float T = T_final;
468
+
469
+ // We start from the back. The ID of the last contributing
470
+ // Gaussian is known from each pixel from the forward.
471
+ uint32_t contributor = toDo;
472
+ const int last_contributor = inside ? n_contrib[pix_id] : 0;
473
+
474
+ float accum_rec[C] = {0};
475
+ float dL_dpixel[C];
476
+ if (inside)
477
+ for (int i = 0; i < C; i++)
478
+ dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
479
+
480
+ float last_alpha = 0;
481
+ float last_color[C] = {0};
482
+
483
+ // Gradient of pixel coordinate w.r.t. normalized
484
+ // screen-space viewport corrdinates (-1 to 1)
485
+ const float ddelx_dx = 0.5 * W;
486
+ const float ddely_dy = 0.5 * H;
487
+
488
+ // Traverse all Gaussians
489
+ for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) {
490
+ // Load auxiliary data into shared memory, start in the BACK
491
+ // and load them in revers order.
492
+ block.sync();
493
+ const int progress = i * BLOCK_SIZE + block.thread_rank();
494
+ if (range.x + progress < range.y) {
495
+ const int coll_id = point_list[range.y - progress - 1];
496
+ collected_id[block.thread_rank()] = coll_id;
497
+ collected_xy[block.thread_rank()] = points_xy_image[coll_id];
498
+ collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
499
+ for (int i = 0; i < C; i++)
500
+ collected_colors[i * BLOCK_SIZE + block.thread_rank()] =
501
+ colors[coll_id * C + i];
502
+ }
503
+ block.sync();
504
+
505
+ // Iterate over Gaussians
506
+ for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) {
507
+ // Keep track of current Gaussian ID. Skip, if this one
508
+ // is behind the last contributor for this pixel.
509
+ contributor--;
510
+ if (contributor >= last_contributor)
511
+ continue;
512
+
513
+ // Compute blending values, as before.
514
+ const float2 xy = collected_xy[j];
515
+ const float2 d = {xy.x - pixf.x, xy.y - pixf.y};
516
+ const float4 con_o = collected_conic_opacity[j];
517
+ const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) -
518
+ con_o.y * d.x * d.y;
519
+ if (power > 0.0f)
520
+ continue;
521
+
522
+ const float G = exp(power);
523
+ const float alpha = min(0.99f, con_o.w * G);
524
+ if (alpha < 1.0f / 255.0f)
525
+ continue;
526
+
527
+ T = T / (1.f - alpha);
528
+ const float dchannel_dcolor = alpha * T;
529
+
530
+ // Propagate gradients to per-Gaussian colors and keep
531
+ // gradients w.r.t. alpha (blending factor for a Gaussian/pixel
532
+ // pair).
533
+ float dL_dalpha = 0.0f;
534
+ const int global_id = collected_id[j];
535
+ for (int ch = 0; ch < C; ch++) {
536
+ const float c = collected_colors[ch * BLOCK_SIZE + j];
537
+ // Update last color (to be used in the next iteration)
538
+ accum_rec[ch] =
539
+ last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
540
+ last_color[ch] = c;
541
+
542
+ const float dL_dchannel = dL_dpixel[ch];
543
+ dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
544
+ // Update the gradients w.r.t. color of the Gaussian.
545
+ // Atomic, since this pixel is just one of potentially
546
+ // many that were affected by this Gaussian.
547
+ atomicAdd(&(dL_dcolors[global_id * C + ch]),
548
+ dchannel_dcolor * dL_dchannel);
549
+ }
550
+ dL_dalpha *= T;
551
+ // Update last alpha (to be used in the next iteration)
552
+ last_alpha = alpha;
553
+
554
+ // Account for fact that alpha also influences how much of
555
+ // the background color is added if nothing left to blend
556
+ float bg_dot_dpixel = 0;
557
+ for (int i = 0; i < C; i++)
558
+ bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
559
+ dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;
560
+
561
+ // Helpful reusable temporary variables
562
+ const float dL_dG = con_o.w * dL_dalpha;
563
+ const float gdx = G * d.x;
564
+ const float gdy = G * d.y;
565
+ const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
566
+ const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;
567
+
568
+ // Update gradients w.r.t. 2D mean position of the Gaussian
569
+ atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
570
+ atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);
571
+
572
+ // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
573
+ atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
574
+ atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
575
+ atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
576
+
577
+ // Update gradients w.r.t. opacity of the Gaussian
578
+ atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha);
579
+ }
580
+ }
581
+ }
582
+
583
+ void BACKWARD::preprocess(
584
+ int P, int D, int M, const float3 *means3D, const int *radii,
585
+ const float *shs, const bool *clamped, const glm::vec3 *scales,
586
+ const glm::vec4 *rotations, const float scale_modifier, const float *cov3Ds,
587
+ const float *viewmatrix, const float *projmatrix, const float focal_x,
588
+ float focal_y, const float tan_fovx, float tan_fovy,
589
+ const glm::vec3 *campos, const float3 *dL_dmean2D, const float *dL_dconic,
590
+ glm::vec3 *dL_dmean3D, float *dL_dcolor, float *dL_dcov3D, float *dL_dsh,
591
+ glm::vec3 *dL_dscale, glm::vec4 *dL_drot) {
592
+ // Propagate gradients for the path of 2D conic matrix computation.
593
+ // Somewhat long, thus it is its own kernel rather than being part of
594
+ // "preprocess". When done, loss gradient w.r.t. 3D means has been
595
+ // modified and gradient w.r.t. 3D covariance matrix has been computed.
596
+ computeCov2DCUDA<<<(P + 255) / 256, 256>>>(
597
+ P, means3D, radii, cov3Ds, focal_x, focal_y, tan_fovx, tan_fovy,
598
+ viewmatrix, dL_dconic, (float3 *)dL_dmean3D, dL_dcov3D);
599
+
600
+ // Propagate gradients for remaining steps: finish 3D mean gradients,
601
+ // propagate color gradients to SH (if desireD), propagate 3D covariance
602
+ // matrix gradients to scale and rotation.
603
+ preprocessCUDA<NUM_CHANNELS><<<(P + 255) / 256, 256>>>(
604
+ P, D, M, (float3 *)means3D, radii, shs, clamped, (glm::vec3 *)scales,
605
+ (glm::vec4 *)rotations, scale_modifier, projmatrix, campos,
606
+ (float3 *)dL_dmean2D, (glm::vec3 *)dL_dmean3D, dL_dcolor, dL_dcov3D,
607
+ dL_dsh, dL_dscale, dL_drot);
608
+ }
609
+
610
+ void BACKWARD::render(const dim3 grid, const dim3 block, const uint2 *ranges,
611
+ const uint32_t *point_list, int W, int H,
612
+ const float *bg_color, const float2 *means2D,
613
+ const float4 *conic_opacity, const float *colors,
614
+ const float *final_Ts, const uint32_t *n_contrib,
615
+ const float *dL_dpixels, float3 *dL_dmean2D,
616
+ float4 *dL_dconic2D, float *dL_dopacity,
617
+ float *dL_dcolors) {
618
+ renderCUDA<NUM_CHANNELS>
619
+ <<<grid, block>>>(ranges, point_list, W, H, bg_color, means2D,
620
+ conic_opacity, colors, final_Ts, n_contrib, dL_dpixels,
621
+ dL_dmean2D, dL_dconic2D, dL_dopacity, dL_dcolors);
622
+ }
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/backward.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED
13
+ #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED
14
+
15
+ #include "cuda_runtime.h"
16
+ #include "device_launch_parameters.h"
17
+ #include <cuda.h>
18
+ #define GLM_FORCE_CUDA
19
+ #include <glm/glm.hpp>
20
+
21
+ namespace BACKWARD {
22
+ void render(const dim3 grid, dim3 block, const uint2 *ranges,
23
+ const uint32_t *point_list, int W, int H, const float *bg_color,
24
+ const float2 *means2D, const float4 *conic_opacity,
25
+ const float *colors, const float *final_Ts,
26
+ const uint32_t *n_contrib, const float *dL_dpixels,
27
+ float3 *dL_dmean2D, float4 *dL_dconic2D, float *dL_dopacity,
28
+ float *dL_dcolors);
29
+
30
+ void preprocess(int P, int D, int M, const float3 *means, const int *radii,
31
+ const float *shs, const bool *clamped, const glm::vec3 *scales,
32
+ const glm::vec4 *rotations, const float scale_modifier,
33
+ const float *cov3Ds, const float *view, const float *proj,
34
+ const float focal_x, float focal_y, const float tan_fovx,
35
+ float tan_fovy, const glm::vec3 *campos,
36
+ const float3 *dL_dmean2D, const float *dL_dconics,
37
+ glm::vec3 *dL_dmeans, float *dL_dcolor, float *dL_dcov3D,
38
+ float *dL_dsh, glm::vec3 *dL_dscale, glm::vec4 *dL_drot);
39
+ } // namespace BACKWARD
40
+
41
+ #endif
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/config.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED
13
+ #define CUDA_RASTERIZER_CONFIG_H_INCLUDED
14
+
15
+ #define NUM_CHANNELS 3 // Default 3, RGB
16
+ #define BLOCK_X 16
17
+ #define BLOCK_Y 16
18
+
19
+ #endif
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/forward.cu ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include "auxiliary.h"
13
+ #include "forward.h"
14
+ #include <cooperative_groups.h>
15
+ #include <cooperative_groups/reduce.h>
16
+ namespace cg = cooperative_groups;
17
+
18
+ // Forward method for converting the input spherical harmonics
19
+ // coefficients of each Gaussian to a simple RGB color.
20
+ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs,
21
+ const glm::vec3 *means,
22
+ glm::vec3 campos, const float *shs,
23
+ bool *clamped) {
24
+ // The implementation is loosely based on code for
25
+ // "Differentiable Point-Based Radiance Fields for
26
+ // Efficient View Synthesis" by Zhang et al. (2022)
27
+ glm::vec3 pos = means[idx];
28
+ glm::vec3 dir = pos - campos;
29
+ dir = dir / glm::length(dir);
30
+
31
+ glm::vec3 *sh = ((glm::vec3 *)shs) + idx * max_coeffs;
32
+ glm::vec3 result = SH_C0 * sh[0];
33
+
34
+ if (deg > 0) {
35
+ float x = dir.x;
36
+ float y = dir.y;
37
+ float z = dir.z;
38
+ result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3];
39
+
40
+ if (deg > 1) {
41
+ float xx = x * x, yy = y * y, zz = z * z;
42
+ float xy = x * y, yz = y * z, xz = x * z;
43
+ result = result + SH_C2[0] * xy * sh[4] + SH_C2[1] * yz * sh[5] +
44
+ SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] +
45
+ SH_C2[3] * xz * sh[7] + SH_C2[4] * (xx - yy) * sh[8];
46
+
47
+ if (deg > 2) {
48
+ result = result + SH_C3[0] * y * (3.0f * xx - yy) * sh[9] +
49
+ SH_C3[1] * xy * z * sh[10] +
50
+ SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] +
51
+ SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] +
52
+ SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] +
53
+ SH_C3[5] * z * (xx - yy) * sh[14] +
54
+ SH_C3[6] * x * (xx - 3.0f * yy) * sh[15];
55
+ }
56
+ }
57
+ }
58
+ result += 0.5f;
59
+
60
+ // RGB colors are clamped to positive values. If values are
61
+ // clamped, we need to keep track of this for the backward pass.
62
+ clamped[3 * idx + 0] = (result.x < 0);
63
+ clamped[3 * idx + 1] = (result.y < 0);
64
+ clamped[3 * idx + 2] = (result.z < 0);
65
+ return glm::max(result, 0.0f);
66
+ }
67
+
68
+ // Forward version of 2D covariance matrix computation
69
+ __device__ float3 computeCov2D(const float3 &mean, float focal_x, float focal_y,
70
+ float tan_fovx, float tan_fovy,
71
+ const float *cov3D, const float *viewmatrix) {
72
+ // The following models the steps outlined by equations 29
73
+ // and 31 in "EWA Splatting" (Zwicker et al., 2002).
74
+ // Additionally considers aspect / scaling of viewport.
75
+ // Transposes used to account for row-/column-major conventions.
76
+ float3 t = transformPoint4x3(mean, viewmatrix);
77
+
78
+ const float limx = 1.3f * tan_fovx;
79
+ const float limy = 1.3f * tan_fovy;
80
+ const float txtz = t.x / t.z;
81
+ const float tytz = t.y / t.z;
82
+ t.x = min(limx, max(-limx, txtz)) * t.z;
83
+ t.y = min(limy, max(-limy, tytz)) * t.z;
84
+
85
+ glm::mat3 J =
86
+ glm::mat3(focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z), 0.0f,
87
+ focal_y / t.z, -(focal_y * t.y) / (t.z * t.z), 0, 0, 0);
88
+
89
+ glm::mat3 W = glm::mat3(viewmatrix[0], viewmatrix[4], viewmatrix[8],
90
+ viewmatrix[1], viewmatrix[5], viewmatrix[9],
91
+ viewmatrix[2], viewmatrix[6], viewmatrix[10]);
92
+
93
+ glm::mat3 T = W * J;
94
+
95
+ glm::mat3 Vrk = glm::mat3(cov3D[0], cov3D[1], cov3D[2], cov3D[1], cov3D[3],
96
+ cov3D[4], cov3D[2], cov3D[4], cov3D[5]);
97
+
98
+ glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T;
99
+
100
+ // Apply low-pass filter: every Gaussian should be at least
101
+ // one pixel wide/high. Discard 3rd row and column.
102
+ cov[0][0] += 0.3f;
103
+ cov[1][1] += 0.3f;
104
+ return {float(cov[0][0]), float(cov[0][1]), float(cov[1][1])};
105
+ }
106
+
107
+ // Forward method for converting scale and rotation properties of each
108
+ // Gaussian to a 3D covariance matrix in world space. Also takes care
109
+ // of quaternion normalization.
110
+ __device__ void computeCov3D(const glm::vec3 scale, float mod,
111
+ const glm::vec4 rot, float *cov3D) {
112
+ // Create scaling matrix
113
+ glm::mat3 S = glm::mat3(1.0f);
114
+ S[0][0] = mod * scale.x;
115
+ S[1][1] = mod * scale.y;
116
+ S[2][2] = mod * scale.z;
117
+
118
+ // Normalize quaternion to get valid rotation
119
+ glm::vec4 q = rot; // / glm::length(rot);
120
+ float r = q.x;
121
+ float x = q.y;
122
+ float y = q.z;
123
+ float z = q.w;
124
+
125
+ // Compute rotation matrix from quaternion
126
+ glm::mat3 R = glm::mat3(1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z),
127
+ 2.f * (x * z + r * y), 2.f * (x * y + r * z),
128
+ 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
129
+ 2.f * (x * z - r * y), 2.f * (y * z + r * x),
130
+ 1.f - 2.f * (x * x + y * y));
131
+
132
+ glm::mat3 M = S * R;
133
+
134
+ // Compute 3D world covariance matrix Sigma
135
+ glm::mat3 Sigma = glm::transpose(M) * M;
136
+
137
+ // Covariance is symmetric, only store upper right
138
+ cov3D[0] = Sigma[0][0];
139
+ cov3D[1] = Sigma[0][1];
140
+ cov3D[2] = Sigma[0][2];
141
+ cov3D[3] = Sigma[1][1];
142
+ cov3D[4] = Sigma[1][2];
143
+ cov3D[5] = Sigma[2][2];
144
+ }
145
+
146
+ // Perform initial steps for each Gaussian prior to rasterization.
147
+ template <int C>
148
+ __global__ void
149
+ preprocessCUDA(int P, int D, int M, const float *orig_points,
150
+ const glm::vec3 *scales, const float scale_modifier,
151
+ const glm::vec4 *rotations, const float *opacities,
152
+ const float *shs, bool *clamped, const float *cov3D_precomp,
153
+ const float *colors_precomp, const float *viewmatrix,
154
+ const float *projmatrix, const glm::vec3 *cam_pos, const int W,
155
+ int H, const float tan_fovx, float tan_fovy, const float focal_x,
156
+ float focal_y, int *radii, float2 *points_xy_image,
157
+ float *depths, float *cov3Ds, float *rgb, float4 *conic_opacity,
158
+ const dim3 grid, uint32_t *tiles_touched, bool prefiltered) {
159
+ auto idx = cg::this_grid().thread_rank();
160
+ if (idx >= P)
161
+ return;
162
+
163
+ // Initialize radius and touched tiles to 0. If this isn't changed,
164
+ // this Gaussian will not be processed further.
165
+ radii[idx] = 0;
166
+ tiles_touched[idx] = 0;
167
+
168
+ // Perform near culling, quit if outside.
169
+ float3 p_view;
170
+ if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered,
171
+ p_view))
172
+ return;
173
+
174
+ // Transform point by projecting
175
+ float3 p_orig = {orig_points[3 * idx], orig_points[3 * idx + 1],
176
+ orig_points[3 * idx + 2]};
177
+ float4 p_hom = transformPoint4x4(p_orig, projmatrix);
178
+ float p_w = 1.0f / (p_hom.w + 0.0000001f);
179
+ float3 p_proj = {p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w};
180
+
181
+ // If 3D covariance matrix is precomputed, use it, otherwise compute
182
+ // from scaling and rotation parameters.
183
+ const float *cov3D;
184
+ if (cov3D_precomp != nullptr) {
185
+ cov3D = cov3D_precomp + idx * 6;
186
+ } else {
187
+ computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
188
+ cov3D = cov3Ds + idx * 6;
189
+ }
190
+
191
+ // Compute 2D screen-space covariance matrix
192
+ float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D,
193
+ viewmatrix);
194
+
195
+ // Invert covariance (EWA algorithm)
196
+ float det = (cov.x * cov.z - cov.y * cov.y);
197
+ if (det == 0.0f)
198
+ return;
199
+ float det_inv = 1.f / det;
200
+ float3 conic = {cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv};
201
+
202
+ // Compute extent in screen space (by finding eigenvalues of
203
+ // 2D covariance matrix). Use extent to compute a bounding rectangle
204
+ // of screen-space tiles that this Gaussian overlaps with. Quit if
205
+ // rectangle covers 0 tiles.
206
+ float mid = 0.5f * (cov.x + cov.z);
207
+ float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
208
+ float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
209
+ float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
210
+ float2 point_image = {ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H)};
211
+ uint2 rect_min, rect_max;
212
+ getRect(point_image, my_radius, rect_min, rect_max, grid);
213
+ if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
214
+ return;
215
+
216
+ // If colors have been precomputed, use them, otherwise convert
217
+ // spherical harmonics coefficients to RGB color.
218
+ if (colors_precomp == nullptr) {
219
+ glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3 *)orig_points,
220
+ *cam_pos, shs, clamped);
221
+ rgb[idx * C + 0] = result.x;
222
+ rgb[idx * C + 1] = result.y;
223
+ rgb[idx * C + 2] = result.z;
224
+ }
225
+
226
+ // Store some useful helper data for the next steps.
227
+ depths[idx] = p_view.z;
228
+ radii[idx] = my_radius;
229
+ points_xy_image[idx] = point_image;
230
+ // Inverse 2D covariance and opacity neatly pack into one float4
231
+ conic_opacity[idx] = {conic.x, conic.y, conic.z, opacities[idx]};
232
+ tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
233
+ }
234
+
235
+ // Main rasterization method. Collaboratively works on one tile per
236
+ // block, each thread treats one pixel. Alternates between fetching
237
+ // and rasterizing data.
238
+ template <uint32_t CHANNELS>
239
+ __global__ void __launch_bounds__(BLOCK_X *BLOCK_Y)
240
+ renderCUDA(const uint2 *__restrict__ ranges,
241
+ const uint32_t *__restrict__ point_list, int W, int H,
242
+ const float2 *__restrict__ points_xy_image,
243
+ const float *__restrict__ features,
244
+ const float4 *__restrict__ conic_opacity,
245
+ float *__restrict__ final_T, uint32_t *__restrict__ n_contrib,
246
+ const float *__restrict__ bg_color,
247
+ float *__restrict__ out_color) {
248
+ // Identify current tile and associated min/max pixel range.
249
+ auto block = cg::this_thread_block();
250
+ uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
251
+ uint2 pix_min = {block.group_index().x * BLOCK_X,
252
+ block.group_index().y * BLOCK_Y};
253
+ uint2 pix_max = {min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y, H)};
254
+ uint2 pix = {pix_min.x + block.thread_index().x,
255
+ pix_min.y + block.thread_index().y};
256
+ uint32_t pix_id = W * pix.y + pix.x;
257
+ float2 pixf = {(float)pix.x, (float)pix.y};
258
+
259
+ // Check if this thread is associated with a valid pixel or outside.
260
+ bool inside = pix.x < W && pix.y < H;
261
+ // Done threads can help with fetching, but don't rasterize
262
+ bool done = !inside;
263
+
264
+ // Load start/end range of IDs to process in bit sorted list.
265
+ uint2 range =
266
+ ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
267
+ const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
268
+ int toDo = range.y - range.x;
269
+
270
+ // Allocate storage for batches of collectively fetched data.
271
+ __shared__ int collected_id[BLOCK_SIZE];
272
+ __shared__ float2 collected_xy[BLOCK_SIZE];
273
+ __shared__ float4 collected_conic_opacity[BLOCK_SIZE];
274
+
275
+ // Initialize helper variables
276
+ float T = 1.0f;
277
+ uint32_t contributor = 0;
278
+ uint32_t last_contributor = 0;
279
+ float C[CHANNELS] = {0};
280
+
281
+ // Iterate over batches until all done or range is complete
282
+ for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) {
283
+ // End if entire block votes that it is done rasterizing
284
+ int num_done = __syncthreads_count(done);
285
+ if (num_done == BLOCK_SIZE)
286
+ break;
287
+
288
+ // Collectively fetch per-Gaussian data from global to shared
289
+ int progress = i * BLOCK_SIZE + block.thread_rank();
290
+ if (range.x + progress < range.y) {
291
+ int coll_id = point_list[range.x + progress];
292
+ collected_id[block.thread_rank()] = coll_id;
293
+ collected_xy[block.thread_rank()] = points_xy_image[coll_id];
294
+ collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
295
+ }
296
+ block.sync();
297
+
298
+ // Iterate over current batch
299
+ for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) {
300
+ // Keep track of current position in range
301
+ contributor++;
302
+
303
+ // Resample using conic matrix (cf. "Surface
304
+ // Splatting" by Zwicker et al., 2001)
305
+ float2 xy = collected_xy[j];
306
+ float2 d = {xy.x - pixf.x, xy.y - pixf.y};
307
+ float4 con_o = collected_conic_opacity[j];
308
+ float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) -
309
+ con_o.y * d.x * d.y;
310
+ if (power > 0.0f)
311
+ continue;
312
+
313
+ // Eq. (2) from 3D Gaussian splatting paper.
314
+ // Obtain alpha by multiplying with Gaussian opacity
315
+ // and its exponential falloff from mean.
316
+ // Avoid numerical instabilities (see paper appendix).
317
+ float alpha = min(0.99f, con_o.w * exp(power));
318
+ if (alpha < 1.0f / 255.0f)
319
+ continue;
320
+ float test_T = T * (1 - alpha);
321
+ if (test_T < 0.0001f) {
322
+ done = true;
323
+ continue;
324
+ }
325
+
326
+ // Eq. (3) from 3D Gaussian splatting paper.
327
+ for (int ch = 0; ch < CHANNELS; ch++)
328
+ C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
329
+
330
+ T = test_T;
331
+
332
+ // Keep track of last range entry to update this
333
+ // pixel.
334
+ last_contributor = contributor;
335
+ }
336
+ }
337
+
338
+ // All threads that treat valid pixel write out their final
339
+ // rendering data to the frame and auxiliary buffers.
340
+ if (inside) {
341
+ final_T[pix_id] = T;
342
+ n_contrib[pix_id] = last_contributor;
343
+ for (int ch = 0; ch < CHANNELS; ch++)
344
+ out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
345
+ }
346
+ }
347
+
348
+ void FORWARD::render(const dim3 grid, dim3 block, const uint2 *ranges,
349
+ const uint32_t *point_list, int W, int H,
350
+ const float2 *means2D, const float *colors,
351
+ const float4 *conic_opacity, float *final_T,
352
+ uint32_t *n_contrib, const float *bg_color,
353
+ float *out_color) {
354
+ renderCUDA<NUM_CHANNELS><<<grid, block>>>(ranges, point_list, W, H, means2D,
355
+ colors, conic_opacity, final_T,
356
+ n_contrib, bg_color, out_color);
357
+ }
358
+
359
+ void FORWARD::preprocess(int P, int D, int M, const float *means3D,
360
+ const glm::vec3 *scales, const float scale_modifier,
361
+ const glm::vec4 *rotations, const float *opacities,
362
+ const float *shs, bool *clamped,
363
+ const float *cov3D_precomp,
364
+ const float *colors_precomp, const float *viewmatrix,
365
+ const float *projmatrix, const glm::vec3 *cam_pos,
366
+ const int W, int H, const float focal_x, float focal_y,
367
+ const float tan_fovx, float tan_fovy, int *radii,
368
+ float2 *means2D, float *depths, float *cov3Ds,
369
+ float *rgb, float4 *conic_opacity, const dim3 grid,
370
+ uint32_t *tiles_touched, bool prefiltered) {
371
+ preprocessCUDA<NUM_CHANNELS><<<(P + 255) / 256, 256>>>(
372
+ P, D, M, means3D, scales, scale_modifier, rotations, opacities, shs,
373
+ clamped, cov3D_precomp, colors_precomp, viewmatrix, projmatrix, cam_pos,
374
+ W, H, tan_fovx, tan_fovy, focal_x, focal_y, radii, means2D, depths,
375
+ cov3Ds, rgb, conic_opacity, grid, tiles_touched, prefiltered);
376
+ }
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/forward.h ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED
13
+ #define CUDA_RASTERIZER_FORWARD_H_INCLUDED
14
+
15
+ #include "cuda_runtime.h"
16
+ #include "device_launch_parameters.h"
17
+ #include <cuda.h>
18
+ #define GLM_FORCE_CUDA
19
+ #include <glm/glm.hpp>
20
+
21
+ namespace FORWARD {
22
+ // Perform initial steps for each Gaussian prior to rasterization.
23
+ void preprocess(int P, int D, int M, const float *orig_points,
24
+ const glm::vec3 *scales, const float scale_modifier,
25
+ const glm::vec4 *rotations, const float *opacities,
26
+ const float *shs, bool *clamped, const float *cov3D_precomp,
27
+ const float *colors_precomp, const float *viewmatrix,
28
+ const float *projmatrix, const glm::vec3 *cam_pos, const int W,
29
+ int H, const float focal_x, float focal_y, const float tan_fovx,
30
+ float tan_fovy, int *radii, float2 *points_xy_image,
31
+ float *depths, float *cov3Ds, float *colors,
32
+ float4 *conic_opacity, const dim3 grid, uint32_t *tiles_touched,
33
+ bool prefiltered);
34
+
35
+ // Main rasterization method.
36
+ void render(const dim3 grid, dim3 block, const uint2 *ranges,
37
+ const uint32_t *point_list, int W, int H,
38
+ const float2 *points_xy_image, const float *features,
39
+ const float4 *conic_opacity, float *final_T, uint32_t *n_contrib,
40
+ const float *bg_color, float *out_color);
41
+ } // namespace FORWARD
42
+
43
+ #endif
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer.h ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_H_INCLUDED
13
+ #define CUDA_RASTERIZER_H_INCLUDED
14
+
15
+ #include <functional>
16
+ #include <vector>
17
+
18
+ namespace CudaRasterizer {
19
+ class Rasterizer {
20
+ public:
21
+ static void markVisible(int P, float *means3D, float *viewmatrix,
22
+ float *projmatrix, bool *present);
23
+
24
+ static int forward(std::function<char *(size_t)> geometryBuffer,
25
+ std::function<char *(size_t)> binningBuffer,
26
+ std::function<char *(size_t)> imageBuffer, const int P,
27
+ int D, int M, const float *background, const int width,
28
+ int height, const float *means3D, const float *shs,
29
+ const float *colors_precomp, const float *opacities,
30
+ const float *scales, const float scale_modifier,
31
+ const float *rotations, const float *cov3D_precomp,
32
+ const float *viewmatrix, const float *projmatrix,
33
+ const float *cam_pos, const float tan_fovx, float tan_fovy,
34
+ const bool prefiltered, float *out_color,
35
+ int *radii = nullptr, bool debug = false);
36
+
37
+ static void
38
+ backward(const int P, int D, int M, int R, const float *background,
39
+ const int width, int height, const float *means3D, const float *shs,
40
+ const float *colors_precomp, const float *scales,
41
+ const float scale_modifier, const float *rotations,
42
+ const float *cov3D_precomp, const float *viewmatrix,
43
+ const float *projmatrix, const float *campos, const float tan_fovx,
44
+ float tan_fovy, const int *radii, char *geom_buffer,
45
+ char *binning_buffer, char *image_buffer, const float *dL_dpix,
46
+ float *dL_dmean2D, float *dL_dconic, float *dL_dopacity,
47
+ float *dL_dcolor, float *dL_dmean3D, float *dL_dcov3D, float *dL_dsh,
48
+ float *dL_dscale, float *dL_drot, bool debug);
49
+ };
50
+ }; // namespace CudaRasterizer
51
+
52
+ #endif
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer_impl.cu ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include "cuda_runtime.h"
13
+ #include "device_launch_parameters.h"
14
+ #include "rasterizer_impl.h"
15
+ #include <algorithm>
16
+ #include <cub/cub.cuh>
17
+ #include <cub/device/device_radix_sort.cuh>
18
+ #include <cuda.h>
19
+ #include <fstream>
20
+ #include <iostream>
21
+ #include <numeric>
22
+ #define GLM_FORCE_CUDA
23
+ #include <glm/glm.hpp>
24
+
25
+ #include <cooperative_groups.h>
26
+ #include <cooperative_groups/reduce.h>
27
+ namespace cg = cooperative_groups;
28
+
29
+ #include "auxiliary.h"
30
+ #include "backward.h"
31
+ #include "forward.h"
32
+
33
+ // Helper function to find the next-highest bit of the MSB
34
+ // on the CPU.
35
+ uint32_t getHigherMsb(uint32_t n) {
36
+ uint32_t msb = sizeof(n) * 4;
37
+ uint32_t step = msb;
38
+ while (step > 1) {
39
+ step /= 2;
40
+ if (n >> msb)
41
+ msb += step;
42
+ else
43
+ msb -= step;
44
+ }
45
+ if (n >> msb)
46
+ msb++;
47
+ return msb;
48
+ }
49
+
50
+ // Wrapper method to call auxiliary coarse frustum containment test.
51
+ // Mark all Gaussians that pass it.
52
+ __global__ void checkFrustum(int P, const float *orig_points,
53
+ const float *viewmatrix, const float *projmatrix,
54
+ bool *present) {
55
+ auto idx = cg::this_grid().thread_rank();
56
+ if (idx >= P)
57
+ return;
58
+
59
+ float3 p_view;
60
+ present[idx] =
61
+ in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view);
62
+ }
63
+
64
+ // Generates one key/value pair for all Gaussian / tile overlaps.
65
+ // Run once per Gaussian (1:N mapping).
66
+ __global__ void duplicateWithKeys(int P, const float2 *points_xy,
67
+ const float *depths, const uint32_t *offsets,
68
+ uint64_t *gaussian_keys_unsorted,
69
+ uint32_t *gaussian_values_unsorted,
70
+ int *radii, dim3 grid) {
71
+ auto idx = cg::this_grid().thread_rank();
72
+ if (idx >= P)
73
+ return;
74
+
75
+ // Generate no key/value pair for invisible Gaussians
76
+ if (radii[idx] > 0) {
77
+ // Find this Gaussian's offset in buffer for writing keys/values.
78
+ uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];
79
+ uint2 rect_min, rect_max;
80
+
81
+ getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);
82
+
83
+ // For each tile that the bounding rect overlaps, emit a
84
+ // key/value pair. The key is | tile ID | depth |,
85
+ // and the value is the ID of the Gaussian. Sorting the values
86
+ // with this key yields Gaussian IDs in a list, such that they
87
+ // are first sorted by tile and then by depth.
88
+ for (int y = rect_min.y; y < rect_max.y; y++) {
89
+ for (int x = rect_min.x; x < rect_max.x; x++) {
90
+ uint64_t key = y * grid.x + x;
91
+ key <<= 32;
92
+ key |= *((uint32_t *)&depths[idx]);
93
+ gaussian_keys_unsorted[off] = key;
94
+ gaussian_values_unsorted[off] = idx;
95
+ off++;
96
+ }
97
+ }
98
+ }
99
+ }
100
+
101
+ // Check keys to see if it is at the start/end of one tile's range in
102
+ // the full sorted list. If yes, write start/end of this tile.
103
+ // Run once per instanced (duplicated) Gaussian ID.
104
+ __global__ void identifyTileRanges(int L, uint64_t *point_list_keys,
105
+ uint2 *ranges) {
106
+ auto idx = cg::this_grid().thread_rank();
107
+ if (idx >= L)
108
+ return;
109
+
110
+ // Read tile ID from key. Update start/end of tile range if at limit.
111
+ uint64_t key = point_list_keys[idx];
112
+ uint32_t currtile = key >> 32;
113
+ if (idx == 0)
114
+ ranges[currtile].x = 0;
115
+ else {
116
+ uint32_t prevtile = point_list_keys[idx - 1] >> 32;
117
+ if (currtile != prevtile) {
118
+ ranges[prevtile].y = idx;
119
+ ranges[currtile].x = idx;
120
+ }
121
+ }
122
+ if (idx == L - 1)
123
+ ranges[currtile].y = L;
124
+ }
125
+
126
+ // Mark Gaussians as visible/invisible, based on view frustum testing
127
+ void CudaRasterizer::Rasterizer::markVisible(int P, float *means3D,
128
+ float *viewmatrix,
129
+ float *projmatrix, bool *present) {
130
+ checkFrustum<<<(P + 255) / 256, 256>>>(P, means3D, viewmatrix, projmatrix,
131
+ present);
132
+ }
133
+
134
+ CudaRasterizer::GeometryState
135
+ CudaRasterizer::GeometryState::fromChunk(char *&chunk, size_t P) {
136
+ GeometryState geom;
137
+ obtain(chunk, geom.depths, P, 128);
138
+ obtain(chunk, geom.clamped, P * 3, 128);
139
+ obtain(chunk, geom.internal_radii, P, 128);
140
+ obtain(chunk, geom.means2D, P, 128);
141
+ obtain(chunk, geom.cov3D, P * 6, 128);
142
+ obtain(chunk, geom.conic_opacity, P, 128);
143
+ obtain(chunk, geom.rgb, P * 3, 128);
144
+ obtain(chunk, geom.tiles_touched, P, 128);
145
+ cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched,
146
+ geom.tiles_touched, P);
147
+ obtain(chunk, geom.scanning_space, geom.scan_size, 128);
148
+ obtain(chunk, geom.point_offsets, P, 128);
149
+ return geom;
150
+ }
151
+
152
+ CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char *&chunk,
153
+ size_t N) {
154
+ ImageState img;
155
+ obtain(chunk, img.accum_alpha, N, 128);
156
+ obtain(chunk, img.n_contrib, N, 128);
157
+ obtain(chunk, img.ranges, N, 128);
158
+ return img;
159
+ }
160
+
161
+ CudaRasterizer::BinningState
162
+ CudaRasterizer::BinningState::fromChunk(char *&chunk, size_t P) {
163
+ BinningState binning;
164
+ obtain(chunk, binning.point_list, P, 128);
165
+ obtain(chunk, binning.point_list_unsorted, P, 128);
166
+ obtain(chunk, binning.point_list_keys, P, 128);
167
+ obtain(chunk, binning.point_list_keys_unsorted, P, 128);
168
+ cub::DeviceRadixSort::SortPairs(
169
+ nullptr, binning.sorting_size, binning.point_list_keys_unsorted,
170
+ binning.point_list_keys, binning.point_list_unsorted, binning.point_list,
171
+ P);
172
+ obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
173
+ return binning;
174
+ }
175
+
176
+ // Forward rendering procedure for differentiable rasterization
177
+ // of Gaussians.
178
+ int CudaRasterizer::Rasterizer::forward(
179
+ std::function<char *(size_t)> geometryBuffer,
180
+ std::function<char *(size_t)> binningBuffer,
181
+ std::function<char *(size_t)> imageBuffer, const int P, int D, int M,
182
+ const float *background, const int width, int height, const float *means3D,
183
+ const float *shs, const float *colors_precomp, const float *opacities,
184
+ const float *scales, const float scale_modifier, const float *rotations,
185
+ const float *cov3D_precomp, const float *viewmatrix,
186
+ const float *projmatrix, const float *cam_pos, const float tan_fovx,
187
+ float tan_fovy, const bool prefiltered, float *out_color, int *radii,
188
+ bool debug) {
189
+ const float focal_y = height / (2.0f * tan_fovy);
190
+ const float focal_x = width / (2.0f * tan_fovx);
191
+
192
+ size_t chunk_size = required<GeometryState>(P);
193
+ char *chunkptr = geometryBuffer(chunk_size);
194
+ GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
195
+
196
+ if (radii == nullptr) {
197
+ radii = geomState.internal_radii;
198
+ }
199
+
200
+ dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X,
201
+ (height + BLOCK_Y - 1) / BLOCK_Y, 1);
202
+ dim3 block(BLOCK_X, BLOCK_Y, 1);
203
+
204
+ // Dynamically resize image-based auxiliary buffers during training
205
+ size_t img_chunk_size = required<ImageState>(width * height);
206
+ char *img_chunkptr = imageBuffer(img_chunk_size);
207
+ ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
208
+
209
+ if (NUM_CHANNELS != 3 && colors_precomp == nullptr) {
210
+ throw std::runtime_error(
211
+ "For non-RGB, provide precomputed Gaussian colors!");
212
+ }
213
+
214
+ // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs
215
+ // to RGB)
216
+ CHECK_CUDA(FORWARD::preprocess(
217
+ P, D, M, means3D, (glm::vec3 *)scales, scale_modifier,
218
+ (glm::vec4 *)rotations, opacities, shs, geomState.clamped,
219
+ cov3D_precomp, colors_precomp, viewmatrix, projmatrix,
220
+ (glm::vec3 *)cam_pos, width, height, focal_x, focal_y,
221
+ tan_fovx, tan_fovy, radii, geomState.means2D, geomState.depths,
222
+ geomState.cov3D, geomState.rgb, geomState.conic_opacity,
223
+ tile_grid, geomState.tiles_touched, prefiltered),
224
+ debug)
225
+
226
+ // Compute prefix sum over full list of touched tile counts by Gaussians
227
+ // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
228
+ CHECK_CUDA(cub::DeviceScan::InclusiveSum(
229
+ geomState.scanning_space, geomState.scan_size,
230
+ geomState.tiles_touched, geomState.point_offsets, P),
231
+ debug)
232
+
233
+ // Retrieve total number of Gaussian instances to launch and resize aux
234
+ // buffers
235
+ int num_rendered;
236
+ CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1,
237
+ sizeof(int), cudaMemcpyDeviceToHost),
238
+ debug);
239
+
240
+ size_t binning_chunk_size = required<BinningState>(num_rendered);
241
+ char *binning_chunkptr = binningBuffer(binning_chunk_size);
242
+ BinningState binningState =
243
+ BinningState::fromChunk(binning_chunkptr, num_rendered);
244
+
245
+ // For each instance to be rendered, produce adequate [ tile | depth ] key
246
+ // and corresponding dublicated Gaussian indices to be sorted
247
+ duplicateWithKeys<<<(P + 255) / 256, 256>>>(
248
+ P, geomState.means2D, geomState.depths, geomState.point_offsets,
249
+ binningState.point_list_keys_unsorted, binningState.point_list_unsorted,
250
+ radii, tile_grid) CHECK_CUDA(, debug)
251
+
252
+ int bit = getHigherMsb(tile_grid.x * tile_grid.y);
253
+
254
+ // Sort complete list of (duplicated) Gaussian indices by keys
255
+ CHECK_CUDA(cub::DeviceRadixSort::SortPairs(
256
+ binningState.list_sorting_space, binningState.sorting_size,
257
+ binningState.point_list_keys_unsorted,
258
+ binningState.point_list_keys, binningState.point_list_unsorted,
259
+ binningState.point_list, num_rendered, 0, 32 + bit),
260
+ debug)
261
+
262
+ CHECK_CUDA(
263
+ cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)),
264
+ debug);
265
+
266
+ // Identify start and end of per-tile workloads in sorted list
267
+ if (num_rendered > 0)
268
+ identifyTileRanges<<<(num_rendered + 255) / 256, 256>>>(
269
+ num_rendered, binningState.point_list_keys, imgState.ranges);
270
+ CHECK_CUDA(, debug)
271
+
272
+ // Let each tile blend its range of Gaussians independently in parallel
273
+ const float *feature_ptr =
274
+ colors_precomp != nullptr ? colors_precomp : geomState.rgb;
275
+ CHECK_CUDA(FORWARD::render(tile_grid, block, imgState.ranges,
276
+ binningState.point_list, width, height,
277
+ geomState.means2D, feature_ptr,
278
+ geomState.conic_opacity, imgState.accum_alpha,
279
+ imgState.n_contrib, background, out_color),
280
+ debug)
281
+
282
+ return num_rendered;
283
+ }
284
+
285
+ // Produce necessary gradients for optimization, corresponding
286
+ // to forward render pass
287
+ void CudaRasterizer::Rasterizer::backward(
288
+ const int P, int D, int M, int R, const float *background, const int width,
289
+ int height, const float *means3D, const float *shs,
290
+ const float *colors_precomp, const float *scales,
291
+ const float scale_modifier, const float *rotations,
292
+ const float *cov3D_precomp, const float *viewmatrix,
293
+ const float *projmatrix, const float *campos, const float tan_fovx,
294
+ float tan_fovy, const int *radii, char *geom_buffer, char *binning_buffer,
295
+ char *img_buffer, const float *dL_dpix, float *dL_dmean2D, float *dL_dconic,
296
+ float *dL_dopacity, float *dL_dcolor, float *dL_dmean3D, float *dL_dcov3D,
297
+ float *dL_dsh, float *dL_dscale, float *dL_drot, bool debug) {
298
+ GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
299
+ BinningState binningState = BinningState::fromChunk(binning_buffer, R);
300
+ ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
301
+
302
+ if (radii == nullptr) {
303
+ radii = geomState.internal_radii;
304
+ }
305
+
306
+ const float focal_y = height / (2.0f * tan_fovy);
307
+ const float focal_x = width / (2.0f * tan_fovx);
308
+
309
+ const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X,
310
+ (height + BLOCK_Y - 1) / BLOCK_Y, 1);
311
+ const dim3 block(BLOCK_X, BLOCK_Y, 1);
312
+
313
+ // Compute loss gradients w.r.t. 2D mean position, conic matrix,
314
+ // opacity and RGB of Gaussians from per-pixel loss gradients.
315
+ // If we were given precomputed colors and not SHs, use them.
316
+ const float *color_ptr =
317
+ (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
318
+ CHECK_CUDA(BACKWARD::render(
319
+ tile_grid, block, imgState.ranges, binningState.point_list,
320
+ width, height, background, geomState.means2D,
321
+ geomState.conic_opacity, color_ptr, imgState.accum_alpha,
322
+ imgState.n_contrib, dL_dpix, (float3 *)dL_dmean2D,
323
+ (float4 *)dL_dconic, dL_dopacity, dL_dcolor),
324
+ debug)
325
+
326
+ // Take care of the rest of preprocessing. Was the precomputed covariance
327
+ // given to us or a scales/rot pair? If precomputed, pass that. If not,
328
+ // use the one we computed ourselves.
329
+ const float *cov3D_ptr =
330
+ (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
331
+ CHECK_CUDA(BACKWARD::preprocess(
332
+ P, D, M, (float3 *)means3D, radii, shs, geomState.clamped,
333
+ (glm::vec3 *)scales, (glm::vec4 *)rotations, scale_modifier,
334
+ cov3D_ptr, viewmatrix, projmatrix, focal_x, focal_y, tan_fovx,
335
+ tan_fovy, (glm::vec3 *)campos, (float3 *)dL_dmean2D, dL_dconic,
336
+ (glm::vec3 *)dL_dmean3D, dL_dcolor, dL_dcov3D, dL_dsh,
337
+ (glm::vec3 *)dL_dscale, (glm::vec4 *)dL_drot),
338
+ debug)
339
+ }
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer_impl.h ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include "rasterizer.h"
15
+ #include <cuda_runtime_api.h>
16
+ #include <iostream>
17
+ #include <vector>
18
+
19
+ namespace CudaRasterizer {
20
+ template <typename T>
21
+ static void obtain(char *&chunk, T *&ptr, std::size_t count,
22
+ std::size_t alignment) {
23
+ std::size_t offset =
24
+ (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) &
25
+ ~(alignment - 1);
26
+ ptr = reinterpret_cast<T *>(offset);
27
+ chunk = reinterpret_cast<char *>(ptr + count);
28
+ }
29
+
30
+ struct GeometryState {
31
+ size_t scan_size;
32
+ float *depths;
33
+ char *scanning_space;
34
+ bool *clamped;
35
+ int *internal_radii;
36
+ float2 *means2D;
37
+ float *cov3D;
38
+ float4 *conic_opacity;
39
+ float *rgb;
40
+ uint32_t *point_offsets;
41
+ uint32_t *tiles_touched;
42
+
43
+ static GeometryState fromChunk(char *&chunk, size_t P);
44
+ };
45
+
46
+ struct ImageState {
47
+ uint2 *ranges;
48
+ uint32_t *n_contrib;
49
+ float *accum_alpha;
50
+
51
+ static ImageState fromChunk(char *&chunk, size_t N);
52
+ };
53
+
54
+ struct BinningState {
55
+ size_t sorting_size;
56
+ uint64_t *point_list_keys_unsorted;
57
+ uint64_t *point_list_keys;
58
+ uint32_t *point_list_unsorted;
59
+ uint32_t *point_list;
60
+ char *list_sorting_space;
61
+
62
+ static BinningState fromChunk(char *&chunk, size_t P);
63
+ };
64
+
65
+ template <typename T> size_t required(size_t P) {
66
+ char *size = nullptr;
67
+ T::fromChunk(size, P);
68
+ return ((size_t)size) + 128;
69
+ }
70
+ }; // namespace CudaRasterizer
gaussiancity/extensions/diff_gaussian_rasterization/rasterize_points.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include "cuda_rasterizer/config.h"
13
+ #include "cuda_rasterizer/rasterizer.h"
14
+ #include <cstdio>
15
+ #include <cuda_runtime_api.h>
16
+ #include <fstream>
17
+ #include <functional>
18
+ #include <iostream>
19
+ #include <math.h>
20
+ #include <memory>
21
+ #include <sstream>
22
+ #include <stdio.h>
23
+ #include <string>
24
+ #include <torch/extension.h>
25
+ #include <tuple>
26
+
27
+ std::function<char *(size_t N)> resizeFunctional(torch::Tensor &t) {
28
+ auto lambda = [&t](size_t N) {
29
+ t.resize_({(long long)N});
30
+ return reinterpret_cast<char *>(t.contiguous().data_ptr());
31
+ };
32
+ return lambda;
33
+ }
34
+
35
+ std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
36
+ torch::Tensor>
37
+ RasterizeGaussiansCUDA(
38
+ const torch::Tensor &background, const torch::Tensor &means3D,
39
+ const torch::Tensor &colors, const torch::Tensor &opacity,
40
+ const torch::Tensor &scales, const torch::Tensor &rotations,
41
+ const float scale_modifier, const torch::Tensor &cov3D_precomp,
42
+ const torch::Tensor &viewmatrix, const torch::Tensor &projmatrix,
43
+ const float tan_fovx, const float tan_fovy, const int image_height,
44
+ const int image_width, const torch::Tensor &sh, const int degree,
45
+ const torch::Tensor &campos, const bool prefiltered, const bool debug) {
46
+ if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
47
+ AT_ERROR("means3D must have dimensions (num_points, 3)");
48
+ }
49
+
50
+ const int P = means3D.size(0);
51
+ const int H = image_height;
52
+ const int W = image_width;
53
+
54
+ auto int_opts = means3D.options().dtype(torch::kInt32);
55
+ auto float_opts = means3D.options().dtype(torch::kFloat32);
56
+
57
+ torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
58
+ torch::Tensor radii =
59
+ torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
60
+
61
+ torch::Device device(torch::kCUDA);
62
+ torch::TensorOptions options(torch::kByte);
63
+ torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
64
+ torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
65
+ torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
66
+ std::function<char *(size_t)> geomFunc = resizeFunctional(geomBuffer);
67
+ std::function<char *(size_t)> binningFunc = resizeFunctional(binningBuffer);
68
+ std::function<char *(size_t)> imgFunc = resizeFunctional(imgBuffer);
69
+
70
+ int rendered = 0;
71
+ if (P != 0) {
72
+ int M = 0;
73
+ if (sh.size(0) != 0) {
74
+ M = sh.size(1);
75
+ }
76
+
77
+ rendered = CudaRasterizer::Rasterizer::forward(
78
+ geomFunc, binningFunc, imgFunc, P, degree, M,
79
+ background.contiguous().data<float>(), W, H,
80
+ means3D.contiguous().data<float>(), sh.contiguous().data_ptr<float>(),
81
+ colors.contiguous().data<float>(), opacity.contiguous().data<float>(),
82
+ scales.contiguous().data_ptr<float>(), scale_modifier,
83
+ rotations.contiguous().data_ptr<float>(),
84
+ cov3D_precomp.contiguous().data<float>(),
85
+ viewmatrix.contiguous().data<float>(),
86
+ projmatrix.contiguous().data<float>(),
87
+ campos.contiguous().data<float>(), tan_fovx, tan_fovy, prefiltered,
88
+ out_color.contiguous().data<float>(), radii.contiguous().data<int>(),
89
+ debug);
90
+ }
91
+ return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer,
92
+ imgBuffer);
93
+ }
94
+
95
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
96
+ torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
97
+ RasterizeGaussiansBackwardCUDA(
98
+ const torch::Tensor &background, const torch::Tensor &means3D,
99
+ const torch::Tensor &radii, const torch::Tensor &colors,
100
+ const torch::Tensor &scales, const torch::Tensor &rotations,
101
+ const float scale_modifier, const torch::Tensor &cov3D_precomp,
102
+ const torch::Tensor &viewmatrix, const torch::Tensor &projmatrix,
103
+ const float tan_fovx, const float tan_fovy,
104
+ const torch::Tensor &dL_dout_color, const torch::Tensor &sh,
105
+ const int degree, const torch::Tensor &campos,
106
+ const torch::Tensor &geomBuffer, const int R,
107
+ const torch::Tensor &binningBuffer, const torch::Tensor &imageBuffer,
108
+ const bool debug) {
109
+ const int P = means3D.size(0);
110
+ const int H = dL_dout_color.size(1);
111
+ const int W = dL_dout_color.size(2);
112
+
113
+ int M = 0;
114
+ if (sh.size(0) != 0) {
115
+ M = sh.size(1);
116
+ }
117
+
118
+ torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
119
+ torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
120
+ torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
121
+ torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
122
+ torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
123
+ torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
124
+ torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
125
+ torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
126
+ torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
127
+
128
+ if (P != 0) {
129
+ CudaRasterizer::Rasterizer::backward(
130
+ P, degree, M, R, background.contiguous().data<float>(), W, H,
131
+ means3D.contiguous().data<float>(), sh.contiguous().data<float>(),
132
+ colors.contiguous().data<float>(), scales.data_ptr<float>(),
133
+ scale_modifier, rotations.data_ptr<float>(),
134
+ cov3D_precomp.contiguous().data<float>(),
135
+ viewmatrix.contiguous().data<float>(),
136
+ projmatrix.contiguous().data<float>(),
137
+ campos.contiguous().data<float>(), tan_fovx, tan_fovy,
138
+ radii.contiguous().data<int>(),
139
+ reinterpret_cast<char *>(geomBuffer.contiguous().data_ptr()),
140
+ reinterpret_cast<char *>(binningBuffer.contiguous().data_ptr()),
141
+ reinterpret_cast<char *>(imageBuffer.contiguous().data_ptr()),
142
+ dL_dout_color.contiguous().data<float>(),
143
+ dL_dmeans2D.contiguous().data<float>(),
144
+ dL_dconic.contiguous().data<float>(),
145
+ dL_dopacity.contiguous().data<float>(),
146
+ dL_dcolors.contiguous().data<float>(),
147
+ dL_dmeans3D.contiguous().data<float>(),
148
+ dL_dcov3D.contiguous().data<float>(), dL_dsh.contiguous().data<float>(),
149
+ dL_dscales.contiguous().data<float>(),
150
+ dL_drotations.contiguous().data<float>(), debug);
151
+ }
152
+
153
+ return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D,
154
+ dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
155
+ }
156
+
157
+ torch::Tensor markVisible(torch::Tensor &means3D, torch::Tensor &viewmatrix,
158
+ torch::Tensor &projmatrix) {
159
+ const int P = means3D.size(0);
160
+
161
+ torch::Tensor present =
162
+ torch::full({P}, false, means3D.options().dtype(at::kBool));
163
+
164
+ if (P != 0) {
165
+ CudaRasterizer::Rasterizer::markVisible(
166
+ P, means3D.contiguous().data<float>(),
167
+ viewmatrix.contiguous().data<float>(),
168
+ projmatrix.contiguous().data<float>(),
169
+ present.contiguous().data<bool>());
170
+ }
171
+
172
+ return present;
173
+ }
gaussiancity/extensions/diff_gaussian_rasterization/rasterize_points.h ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #pragma once
13
+ #include <cstdio>
14
+ #include <string>
15
+ #include <torch/extension.h>
16
+ #include <tuple>
17
+
18
+ std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
19
+ torch::Tensor>
20
+ RasterizeGaussiansCUDA(
21
+ const torch::Tensor &background, const torch::Tensor &means3D,
22
+ const torch::Tensor &colors, const torch::Tensor &opacity,
23
+ const torch::Tensor &scales, const torch::Tensor &rotations,
24
+ const float scale_modifier, const torch::Tensor &cov3D_precomp,
25
+ const torch::Tensor &viewmatrix, const torch::Tensor &projmatrix,
26
+ const float tan_fovx, const float tan_fovy, const int image_height,
27
+ const int image_width, const torch::Tensor &sh, const int degree,
28
+ const torch::Tensor &campos, const bool prefiltered, const bool debug);
29
+
30
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
31
+ torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
32
+ RasterizeGaussiansBackwardCUDA(
33
+ const torch::Tensor &background, const torch::Tensor &means3D,
34
+ const torch::Tensor &radii, const torch::Tensor &colors,
35
+ const torch::Tensor &scales, const torch::Tensor &rotations,
36
+ const float scale_modifier, const torch::Tensor &cov3D_precomp,
37
+ const torch::Tensor &viewmatrix, const torch::Tensor &projmatrix,
38
+ const float tan_fovx, const float tan_fovy,
39
+ const torch::Tensor &dL_dout_color, const torch::Tensor &sh,
40
+ const int degree, const torch::Tensor &campos,
41
+ const torch::Tensor &geomBuffer, const int R,
42
+ const torch::Tensor &binningBuffer, const torch::Tensor &imageBuffer,
43
+ const bool debug);
44
+
45
+ torch::Tensor markVisible(torch::Tensor &means3D, torch::Tensor &viewmatrix,
46
+ torch::Tensor &projmatrix);
gaussiancity/extensions/diff_gaussian_rasterization/setup.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from setuptools import setup
13
+ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
14
+ import os
15
+
16
+ setup(
17
+ name="diff_gaussian_rasterization",
18
+ version="1.0.0",
19
+ ext_modules=[
20
+ CUDAExtension(
21
+ name="diff_gaussian_rasterization_ext",
22
+ sources=[
23
+ "cuda_rasterizer/rasterizer_impl.cu",
24
+ "cuda_rasterizer/forward.cu",
25
+ "cuda_rasterizer/backward.cu",
26
+ "rasterize_points.cu",
27
+ "bindings.cpp",
28
+ ],
29
+ extra_compile_args={
30
+ "nvcc": [
31
+ "-I"
32
+ + os.path.join(
33
+ os.path.dirname(os.path.abspath(__file__)), "third_party/glm/"
34
+ )
35
+ ]
36
+ },
37
+ )
38
+ ],
39
+ cmdclass={"build_ext": BuildExtension},
40
+ )
gaussiancity/extensions/diff_gaussian_rasterization/third_party/glm ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 2d4c4b4dd31fde06cfffad7915c2b3006402322f
gaussiancity/extensions/diff_gaussian_rasterization/third_party/stbi_image_write.h ADDED
@@ -0,0 +1,1724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* stb_image_write - v1.16 - public domain - http://nothings.org/stb
2
+ writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015
3
+ no warranty implied; use at your own risk
4
+
5
+ Before #including,
6
+
7
+ #define STB_IMAGE_WRITE_IMPLEMENTATION
8
+
9
+ in the file that you want to have the implementation.
10
+
11
+ Will probably not work correctly with strict-aliasing optimizations.
12
+
13
+ ABOUT:
14
+
15
+ This header file is a library for writing images to C stdio or a callback.
16
+
17
+ The PNG output is not optimal; it is 20-50% larger than the file
18
+ written by a decent optimizing implementation; though providing a custom
19
+ zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that.
20
+ This library is designed for source code compactness and simplicity,
21
+ not optimal image file size or run-time performance.
22
+
23
+ BUILDING:
24
+
25
+ You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h.
26
+ You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace
27
+ malloc,realloc,free.
28
+ You can #define STBIW_MEMMOVE() to replace memmove()
29
+ You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function
30
+ for PNG compression (instead of the builtin one), it must have the following signature:
31
+ unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality);
32
+ The returned data will be freed with STBIW_FREE() (free() by default),
33
+ so it must be heap allocated with STBIW_MALLOC() (malloc() by default),
34
+
35
+ UNICODE:
36
+
37
+ If compiling for Windows and you wish to use Unicode filenames, compile
38
+ with
39
+ #define STBIW_WINDOWS_UTF8
40
+ and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert
41
+ Windows wchar_t filenames to utf8.
42
+
43
+ USAGE:
44
+
45
+ There are five functions, one for each image file format:
46
+
47
+ int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
48
+ int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
49
+ int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
50
+ int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality);
51
+ int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
52
+
53
+ void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically
54
+
55
+ There are also five equivalent functions that use an arbitrary write function. You are
56
+ expected to open/close your file-equivalent before and after calling these:
57
+
58
+ int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes);
59
+ int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
60
+ int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
61
+ int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);
62
+ int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);
63
+
64
+ where the callback is:
65
+ void stbi_write_func(void *context, void *data, int size);
66
+
67
+ You can configure it with these global variables:
68
+ int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE
69
+ int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression
70
+ int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode
71
+
72
+
73
+ You can define STBI_WRITE_NO_STDIO to disable the file variant of these
74
+ functions, so the library will not use stdio.h at all. However, this will
75
+ also disable HDR writing, because it requires stdio for formatted output.
76
+
77
+ Each function returns 0 on failure and non-0 on success.
78
+
79
+ The functions create an image file defined by the parameters. The image
80
+ is a rectangle of pixels stored from left-to-right, top-to-bottom.
81
+ Each pixel contains 'comp' channels of data stored interleaved with 8-bits
82
+ per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is
83
+ monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall.
84
+ The *data pointer points to the first byte of the top-left-most pixel.
85
+ For PNG, "stride_in_bytes" is the distance in bytes from the first byte of
86
+ a row of pixels to the first byte of the next row of pixels.
87
+
88
+ PNG creates output files with the same number of components as the input.
89
+ The BMP format expands Y to RGB in the file format and does not
90
+ output alpha.
91
+
92
+ PNG supports writing rectangles of data even when the bytes storing rows of
93
+ data are not consecutive in memory (e.g. sub-rectangles of a larger image),
94
+ by supplying the stride between the beginning of adjacent rows. The other
95
+ formats do not. (Thus you cannot write a native-format BMP through the BMP
96
+ writer, both because it is in BGR order and because it may have padding
97
+ at the end of the line.)
98
+
99
+ PNG allows you to set the deflate compression level by setting the global
100
+ variable 'stbi_write_png_compression_level' (it defaults to 8).
101
+
102
+ HDR expects linear float data. Since the format is always 32-bit rgb(e)
103
+ data, alpha (if provided) is discarded, and for monochrome data it is
104
+ replicated across all three channels.
105
+
106
+ TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed
107
+ data, set the global variable 'stbi_write_tga_with_rle' to 0.
108
+
109
+ JPEG does ignore alpha channels in input data; quality is between 1 and 100.
110
+ Higher quality looks better but results in a bigger image.
111
+ JPEG baseline (no JPEG progressive).
112
+
113
+ CREDITS:
114
+
115
+
116
+ Sean Barrett - PNG/BMP/TGA
117
+ Baldur Karlsson - HDR
118
+ Jean-Sebastien Guay - TGA monochrome
119
+ Tim Kelsey - misc enhancements
120
+ Alan Hickman - TGA RLE
121
+ Emmanuel Julien - initial file IO callback implementation
122
+ Jon Olick - original jo_jpeg.cpp code
123
+ Daniel Gibson - integrate JPEG, allow external zlib
124
+ Aarni Koskela - allow choosing PNG filter
125
+
126
+ bugfixes:
127
+ github:Chribba
128
+ Guillaume Chereau
129
+ github:jry2
130
+ github:romigrou
131
+ Sergio Gonzalez
132
+ Jonas Karlsson
133
+ Filip Wasil
134
+ Thatcher Ulrich
135
+ github:poppolopoppo
136
+ Patrick Boettcher
137
+ github:xeekworx
138
+ Cap Petschulat
139
+ Simon Rodriguez
140
+ Ivan Tikhonov
141
+ github:ignotion
142
+ Adam Schackart
143
+ Andrew Kensler
144
+
145
+ LICENSE
146
+
147
+ See end of file for license information.
148
+
149
+ */
150
+
151
+ #ifndef INCLUDE_STB_IMAGE_WRITE_H
152
+ #define INCLUDE_STB_IMAGE_WRITE_H
153
+
154
+ #include <stdlib.h>
155
+
156
+ // if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline'
157
+ #ifndef STBIWDEF
158
+ #ifdef STB_IMAGE_WRITE_STATIC
159
+ #define STBIWDEF static
160
+ #else
161
+ #ifdef __cplusplus
162
+ #define STBIWDEF extern "C"
163
+ #else
164
+ #define STBIWDEF extern
165
+ #endif
166
+ #endif
167
+ #endif
168
+
169
+ #ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations
170
+ STBIWDEF int stbi_write_tga_with_rle;
171
+ STBIWDEF int stbi_write_png_compression_level;
172
+ STBIWDEF int stbi_write_force_png_filter;
173
+ #endif
174
+
175
+ #ifndef STBI_WRITE_NO_STDIO
176
+ STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
177
+ STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
178
+ STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
179
+ STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
180
+ STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality);
181
+
182
+ #ifdef STBIW_WINDOWS_UTF8
183
+ STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);
184
+ #endif
185
+ #endif
186
+
187
+ typedef void stbi_write_func(void *context, void *data, int size);
188
+
189
+ STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes);
190
+ STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
191
+ STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
192
+ STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);
193
+ STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);
194
+
195
+ STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean);
196
+
197
+ #endif//INCLUDE_STB_IMAGE_WRITE_H
198
+
199
+ #ifdef STB_IMAGE_WRITE_IMPLEMENTATION
200
+
201
+ #ifdef _WIN32
202
+ #ifndef _CRT_SECURE_NO_WARNINGS
203
+ #define _CRT_SECURE_NO_WARNINGS
204
+ #endif
205
+ #ifndef _CRT_NONSTDC_NO_DEPRECATE
206
+ #define _CRT_NONSTDC_NO_DEPRECATE
207
+ #endif
208
+ #endif
209
+
210
+ #ifndef STBI_WRITE_NO_STDIO
211
+ #include <stdio.h>
212
+ #endif // STBI_WRITE_NO_STDIO
213
+
214
+ #include <stdarg.h>
215
+ #include <stdlib.h>
216
+ #include <string.h>
217
+ #include <math.h>
218
+
219
+ #if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED))
220
+ // ok
221
+ #elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED)
222
+ // ok
223
+ #else
224
+ #error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)."
225
+ #endif
226
+
227
+ #ifndef STBIW_MALLOC
228
+ #define STBIW_MALLOC(sz) malloc(sz)
229
+ #define STBIW_REALLOC(p,newsz) realloc(p,newsz)
230
+ #define STBIW_FREE(p) free(p)
231
+ #endif
232
+
233
+ #ifndef STBIW_REALLOC_SIZED
234
+ #define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz)
235
+ #endif
236
+
237
+
238
+ #ifndef STBIW_MEMMOVE
239
+ #define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz)
240
+ #endif
241
+
242
+
243
+ #ifndef STBIW_ASSERT
244
+ #include <assert.h>
245
+ #define STBIW_ASSERT(x) assert(x)
246
+ #endif
247
+
248
+ #define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff)
249
+
250
+ #ifdef STB_IMAGE_WRITE_STATIC
251
+ static int stbi_write_png_compression_level = 8;
252
+ static int stbi_write_tga_with_rle = 1;
253
+ static int stbi_write_force_png_filter = -1;
254
+ #else
255
+ int stbi_write_png_compression_level = 8;
256
+ int stbi_write_tga_with_rle = 1;
257
+ int stbi_write_force_png_filter = -1;
258
+ #endif
259
+
260
+ static int stbi__flip_vertically_on_write = 0;
261
+
262
+ STBIWDEF void stbi_flip_vertically_on_write(int flag)
263
+ {
264
+ stbi__flip_vertically_on_write = flag;
265
+ }
266
+
267
+ typedef struct
268
+ {
269
+ stbi_write_func *func;
270
+ void *context;
271
+ unsigned char buffer[64];
272
+ int buf_used;
273
+ } stbi__write_context;
274
+
275
+ // initialize a callback-based context
276
+ static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context)
277
+ {
278
+ s->func = c;
279
+ s->context = context;
280
+ }
281
+
282
+ #ifndef STBI_WRITE_NO_STDIO
283
+
284
+ static void stbi__stdio_write(void *context, void *data, int size)
285
+ {
286
+ fwrite(data,1,size,(FILE*) context);
287
+ }
288
+
289
+ #if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)
290
+ #ifdef __cplusplus
291
+ #define STBIW_EXTERN extern "C"
292
+ #else
293
+ #define STBIW_EXTERN extern
294
+ #endif
295
+ STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide);
296
+ STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default);
297
+
298
+ STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input)
299
+ {
300
+ return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL);
301
+ }
302
+ #endif
303
+
304
+ static FILE *stbiw__fopen(char const *filename, char const *mode)
305
+ {
306
+ FILE *f;
307
+ #if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)
308
+ wchar_t wMode[64];
309
+ wchar_t wFilename[1024];
310
+ if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename)))
311
+ return 0;
312
+
313
+ if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode)))
314
+ return 0;
315
+
316
+ #if defined(_MSC_VER) && _MSC_VER >= 1400
317
+ if (0 != _wfopen_s(&f, wFilename, wMode))
318
+ f = 0;
319
+ #else
320
+ f = _wfopen(wFilename, wMode);
321
+ #endif
322
+
323
+ #elif defined(_MSC_VER) && _MSC_VER >= 1400
324
+ if (0 != fopen_s(&f, filename, mode))
325
+ f=0;
326
+ #else
327
+ f = fopen(filename, mode);
328
+ #endif
329
+ return f;
330
+ }
331
+
332
+ static int stbi__start_write_file(stbi__write_context *s, const char *filename)
333
+ {
334
+ FILE *f = stbiw__fopen(filename, "wb");
335
+ stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f);
336
+ return f != NULL;
337
+ }
338
+
339
+ static void stbi__end_write_file(stbi__write_context *s)
340
+ {
341
+ fclose((FILE *)s->context);
342
+ }
343
+
344
+ #endif // !STBI_WRITE_NO_STDIO
345
+
346
+ typedef unsigned int stbiw_uint32;
347
+ typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1];
348
+
349
+ static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v)
350
+ {
351
+ while (*fmt) {
352
+ switch (*fmt++) {
353
+ case ' ': break;
354
+ case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int));
355
+ s->func(s->context,&x,1);
356
+ break; }
357
+ case '2': { int x = va_arg(v,int);
358
+ unsigned char b[2];
359
+ b[0] = STBIW_UCHAR(x);
360
+ b[1] = STBIW_UCHAR(x>>8);
361
+ s->func(s->context,b,2);
362
+ break; }
363
+ case '4': { stbiw_uint32 x = va_arg(v,int);
364
+ unsigned char b[4];
365
+ b[0]=STBIW_UCHAR(x);
366
+ b[1]=STBIW_UCHAR(x>>8);
367
+ b[2]=STBIW_UCHAR(x>>16);
368
+ b[3]=STBIW_UCHAR(x>>24);
369
+ s->func(s->context,b,4);
370
+ break; }
371
+ default:
372
+ STBIW_ASSERT(0);
373
+ return;
374
+ }
375
+ }
376
+ }
377
+
378
+ static void stbiw__writef(stbi__write_context *s, const char *fmt, ...)
379
+ {
380
+ va_list v;
381
+ va_start(v, fmt);
382
+ stbiw__writefv(s, fmt, v);
383
+ va_end(v);
384
+ }
385
+
386
+ static void stbiw__write_flush(stbi__write_context *s)
387
+ {
388
+ if (s->buf_used) {
389
+ s->func(s->context, &s->buffer, s->buf_used);
390
+ s->buf_used = 0;
391
+ }
392
+ }
393
+
394
+ static void stbiw__putc(stbi__write_context *s, unsigned char c)
395
+ {
396
+ s->func(s->context, &c, 1);
397
+ }
398
+
399
+ static void stbiw__write1(stbi__write_context *s, unsigned char a)
400
+ {
401
+ if ((size_t)s->buf_used + 1 > sizeof(s->buffer))
402
+ stbiw__write_flush(s);
403
+ s->buffer[s->buf_used++] = a;
404
+ }
405
+
406
+ static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c)
407
+ {
408
+ int n;
409
+ if ((size_t)s->buf_used + 3 > sizeof(s->buffer))
410
+ stbiw__write_flush(s);
411
+ n = s->buf_used;
412
+ s->buf_used = n+3;
413
+ s->buffer[n+0] = a;
414
+ s->buffer[n+1] = b;
415
+ s->buffer[n+2] = c;
416
+ }
417
+
418
+ static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d)
419
+ {
420
+ unsigned char bg[3] = { 255, 0, 255}, px[3];
421
+ int k;
422
+
423
+ if (write_alpha < 0)
424
+ stbiw__write1(s, d[comp - 1]);
425
+
426
+ switch (comp) {
427
+ case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case
428
+ case 1:
429
+ if (expand_mono)
430
+ stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp
431
+ else
432
+ stbiw__write1(s, d[0]); // monochrome TGA
433
+ break;
434
+ case 4:
435
+ if (!write_alpha) {
436
+ // composite against pink background
437
+ for (k = 0; k < 3; ++k)
438
+ px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255;
439
+ stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]);
440
+ break;
441
+ }
442
+ /* FALLTHROUGH */
443
+ case 3:
444
+ stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]);
445
+ break;
446
+ }
447
+ if (write_alpha > 0)
448
+ stbiw__write1(s, d[comp - 1]);
449
+ }
450
+
451
+ static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono)
452
+ {
453
+ stbiw_uint32 zero = 0;
454
+ int i,j, j_end;
455
+
456
+ if (y <= 0)
457
+ return;
458
+
459
+ if (stbi__flip_vertically_on_write)
460
+ vdir *= -1;
461
+
462
+ if (vdir < 0) {
463
+ j_end = -1; j = y-1;
464
+ } else {
465
+ j_end = y; j = 0;
466
+ }
467
+
468
+ for (; j != j_end; j += vdir) {
469
+ for (i=0; i < x; ++i) {
470
+ unsigned char *d = (unsigned char *) data + (j*x+i)*comp;
471
+ stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d);
472
+ }
473
+ stbiw__write_flush(s);
474
+ s->func(s->context, &zero, scanline_pad);
475
+ }
476
+ }
477
+
478
+ static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...)
479
+ {
480
+ if (y < 0 || x < 0) {
481
+ return 0;
482
+ } else {
483
+ va_list v;
484
+ va_start(v, fmt);
485
+ stbiw__writefv(s, fmt, v);
486
+ va_end(v);
487
+ stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono);
488
+ return 1;
489
+ }
490
+ }
491
+
492
+ static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data)
493
+ {
494
+ if (comp != 4) {
495
+ // write RGB bitmap
496
+ int pad = (-x*3) & 3;
497
+ return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad,
498
+ "11 4 22 4" "4 44 22 444444",
499
+ 'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header
500
+ 40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header
501
+ } else {
502
+ // RGBA bitmaps need a v4 header
503
+ // use BI_BITFIELDS mode with 32bpp and alpha mask
504
+ // (straight BI_RGB with alpha mask doesn't work in most readers)
505
+ return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0,
506
+ "11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444",
507
+ 'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header
508
+ 108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header
509
+ }
510
+ }
511
+
512
+ STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)
513
+ {
514
+ stbi__write_context s = { 0 };
515
+ stbi__start_write_callbacks(&s, func, context);
516
+ return stbi_write_bmp_core(&s, x, y, comp, data);
517
+ }
518
+
519
+ #ifndef STBI_WRITE_NO_STDIO
520
+ STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data)
521
+ {
522
+ stbi__write_context s = { 0 };
523
+ if (stbi__start_write_file(&s,filename)) {
524
+ int r = stbi_write_bmp_core(&s, x, y, comp, data);
525
+ stbi__end_write_file(&s);
526
+ return r;
527
+ } else
528
+ return 0;
529
+ }
530
+ #endif //!STBI_WRITE_NO_STDIO
531
+
532
+ static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data)
533
+ {
534
+ int has_alpha = (comp == 2 || comp == 4);
535
+ int colorbytes = has_alpha ? comp-1 : comp;
536
+ int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3
537
+
538
+ if (y < 0 || x < 0)
539
+ return 0;
540
+
541
+ if (!stbi_write_tga_with_rle) {
542
+ return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0,
543
+ "111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8);
544
+ } else {
545
+ int i,j,k;
546
+ int jend, jdir;
547
+
548
+ stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8);
549
+
550
+ if (stbi__flip_vertically_on_write) {
551
+ j = 0;
552
+ jend = y;
553
+ jdir = 1;
554
+ } else {
555
+ j = y-1;
556
+ jend = -1;
557
+ jdir = -1;
558
+ }
559
+ for (; j != jend; j += jdir) {
560
+ unsigned char *row = (unsigned char *) data + j * x * comp;
561
+ int len;
562
+
563
+ for (i = 0; i < x; i += len) {
564
+ unsigned char *begin = row + i * comp;
565
+ int diff = 1;
566
+ len = 1;
567
+
568
+ if (i < x - 1) {
569
+ ++len;
570
+ diff = memcmp(begin, row + (i + 1) * comp, comp);
571
+ if (diff) {
572
+ const unsigned char *prev = begin;
573
+ for (k = i + 2; k < x && len < 128; ++k) {
574
+ if (memcmp(prev, row + k * comp, comp)) {
575
+ prev += comp;
576
+ ++len;
577
+ } else {
578
+ --len;
579
+ break;
580
+ }
581
+ }
582
+ } else {
583
+ for (k = i + 2; k < x && len < 128; ++k) {
584
+ if (!memcmp(begin, row + k * comp, comp)) {
585
+ ++len;
586
+ } else {
587
+ break;
588
+ }
589
+ }
590
+ }
591
+ }
592
+
593
+ if (diff) {
594
+ unsigned char header = STBIW_UCHAR(len - 1);
595
+ stbiw__write1(s, header);
596
+ for (k = 0; k < len; ++k) {
597
+ stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp);
598
+ }
599
+ } else {
600
+ unsigned char header = STBIW_UCHAR(len - 129);
601
+ stbiw__write1(s, header);
602
+ stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin);
603
+ }
604
+ }
605
+ }
606
+ stbiw__write_flush(s);
607
+ }
608
+ return 1;
609
+ }
610
+
611
+ STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)
612
+ {
613
+ stbi__write_context s = { 0 };
614
+ stbi__start_write_callbacks(&s, func, context);
615
+ return stbi_write_tga_core(&s, x, y, comp, (void *) data);
616
+ }
617
+
618
+ #ifndef STBI_WRITE_NO_STDIO
619
+ STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data)
620
+ {
621
+ stbi__write_context s = { 0 };
622
+ if (stbi__start_write_file(&s,filename)) {
623
+ int r = stbi_write_tga_core(&s, x, y, comp, (void *) data);
624
+ stbi__end_write_file(&s);
625
+ return r;
626
+ } else
627
+ return 0;
628
+ }
629
+ #endif
630
+
631
+ // *************************************************************************************************
632
+ // Radiance RGBE HDR writer
633
+ // by Baldur Karlsson
634
+
635
+ #define stbiw__max(a, b) ((a) > (b) ? (a) : (b))
636
+
637
+ #ifndef STBI_WRITE_NO_STDIO
638
+
639
+ static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear)
640
+ {
641
+ int exponent;
642
+ float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2]));
643
+
644
+ if (maxcomp < 1e-32f) {
645
+ rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0;
646
+ } else {
647
+ float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp;
648
+
649
+ rgbe[0] = (unsigned char)(linear[0] * normalize);
650
+ rgbe[1] = (unsigned char)(linear[1] * normalize);
651
+ rgbe[2] = (unsigned char)(linear[2] * normalize);
652
+ rgbe[3] = (unsigned char)(exponent + 128);
653
+ }
654
+ }
655
+
656
+ static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte)
657
+ {
658
+ unsigned char lengthbyte = STBIW_UCHAR(length+128);
659
+ STBIW_ASSERT(length+128 <= 255);
660
+ s->func(s->context, &lengthbyte, 1);
661
+ s->func(s->context, &databyte, 1);
662
+ }
663
+
664
+ static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data)
665
+ {
666
+ unsigned char lengthbyte = STBIW_UCHAR(length);
667
+ STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code
668
+ s->func(s->context, &lengthbyte, 1);
669
+ s->func(s->context, data, length);
670
+ }
671
+
672
+ static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline)
673
+ {
674
+ unsigned char scanlineheader[4] = { 2, 2, 0, 0 };
675
+ unsigned char rgbe[4];
676
+ float linear[3];
677
+ int x;
678
+
679
+ scanlineheader[2] = (width&0xff00)>>8;
680
+ scanlineheader[3] = (width&0x00ff);
681
+
682
+ /* skip RLE for images too small or large */
683
+ if (width < 8 || width >= 32768) {
684
+ for (x=0; x < width; x++) {
685
+ switch (ncomp) {
686
+ case 4: /* fallthrough */
687
+ case 3: linear[2] = scanline[x*ncomp + 2];
688
+ linear[1] = scanline[x*ncomp + 1];
689
+ linear[0] = scanline[x*ncomp + 0];
690
+ break;
691
+ default:
692
+ linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];
693
+ break;
694
+ }
695
+ stbiw__linear_to_rgbe(rgbe, linear);
696
+ s->func(s->context, rgbe, 4);
697
+ }
698
+ } else {
699
+ int c,r;
700
+ /* encode into scratch buffer */
701
+ for (x=0; x < width; x++) {
702
+ switch(ncomp) {
703
+ case 4: /* fallthrough */
704
+ case 3: linear[2] = scanline[x*ncomp + 2];
705
+ linear[1] = scanline[x*ncomp + 1];
706
+ linear[0] = scanline[x*ncomp + 0];
707
+ break;
708
+ default:
709
+ linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];
710
+ break;
711
+ }
712
+ stbiw__linear_to_rgbe(rgbe, linear);
713
+ scratch[x + width*0] = rgbe[0];
714
+ scratch[x + width*1] = rgbe[1];
715
+ scratch[x + width*2] = rgbe[2];
716
+ scratch[x + width*3] = rgbe[3];
717
+ }
718
+
719
+ s->func(s->context, scanlineheader, 4);
720
+
721
+ /* RLE each component separately */
722
+ for (c=0; c < 4; c++) {
723
+ unsigned char *comp = &scratch[width*c];
724
+
725
+ x = 0;
726
+ while (x < width) {
727
+ // find first run
728
+ r = x;
729
+ while (r+2 < width) {
730
+ if (comp[r] == comp[r+1] && comp[r] == comp[r+2])
731
+ break;
732
+ ++r;
733
+ }
734
+ if (r+2 >= width)
735
+ r = width;
736
+ // dump up to first run
737
+ while (x < r) {
738
+ int len = r-x;
739
+ if (len > 128) len = 128;
740
+ stbiw__write_dump_data(s, len, &comp[x]);
741
+ x += len;
742
+ }
743
+ // if there's a run, output it
744
+ if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd
745
+ // find next byte after run
746
+ while (r < width && comp[r] == comp[x])
747
+ ++r;
748
+ // output run up to r
749
+ while (x < r) {
750
+ int len = r-x;
751
+ if (len > 127) len = 127;
752
+ stbiw__write_run_data(s, len, comp[x]);
753
+ x += len;
754
+ }
755
+ }
756
+ }
757
+ }
758
+ }
759
+ }
760
+
761
+ static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data)
762
+ {
763
+ if (y <= 0 || x <= 0 || data == NULL)
764
+ return 0;
765
+ else {
766
+ // Each component is stored separately. Allocate scratch space for full output scanline.
767
+ unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4);
768
+ int i, len;
769
+ char buffer[128];
770
+ char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n";
771
+ s->func(s->context, header, sizeof(header)-1);
772
+
773
+ #ifdef __STDC_LIB_EXT1__
774
+ len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x);
775
+ #else
776
+ len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x);
777
+ #endif
778
+ s->func(s->context, buffer, len);
779
+
780
+ for(i=0; i < y; i++)
781
+ stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i));
782
+ STBIW_FREE(scratch);
783
+ return 1;
784
+ }
785
+ }
786
+
787
+ STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data)
788
+ {
789
+ stbi__write_context s = { 0 };
790
+ stbi__start_write_callbacks(&s, func, context);
791
+ return stbi_write_hdr_core(&s, x, y, comp, (float *) data);
792
+ }
793
+
794
+ STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data)
795
+ {
796
+ stbi__write_context s = { 0 };
797
+ if (stbi__start_write_file(&s,filename)) {
798
+ int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data);
799
+ stbi__end_write_file(&s);
800
+ return r;
801
+ } else
802
+ return 0;
803
+ }
804
+ #endif // STBI_WRITE_NO_STDIO
805
+
806
+
807
+ //////////////////////////////////////////////////////////////////////////////
808
+ //
809
+ // PNG writer
810
+ //
811
+
812
+ #ifndef STBIW_ZLIB_COMPRESS
813
+ // stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size()
814
+ #define stbiw__sbraw(a) ((int *) (void *) (a) - 2)
815
+ #define stbiw__sbm(a) stbiw__sbraw(a)[0]
816
+ #define stbiw__sbn(a) stbiw__sbraw(a)[1]
817
+
818
+ #define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a))
819
+ #define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0)
820
+ #define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a)))
821
+
822
+ #define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v))
823
+ #define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0)
824
+ #define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0)
825
+
826
+ static void *stbiw__sbgrowf(void **arr, int increment, int itemsize)
827
+ {
828
+ int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1;
829
+ void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2);
830
+ STBIW_ASSERT(p);
831
+ if (p) {
832
+ if (!*arr) ((int *) p)[1] = 0;
833
+ *arr = (void *) ((int *) p + 2);
834
+ stbiw__sbm(*arr) = m;
835
+ }
836
+ return *arr;
837
+ }
838
+
839
+ static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount)
840
+ {
841
+ while (*bitcount >= 8) {
842
+ stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer));
843
+ *bitbuffer >>= 8;
844
+ *bitcount -= 8;
845
+ }
846
+ return data;
847
+ }
848
+
849
+ static int stbiw__zlib_bitrev(int code, int codebits)
850
+ {
851
+ int res=0;
852
+ while (codebits--) {
853
+ res = (res << 1) | (code & 1);
854
+ code >>= 1;
855
+ }
856
+ return res;
857
+ }
858
+
859
+ static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit)
860
+ {
861
+ int i;
862
+ for (i=0; i < limit && i < 258; ++i)
863
+ if (a[i] != b[i]) break;
864
+ return i;
865
+ }
866
+
867
+ static unsigned int stbiw__zhash(unsigned char *data)
868
+ {
869
+ stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16);
870
+ hash ^= hash << 3;
871
+ hash += hash >> 5;
872
+ hash ^= hash << 4;
873
+ hash += hash >> 17;
874
+ hash ^= hash << 25;
875
+ hash += hash >> 6;
876
+ return hash;
877
+ }
878
+
879
+ #define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount))
880
+ #define stbiw__zlib_add(code,codebits) \
881
+ (bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush())
882
+ #define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c)
883
+ // default huffman tables
884
+ #define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8)
885
+ #define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9)
886
+ #define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7)
887
+ #define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8)
888
+ #define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n))
889
+ #define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n))
890
+
891
+ #define stbiw__ZHASH 16384
892
+
893
+ #endif // STBIW_ZLIB_COMPRESS
894
+
895
+ STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality)
896
+ {
897
+ #ifdef STBIW_ZLIB_COMPRESS
898
+ // user provided a zlib compress implementation, use that
899
+ return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality);
900
+ #else // use builtin
901
+ static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 };
902
+ static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 };
903
+ static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 };
904
+ static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 };
905
+ unsigned int bitbuf=0;
906
+ int i,j, bitcount=0;
907
+ unsigned char *out = NULL;
908
+ unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**));
909
+ if (hash_table == NULL)
910
+ return NULL;
911
+ if (quality < 5) quality = 5;
912
+
913
+ stbiw__sbpush(out, 0x78); // DEFLATE 32K window
914
+ stbiw__sbpush(out, 0x5e); // FLEVEL = 1
915
+ stbiw__zlib_add(1,1); // BFINAL = 1
916
+ stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman
917
+
918
+ for (i=0; i < stbiw__ZHASH; ++i)
919
+ hash_table[i] = NULL;
920
+
921
+ i=0;
922
+ while (i < data_len-3) {
923
+ // hash next 3 bytes of data to be compressed
924
+ int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3;
925
+ unsigned char *bestloc = 0;
926
+ unsigned char **hlist = hash_table[h];
927
+ int n = stbiw__sbcount(hlist);
928
+ for (j=0; j < n; ++j) {
929
+ if (hlist[j]-data > i-32768) { // if entry lies within window
930
+ int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i);
931
+ if (d >= best) { best=d; bestloc=hlist[j]; }
932
+ }
933
+ }
934
+ // when hash table entry is too long, delete half the entries
935
+ if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) {
936
+ STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality);
937
+ stbiw__sbn(hash_table[h]) = quality;
938
+ }
939
+ stbiw__sbpush(hash_table[h],data+i);
940
+
941
+ if (bestloc) {
942
+ // "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal
943
+ h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1);
944
+ hlist = hash_table[h];
945
+ n = stbiw__sbcount(hlist);
946
+ for (j=0; j < n; ++j) {
947
+ if (hlist[j]-data > i-32767) {
948
+ int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1);
949
+ if (e > best) { // if next match is better, bail on current match
950
+ bestloc = NULL;
951
+ break;
952
+ }
953
+ }
954
+ }
955
+ }
956
+
957
+ if (bestloc) {
958
+ int d = (int) (data+i - bestloc); // distance back
959
+ STBIW_ASSERT(d <= 32767 && best <= 258);
960
+ for (j=0; best > lengthc[j+1]-1; ++j);
961
+ stbiw__zlib_huff(j+257);
962
+ if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]);
963
+ for (j=0; d > distc[j+1]-1; ++j);
964
+ stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5);
965
+ if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]);
966
+ i += best;
967
+ } else {
968
+ stbiw__zlib_huffb(data[i]);
969
+ ++i;
970
+ }
971
+ }
972
+ // write out final bytes
973
+ for (;i < data_len; ++i)
974
+ stbiw__zlib_huffb(data[i]);
975
+ stbiw__zlib_huff(256); // end of block
976
+ // pad with 0 bits to byte boundary
977
+ while (bitcount)
978
+ stbiw__zlib_add(0,1);
979
+
980
+ for (i=0; i < stbiw__ZHASH; ++i)
981
+ (void) stbiw__sbfree(hash_table[i]);
982
+ STBIW_FREE(hash_table);
983
+
984
+ // store uncompressed instead if compression was worse
985
+ if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) {
986
+ stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1
987
+ for (j = 0; j < data_len;) {
988
+ int blocklen = data_len - j;
989
+ if (blocklen > 32767) blocklen = 32767;
990
+ stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression
991
+ stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN
992
+ stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8));
993
+ stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN
994
+ stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8));
995
+ memcpy(out+stbiw__sbn(out), data+j, blocklen);
996
+ stbiw__sbn(out) += blocklen;
997
+ j += blocklen;
998
+ }
999
+ }
1000
+
1001
+ {
1002
+ // compute adler32 on input
1003
+ unsigned int s1=1, s2=0;
1004
+ int blocklen = (int) (data_len % 5552);
1005
+ j=0;
1006
+ while (j < data_len) {
1007
+ for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; }
1008
+ s1 %= 65521; s2 %= 65521;
1009
+ j += blocklen;
1010
+ blocklen = 5552;
1011
+ }
1012
+ stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8));
1013
+ stbiw__sbpush(out, STBIW_UCHAR(s2));
1014
+ stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8));
1015
+ stbiw__sbpush(out, STBIW_UCHAR(s1));
1016
+ }
1017
+ *out_len = stbiw__sbn(out);
1018
+ // make returned pointer freeable
1019
+ STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len);
1020
+ return (unsigned char *) stbiw__sbraw(out);
1021
+ #endif // STBIW_ZLIB_COMPRESS
1022
+ }
1023
+
1024
+ static unsigned int stbiw__crc32(unsigned char *buffer, int len)
1025
+ {
1026
+ #ifdef STBIW_CRC32
1027
+ return STBIW_CRC32(buffer, len);
1028
+ #else
1029
+ static unsigned int crc_table[256] =
1030
+ {
1031
+ 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3,
1032
+ 0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91,
1033
+ 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7,
1034
+ 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5,
1035
+ 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B,
1036
+ 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59,
1037
+ 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F,
1038
+ 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D,
1039
+ 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433,
1040
+ 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01,
1041
+ 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457,
1042
+ 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65,
1043
+ 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB,
1044
+ 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9,
1045
+ 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F,
1046
+ 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD,
1047
+ 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683,
1048
+ 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1,
1049
+ 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7,
1050
+ 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5,
1051
+ 0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B,
1052
+ 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79,
1053
+ 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F,
1054
+ 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D,
1055
+ 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713,
1056
+ 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21,
1057
+ 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777,
1058
+ 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45,
1059
+ 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB,
1060
+ 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9,
1061
+ 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF,
1062
+ 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D
1063
+ };
1064
+
1065
+ unsigned int crc = ~0u;
1066
+ int i;
1067
+ for (i=0; i < len; ++i)
1068
+ crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)];
1069
+ return ~crc;
1070
+ #endif
1071
+ }
1072
+
1073
+ #define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4)
1074
+ #define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v));
1075
+ #define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3])
1076
+
1077
+ static void stbiw__wpcrc(unsigned char **data, int len)
1078
+ {
1079
+ unsigned int crc = stbiw__crc32(*data - len - 4, len+4);
1080
+ stbiw__wp32(*data, crc);
1081
+ }
1082
+
1083
+ static unsigned char stbiw__paeth(int a, int b, int c)
1084
+ {
1085
+ int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c);
1086
+ if (pa <= pb && pa <= pc) return STBIW_UCHAR(a);
1087
+ if (pb <= pc) return STBIW_UCHAR(b);
1088
+ return STBIW_UCHAR(c);
1089
+ }
1090
+
1091
+ // @OPTIMIZE: provide an option that always forces left-predict or paeth predict
1092
+ static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer)
1093
+ {
1094
+ static int mapping[] = { 0,1,2,3,4 };
1095
+ static int firstmap[] = { 0,1,0,5,6 };
1096
+ int *mymap = (y != 0) ? mapping : firstmap;
1097
+ int i;
1098
+ int type = mymap[filter_type];
1099
+ unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y);
1100
+ int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes;
1101
+
1102
+ if (type==0) {
1103
+ memcpy(line_buffer, z, width*n);
1104
+ return;
1105
+ }
1106
+
1107
+ // first loop isn't optimized since it's just one pixel
1108
+ for (i = 0; i < n; ++i) {
1109
+ switch (type) {
1110
+ case 1: line_buffer[i] = z[i]; break;
1111
+ case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break;
1112
+ case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break;
1113
+ case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break;
1114
+ case 5: line_buffer[i] = z[i]; break;
1115
+ case 6: line_buffer[i] = z[i]; break;
1116
+ }
1117
+ }
1118
+ switch (type) {
1119
+ case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break;
1120
+ case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break;
1121
+ case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break;
1122
+ case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break;
1123
+ case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break;
1124
+ case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break;
1125
+ }
1126
+ }
1127
+
1128
+ STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len)
1129
+ {
1130
+ int force_filter = stbi_write_force_png_filter;
1131
+ int ctype[5] = { -1, 0, 4, 2, 6 };
1132
+ unsigned char sig[8] = { 137,80,78,71,13,10,26,10 };
1133
+ unsigned char *out,*o, *filt, *zlib;
1134
+ signed char *line_buffer;
1135
+ int j,zlen;
1136
+
1137
+ if (stride_bytes == 0)
1138
+ stride_bytes = x * n;
1139
+
1140
+ if (force_filter >= 5) {
1141
+ force_filter = -1;
1142
+ }
1143
+
1144
+ filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0;
1145
+ line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; }
1146
+ for (j=0; j < y; ++j) {
1147
+ int filter_type;
1148
+ if (force_filter > -1) {
1149
+ filter_type = force_filter;
1150
+ stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer);
1151
+ } else { // Estimate the best filter by running through all of them:
1152
+ int best_filter = 0, best_filter_val = 0x7fffffff, est, i;
1153
+ for (filter_type = 0; filter_type < 5; filter_type++) {
1154
+ stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer);
1155
+
1156
+ // Estimate the entropy of the line using this filter; the less, the better.
1157
+ est = 0;
1158
+ for (i = 0; i < x*n; ++i) {
1159
+ est += abs((signed char) line_buffer[i]);
1160
+ }
1161
+ if (est < best_filter_val) {
1162
+ best_filter_val = est;
1163
+ best_filter = filter_type;
1164
+ }
1165
+ }
1166
+ if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it
1167
+ stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer);
1168
+ filter_type = best_filter;
1169
+ }
1170
+ }
1171
+ // when we get here, filter_type contains the filter type, and line_buffer contains the data
1172
+ filt[j*(x*n+1)] = (unsigned char) filter_type;
1173
+ STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n);
1174
+ }
1175
+ STBIW_FREE(line_buffer);
1176
+ zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level);
1177
+ STBIW_FREE(filt);
1178
+ if (!zlib) return 0;
1179
+
1180
+ // each tag requires 12 bytes of overhead
1181
+ out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12);
1182
+ if (!out) return 0;
1183
+ *out_len = 8 + 12+13 + 12+zlen + 12;
1184
+
1185
+ o=out;
1186
+ STBIW_MEMMOVE(o,sig,8); o+= 8;
1187
+ stbiw__wp32(o, 13); // header length
1188
+ stbiw__wptag(o, "IHDR");
1189
+ stbiw__wp32(o, x);
1190
+ stbiw__wp32(o, y);
1191
+ *o++ = 8;
1192
+ *o++ = STBIW_UCHAR(ctype[n]);
1193
+ *o++ = 0;
1194
+ *o++ = 0;
1195
+ *o++ = 0;
1196
+ stbiw__wpcrc(&o,13);
1197
+
1198
+ stbiw__wp32(o, zlen);
1199
+ stbiw__wptag(o, "IDAT");
1200
+ STBIW_MEMMOVE(o, zlib, zlen);
1201
+ o += zlen;
1202
+ STBIW_FREE(zlib);
1203
+ stbiw__wpcrc(&o, zlen);
1204
+
1205
+ stbiw__wp32(o,0);
1206
+ stbiw__wptag(o, "IEND");
1207
+ stbiw__wpcrc(&o,0);
1208
+
1209
+ STBIW_ASSERT(o == out + *out_len);
1210
+
1211
+ return out;
1212
+ }
1213
+
1214
+ #ifndef STBI_WRITE_NO_STDIO
1215
+ STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes)
1216
+ {
1217
+ FILE *f;
1218
+ int len;
1219
+ unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);
1220
+ if (png == NULL) return 0;
1221
+
1222
+ f = stbiw__fopen(filename, "wb");
1223
+ if (!f) { STBIW_FREE(png); return 0; }
1224
+ fwrite(png, 1, len, f);
1225
+ fclose(f);
1226
+ STBIW_FREE(png);
1227
+ return 1;
1228
+ }
1229
+ #endif
1230
+
1231
+ STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes)
1232
+ {
1233
+ int len;
1234
+ unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);
1235
+ if (png == NULL) return 0;
1236
+ func(context, png, len);
1237
+ STBIW_FREE(png);
1238
+ return 1;
1239
+ }
1240
+
1241
+
1242
+ /* ***************************************************************************
1243
+ *
1244
+ * JPEG writer
1245
+ *
1246
+ * This is based on Jon Olick's jo_jpeg.cpp:
1247
+ * public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html
1248
+ */
1249
+
1250
+ static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18,
1251
+ 24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 };
1252
+
1253
+ static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) {
1254
+ int bitBuf = *bitBufP, bitCnt = *bitCntP;
1255
+ bitCnt += bs[1];
1256
+ bitBuf |= bs[0] << (24 - bitCnt);
1257
+ while(bitCnt >= 8) {
1258
+ unsigned char c = (bitBuf >> 16) & 255;
1259
+ stbiw__putc(s, c);
1260
+ if(c == 255) {
1261
+ stbiw__putc(s, 0);
1262
+ }
1263
+ bitBuf <<= 8;
1264
+ bitCnt -= 8;
1265
+ }
1266
+ *bitBufP = bitBuf;
1267
+ *bitCntP = bitCnt;
1268
+ }
1269
+
1270
+ static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) {
1271
+ float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p;
1272
+ float z1, z2, z3, z4, z5, z11, z13;
1273
+
1274
+ float tmp0 = d0 + d7;
1275
+ float tmp7 = d0 - d7;
1276
+ float tmp1 = d1 + d6;
1277
+ float tmp6 = d1 - d6;
1278
+ float tmp2 = d2 + d5;
1279
+ float tmp5 = d2 - d5;
1280
+ float tmp3 = d3 + d4;
1281
+ float tmp4 = d3 - d4;
1282
+
1283
+ // Even part
1284
+ float tmp10 = tmp0 + tmp3; // phase 2
1285
+ float tmp13 = tmp0 - tmp3;
1286
+ float tmp11 = tmp1 + tmp2;
1287
+ float tmp12 = tmp1 - tmp2;
1288
+
1289
+ d0 = tmp10 + tmp11; // phase 3
1290
+ d4 = tmp10 - tmp11;
1291
+
1292
+ z1 = (tmp12 + tmp13) * 0.707106781f; // c4
1293
+ d2 = tmp13 + z1; // phase 5
1294
+ d6 = tmp13 - z1;
1295
+
1296
+ // Odd part
1297
+ tmp10 = tmp4 + tmp5; // phase 2
1298
+ tmp11 = tmp5 + tmp6;
1299
+ tmp12 = tmp6 + tmp7;
1300
+
1301
+ // The rotator is modified from fig 4-8 to avoid extra negations.
1302
+ z5 = (tmp10 - tmp12) * 0.382683433f; // c6
1303
+ z2 = tmp10 * 0.541196100f + z5; // c2-c6
1304
+ z4 = tmp12 * 1.306562965f + z5; // c2+c6
1305
+ z3 = tmp11 * 0.707106781f; // c4
1306
+
1307
+ z11 = tmp7 + z3; // phase 5
1308
+ z13 = tmp7 - z3;
1309
+
1310
+ *d5p = z13 + z2; // phase 6
1311
+ *d3p = z13 - z2;
1312
+ *d1p = z11 + z4;
1313
+ *d7p = z11 - z4;
1314
+
1315
+ *d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6;
1316
+ }
1317
+
1318
+ static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) {
1319
+ int tmp1 = val < 0 ? -val : val;
1320
+ val = val < 0 ? val-1 : val;
1321
+ bits[1] = 1;
1322
+ while(tmp1 >>= 1) {
1323
+ ++bits[1];
1324
+ }
1325
+ bits[0] = val & ((1<<bits[1])-1);
1326
+ }
1327
+
1328
+ static int stbiw__jpg_processDU(stbi__write_context *s, int *bitBuf, int *bitCnt, float *CDU, int du_stride, float *fdtbl, int DC, const unsigned short HTDC[256][2], const unsigned short HTAC[256][2]) {
1329
+ const unsigned short EOB[2] = { HTAC[0x00][0], HTAC[0x00][1] };
1330
+ const unsigned short M16zeroes[2] = { HTAC[0xF0][0], HTAC[0xF0][1] };
1331
+ int dataOff, i, j, n, diff, end0pos, x, y;
1332
+ int DU[64];
1333
+
1334
+ // DCT rows
1335
+ for(dataOff=0, n=du_stride*8; dataOff<n; dataOff+=du_stride) {
1336
+ stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+1], &CDU[dataOff+2], &CDU[dataOff+3], &CDU[dataOff+4], &CDU[dataOff+5], &CDU[dataOff+6], &CDU[dataOff+7]);
1337
+ }
1338
+ // DCT columns
1339
+ for(dataOff=0; dataOff<8; ++dataOff) {
1340
+ stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+du_stride], &CDU[dataOff+du_stride*2], &CDU[dataOff+du_stride*3], &CDU[dataOff+du_stride*4],
1341
+ &CDU[dataOff+du_stride*5], &CDU[dataOff+du_stride*6], &CDU[dataOff+du_stride*7]);
1342
+ }
1343
+ // Quantize/descale/zigzag the coefficients
1344
+ for(y = 0, j=0; y < 8; ++y) {
1345
+ for(x = 0; x < 8; ++x,++j) {
1346
+ float v;
1347
+ i = y*du_stride+x;
1348
+ v = CDU[i]*fdtbl[j];
1349
+ // DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? ceilf(v - 0.5f) : floorf(v + 0.5f));
1350
+ // ceilf() and floorf() are C99, not C89, but I /think/ they're not needed here anyway?
1351
+ DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? v - 0.5f : v + 0.5f);
1352
+ }
1353
+ }
1354
+
1355
+ // Encode DC
1356
+ diff = DU[0] - DC;
1357
+ if (diff == 0) {
1358
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[0]);
1359
+ } else {
1360
+ unsigned short bits[2];
1361
+ stbiw__jpg_calcBits(diff, bits);
1362
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[bits[1]]);
1363
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);
1364
+ }
1365
+ // Encode ACs
1366
+ end0pos = 63;
1367
+ for(; (end0pos>0)&&(DU[end0pos]==0); --end0pos) {
1368
+ }
1369
+ // end0pos = first element in reverse order !=0
1370
+ if(end0pos == 0) {
1371
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);
1372
+ return DU[0];
1373
+ }
1374
+ for(i = 1; i <= end0pos; ++i) {
1375
+ int startpos = i;
1376
+ int nrzeroes;
1377
+ unsigned short bits[2];
1378
+ for (; DU[i]==0 && i<=end0pos; ++i) {
1379
+ }
1380
+ nrzeroes = i-startpos;
1381
+ if ( nrzeroes >= 16 ) {
1382
+ int lng = nrzeroes>>4;
1383
+ int nrmarker;
1384
+ for (nrmarker=1; nrmarker <= lng; ++nrmarker)
1385
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes);
1386
+ nrzeroes &= 15;
1387
+ }
1388
+ stbiw__jpg_calcBits(DU[i], bits);
1389
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]);
1390
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);
1391
+ }
1392
+ if(end0pos != 63) {
1393
+ stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);
1394
+ }
1395
+ return DU[0];
1396
+ }
1397
+
1398
+ static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) {
1399
+ // Constants that don't pollute global namespace
1400
+ static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0};
1401
+ static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1402
+ static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d};
1403
+ static const unsigned char std_ac_luminance_values[] = {
1404
+ 0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08,
1405
+ 0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28,
1406
+ 0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59,
1407
+ 0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89,
1408
+ 0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6,
1409
+ 0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2,
1410
+ 0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa
1411
+ };
1412
+ static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0};
1413
+ static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1414
+ static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77};
1415
+ static const unsigned char std_ac_chrominance_values[] = {
1416
+ 0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91,
1417
+ 0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26,
1418
+ 0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,
1419
+ 0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87,
1420
+ 0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,
1421
+ 0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,
1422
+ 0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa
1423
+ };
1424
+ // Huffman tables
1425
+ static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}};
1426
+ static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}};
1427
+ static const unsigned short YAC_HT[256][2] = {
1428
+ {10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1429
+ {12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1430
+ {28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1431
+ {58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1432
+ {59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1433
+ {122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1434
+ {123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1435
+ {250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1436
+ {504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1437
+ {505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1438
+ {506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1439
+ {1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1440
+ {1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1441
+ {2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1442
+ {65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0},
1443
+ {2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}
1444
+ };
1445
+ static const unsigned short UVAC_HT[256][2] = {
1446
+ {0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1447
+ {11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1448
+ {26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1449
+ {27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1450
+ {58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1451
+ {59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1452
+ {121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1453
+ {122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1454
+ {249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1455
+ {503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1456
+ {504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1457
+ {505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1458
+ {506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1459
+ {2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
1460
+ {16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0},
1461
+ {1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}
1462
+ };
1463
+ static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22,
1464
+ 37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99};
1465
+ static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99,
1466
+ 99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99};
1467
+ static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f,
1468
+ 1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f };
1469
+
1470
+ int row, col, i, k, subsample;
1471
+ float fdtbl_Y[64], fdtbl_UV[64];
1472
+ unsigned char YTable[64], UVTable[64];
1473
+
1474
+ if(!data || !width || !height || comp > 4 || comp < 1) {
1475
+ return 0;
1476
+ }
1477
+
1478
+ quality = quality ? quality : 90;
1479
+ subsample = quality <= 90 ? 1 : 0;
1480
+ quality = quality < 1 ? 1 : quality > 100 ? 100 : quality;
1481
+ quality = quality < 50 ? 5000 / quality : 200 - quality * 2;
1482
+
1483
+ for(i = 0; i < 64; ++i) {
1484
+ int uvti, yti = (YQT[i]*quality+50)/100;
1485
+ YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti);
1486
+ uvti = (UVQT[i]*quality+50)/100;
1487
+ UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti);
1488
+ }
1489
+
1490
+ for(row = 0, k = 0; row < 8; ++row) {
1491
+ for(col = 0; col < 8; ++col, ++k) {
1492
+ fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);
1493
+ fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);
1494
+ }
1495
+ }
1496
+
1497
+ // Write Headers
1498
+ {
1499
+ static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 };
1500
+ static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 };
1501
+ const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width),
1502
+ 3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 };
1503
+ s->func(s->context, (void*)head0, sizeof(head0));
1504
+ s->func(s->context, (void*)YTable, sizeof(YTable));
1505
+ stbiw__putc(s, 1);
1506
+ s->func(s->context, UVTable, sizeof(UVTable));
1507
+ s->func(s->context, (void*)head1, sizeof(head1));
1508
+ s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1);
1509
+ s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values));
1510
+ stbiw__putc(s, 0x10); // HTYACinfo
1511
+ s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1);
1512
+ s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values));
1513
+ stbiw__putc(s, 1); // HTUDCinfo
1514
+ s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1);
1515
+ s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values));
1516
+ stbiw__putc(s, 0x11); // HTUACinfo
1517
+ s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1);
1518
+ s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values));
1519
+ s->func(s->context, (void*)head2, sizeof(head2));
1520
+ }
1521
+
1522
+ // Encode 8x8 macroblocks
1523
+ {
1524
+ static const unsigned short fillBits[] = {0x7F, 7};
1525
+ int DCY=0, DCU=0, DCV=0;
1526
+ int bitBuf=0, bitCnt=0;
1527
+ // comp == 2 is grey+alpha (alpha is ignored)
1528
+ int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0;
1529
+ const unsigned char *dataR = (const unsigned char *)data;
1530
+ const unsigned char *dataG = dataR + ofsG;
1531
+ const unsigned char *dataB = dataR + ofsB;
1532
+ int x, y, pos;
1533
+ if(subsample) {
1534
+ for(y = 0; y < height; y += 16) {
1535
+ for(x = 0; x < width; x += 16) {
1536
+ float Y[256], U[256], V[256];
1537
+ for(row = y, pos = 0; row < y+16; ++row) {
1538
+ // row >= height => use last input row
1539
+ int clamped_row = (row < height) ? row : height - 1;
1540
+ int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;
1541
+ for(col = x; col < x+16; ++col, ++pos) {
1542
+ // if col >= width => use pixel from last input column
1543
+ int p = base_p + ((col < width) ? col : (width-1))*comp;
1544
+ float r = dataR[p], g = dataG[p], b = dataB[p];
1545
+ Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;
1546
+ U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;
1547
+ V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;
1548
+ }
1549
+ }
1550
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1551
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1552
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1553
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1554
+
1555
+ // subsample U,V
1556
+ {
1557
+ float subU[64], subV[64];
1558
+ int yy, xx;
1559
+ for(yy = 0, pos = 0; yy < 8; ++yy) {
1560
+ for(xx = 0; xx < 8; ++xx, ++pos) {
1561
+ int j = yy*32+xx*2;
1562
+ subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f;
1563
+ subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f;
1564
+ }
1565
+ }
1566
+ DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);
1567
+ DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);
1568
+ }
1569
+ }
1570
+ }
1571
+ } else {
1572
+ for(y = 0; y < height; y += 8) {
1573
+ for(x = 0; x < width; x += 8) {
1574
+ float Y[64], U[64], V[64];
1575
+ for(row = y, pos = 0; row < y+8; ++row) {
1576
+ // row >= height => use last input row
1577
+ int clamped_row = (row < height) ? row : height - 1;
1578
+ int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;
1579
+ for(col = x; col < x+8; ++col, ++pos) {
1580
+ // if col >= width => use pixel from last input column
1581
+ int p = base_p + ((col < width) ? col : (width-1))*comp;
1582
+ float r = dataR[p], g = dataG[p], b = dataB[p];
1583
+ Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;
1584
+ U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;
1585
+ V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;
1586
+ }
1587
+ }
1588
+
1589
+ DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT);
1590
+ DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);
1591
+ DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);
1592
+ }
1593
+ }
1594
+ }
1595
+
1596
+ // Do the bit alignment of the EOI marker
1597
+ stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits);
1598
+ }
1599
+
1600
+ // EOI
1601
+ stbiw__putc(s, 0xFF);
1602
+ stbiw__putc(s, 0xD9);
1603
+
1604
+ return 1;
1605
+ }
1606
+
1607
+ STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality)
1608
+ {
1609
+ stbi__write_context s = { 0 };
1610
+ stbi__start_write_callbacks(&s, func, context);
1611
+ return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality);
1612
+ }
1613
+
1614
+
1615
+ #ifndef STBI_WRITE_NO_STDIO
1616
+ STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality)
1617
+ {
1618
+ stbi__write_context s = { 0 };
1619
+ if (stbi__start_write_file(&s,filename)) {
1620
+ int r = stbi_write_jpg_core(&s, x, y, comp, data, quality);
1621
+ stbi__end_write_file(&s);
1622
+ return r;
1623
+ } else
1624
+ return 0;
1625
+ }
1626
+ #endif
1627
+
1628
+ #endif // STB_IMAGE_WRITE_IMPLEMENTATION
1629
+
1630
+ /* Revision history
1631
+ 1.16 (2021-07-11)
1632
+ make Deflate code emit uncompressed blocks when it would otherwise expand
1633
+ support writing BMPs with alpha channel
1634
+ 1.15 (2020-07-13) unknown
1635
+ 1.14 (2020-02-02) updated JPEG writer to downsample chroma channels
1636
+ 1.13
1637
+ 1.12
1638
+ 1.11 (2019-08-11)
1639
+
1640
+ 1.10 (2019-02-07)
1641
+ support utf8 filenames in Windows; fix warnings and platform ifdefs
1642
+ 1.09 (2018-02-11)
1643
+ fix typo in zlib quality API, improve STB_I_W_STATIC in C++
1644
+ 1.08 (2018-01-29)
1645
+ add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter
1646
+ 1.07 (2017-07-24)
1647
+ doc fix
1648
+ 1.06 (2017-07-23)
1649
+ writing JPEG (using Jon Olick's code)
1650
+ 1.05 ???
1651
+ 1.04 (2017-03-03)
1652
+ monochrome BMP expansion
1653
+ 1.03 ???
1654
+ 1.02 (2016-04-02)
1655
+ avoid allocating large structures on the stack
1656
+ 1.01 (2016-01-16)
1657
+ STBIW_REALLOC_SIZED: support allocators with no realloc support
1658
+ avoid race-condition in crc initialization
1659
+ minor compile issues
1660
+ 1.00 (2015-09-14)
1661
+ installable file IO function
1662
+ 0.99 (2015-09-13)
1663
+ warning fixes; TGA rle support
1664
+ 0.98 (2015-04-08)
1665
+ added STBIW_MALLOC, STBIW_ASSERT etc
1666
+ 0.97 (2015-01-18)
1667
+ fixed HDR asserts, rewrote HDR rle logic
1668
+ 0.96 (2015-01-17)
1669
+ add HDR output
1670
+ fix monochrome BMP
1671
+ 0.95 (2014-08-17)
1672
+ add monochrome TGA output
1673
+ 0.94 (2014-05-31)
1674
+ rename private functions to avoid conflicts with stb_image.h
1675
+ 0.93 (2014-05-27)
1676
+ warning fixes
1677
+ 0.92 (2010-08-01)
1678
+ casts to unsigned char to fix warnings
1679
+ 0.91 (2010-07-17)
1680
+ first public release
1681
+ 0.90 first internal release
1682
+ */
1683
+
1684
+ /*
1685
+ ------------------------------------------------------------------------------
1686
+ This software is available under 2 licenses -- choose whichever you prefer.
1687
+ ------------------------------------------------------------------------------
1688
+ ALTERNATIVE A - MIT License
1689
+ Copyright (c) 2017 Sean Barrett
1690
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
1691
+ this software and associated documentation files (the "Software"), to deal in
1692
+ the Software without restriction, including without limitation the rights to
1693
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
1694
+ of the Software, and to permit persons to whom the Software is furnished to do
1695
+ so, subject to the following conditions:
1696
+ The above copyright notice and this permission notice shall be included in all
1697
+ copies or substantial portions of the Software.
1698
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1699
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1700
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1701
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1702
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1703
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1704
+ SOFTWARE.
1705
+ ------------------------------------------------------------------------------
1706
+ ALTERNATIVE B - Public Domain (www.unlicense.org)
1707
+ This is free and unencumbered software released into the public domain.
1708
+ Anyone is free to copy, modify, publish, use, compile, sell, or distribute this
1709
+ software, either in source code form or as a compiled binary, for any purpose,
1710
+ commercial or non-commercial, and by any means.
1711
+ In jurisdictions that recognize copyright laws, the author or authors of this
1712
+ software dedicate any and all copyright interest in the software to the public
1713
+ domain. We make this dedication for the benefit of the public at large and to
1714
+ the detriment of our heirs and successors. We intend this dedication to be an
1715
+ overt act of relinquishment in perpetuity of all present and future rights to
1716
+ this software under copyright law.
1717
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1718
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1719
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1720
+ AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
1721
+ ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
1722
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1723
+ ------------------------------------------------------------------------------
1724
+ */
gaussiancity/extensions/grid_encoder/__init__.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: __init__.py
4
+ # @Author: Jiaxiang Tang (@ashawkey)
5
+ # @Date: 2023-04-15 10:39:28
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2023-04-15 13:08:46
8
+ # @Email: ashawkey1999@gmail.com
9
+ # @Ref: https://github.com/ashawkey/torch-ngp
10
+
11
+ import math
12
+ import numpy as np
13
+ import torch
14
+
15
+ import grid_encoder_ext
16
+
17
+
18
+ class GridEncoderFunction(torch.autograd.Function):
19
+ @staticmethod
20
+ def forward(
21
+ ctx,
22
+ inputs,
23
+ embeddings,
24
+ offsets,
25
+ per_level_scale,
26
+ base_resolution,
27
+ calc_grad_inputs=False,
28
+ gridtype=0,
29
+ align_corners=False,
30
+ ):
31
+ # inputs: [B, D], float in [0, 1]
32
+ # embeddings: [sO, C], float
33
+ # offsets: [L + 1], int
34
+ # RETURN: [B, F], float
35
+ inputs = inputs.contiguous()
36
+ # batch size, coord dim
37
+ B, D = inputs.shape
38
+ # level
39
+ L = offsets.shape[0] - 1
40
+ # embedding dim for each level
41
+ C = embeddings.shape[1]
42
+ # resolution multiplier at each level, apply log2 for later CUDA exp2f
43
+ S = math.log2(per_level_scale)
44
+ # base resolution
45
+ H = base_resolution
46
+ # L first, optimize cache for cuda kernel, but needs an extra permute later
47
+ outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
48
+
49
+ if calc_grad_inputs:
50
+ dy_dx = torch.empty(
51
+ B, L * D * C, device=inputs.device, dtype=embeddings.dtype
52
+ )
53
+ else:
54
+ dy_dx = torch.empty(
55
+ 1, device=inputs.device, dtype=embeddings.dtype
56
+ ) # placeholder... TODO: a better way?
57
+
58
+ grid_encoder_ext.forward(
59
+ inputs,
60
+ embeddings,
61
+ offsets,
62
+ outputs,
63
+ B,
64
+ D,
65
+ C,
66
+ L,
67
+ S,
68
+ H,
69
+ calc_grad_inputs,
70
+ dy_dx,
71
+ gridtype,
72
+ align_corners,
73
+ )
74
+ # permute back to [B, L * C]
75
+ outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
76
+ ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
77
+ ctx.dims = [B, D, C, L, S, H, gridtype]
78
+ ctx.calc_grad_inputs = calc_grad_inputs
79
+ ctx.align_corners = align_corners
80
+
81
+ return outputs
82
+
83
+ @staticmethod
84
+ def backward(ctx, grad):
85
+ inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
86
+ B, D, C, L, S, H, gridtype = ctx.dims
87
+ calc_grad_inputs = ctx.calc_grad_inputs
88
+ align_corners = ctx.align_corners
89
+
90
+ # grad: [B, L * C] --> [L, B, C]
91
+ grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
92
+ grad_embeddings = torch.zeros_like(embeddings)
93
+
94
+ if calc_grad_inputs:
95
+ grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
96
+ else:
97
+ grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
98
+
99
+ grid_encoder_ext.backward(
100
+ grad,
101
+ inputs,
102
+ embeddings,
103
+ offsets,
104
+ grad_embeddings,
105
+ B,
106
+ D,
107
+ C,
108
+ L,
109
+ S,
110
+ H,
111
+ calc_grad_inputs,
112
+ dy_dx,
113
+ grad_inputs,
114
+ gridtype,
115
+ align_corners,
116
+ )
117
+
118
+ if calc_grad_inputs:
119
+ grad_inputs = grad_inputs.to(inputs.dtype)
120
+ return grad_inputs, grad_embeddings, None, None, None, None, None, None
121
+ else:
122
+ return None, grad_embeddings, None, None, None, None, None, None
123
+
124
+
125
+ class GridEncoder(torch.nn.Module):
126
+ def __init__(
127
+ self,
128
+ in_channels,
129
+ n_levels,
130
+ lvl_channels,
131
+ desired_resolution,
132
+ per_level_scale=2,
133
+ base_resolution=16,
134
+ log2_hashmap_size=19,
135
+ gridtype="hash",
136
+ align_corners=False,
137
+ ):
138
+ super(GridEncoder, self).__init__()
139
+ self.in_channels = in_channels
140
+ self.n_levels = n_levels # num levels, each level multiply resolution by 2
141
+ self.lvl_channels = lvl_channels # encode channels per level
142
+ self.per_level_scale = 2 ** (
143
+ math.log2(desired_resolution / base_resolution) / (n_levels - 1)
144
+ )
145
+ self.log2_hashmap_size = log2_hashmap_size
146
+ self.base_resolution = base_resolution
147
+ self.output_dim = n_levels * lvl_channels
148
+ self.gridtype = gridtype
149
+ self.gridtype_id = 0 if gridtype == "hash" else 1
150
+ self.align_corners = align_corners
151
+
152
+ # allocate parameters
153
+ offsets = []
154
+ offset = 0
155
+ self.max_params = 2**log2_hashmap_size
156
+ for i in range(n_levels):
157
+ resolution = int(math.ceil(base_resolution * per_level_scale**i))
158
+ params_in_level = min(
159
+ self.max_params,
160
+ (resolution if align_corners else resolution + 1) ** in_channels,
161
+ ) # limit max number
162
+ params_in_level = int(math.ceil(params_in_level / 8) * 8) # make divisible
163
+ offsets.append(offset)
164
+ offset += params_in_level
165
+
166
+ offsets.append(offset)
167
+ offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
168
+ self.register_buffer("offsets", offsets)
169
+
170
+ self.n_params = offsets[-1] * lvl_channels
171
+ self.embeddings = torch.nn.Parameter(torch.empty(offset, lvl_channels))
172
+ self._init_weights()
173
+
174
+ def _init_weights(self):
175
+ self.embeddings.data.uniform_(-1e-4, 1e-4)
176
+
177
+ def forward(self, inputs, bound=1):
178
+ # inputs: [..., in_channels], normalized real world positions in [-bound, bound]
179
+ # return: [..., n_levels * lvl_channels]
180
+ inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
181
+ prefix_shape = list(inputs.shape[:-1])
182
+ inputs = inputs.view(-1, self.in_channels)
183
+ outputs = GridEncoderFunction.apply(
184
+ inputs,
185
+ self.embeddings,
186
+ self.offsets,
187
+ self.per_level_scale,
188
+ self.base_resolution,
189
+ inputs.requires_grad,
190
+ self.gridtype_id,
191
+ self.align_corners,
192
+ )
193
+ return outputs.view(prefix_shape + [self.output_dim])
gaussiancity/extensions/grid_encoder/bindings.cpp ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: grid_encoder_ext_cuda.cpp
3
+ * @Author: Jiaxiang Tang (@ashawkey)
4
+ * @Date: 2023-04-15 10:39:17
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2023-04-15 11:01:32
7
+ * @Email: ashawkey1999@gmail.com
8
+ * @Ref: https://github.com/ashawkey/torch-ngp
9
+ */
10
+
11
+ #include <stdint.h>
12
+ #include <torch/extension.h>
13
+ #include <torch/torch.h>
14
+
15
+ // inputs: [B, D], float, in [0, 1]
16
+ // embeddings: [sO, C], float
17
+ // offsets: [L + 1], uint32_t
18
+ // outputs: [B, L * C], float
19
+ // H: base resolution
20
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings,
21
+ const at::Tensor offsets, at::Tensor outputs,
22
+ const uint32_t B, const uint32_t D, const uint32_t C,
23
+ const uint32_t L, const float S, const uint32_t H,
24
+ const bool calc_grad_inputs, at::Tensor dy_dx,
25
+ const uint32_t gridtype, const bool align_corners);
26
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs,
27
+ const at::Tensor embeddings, const at::Tensor offsets,
28
+ at::Tensor grad_embeddings, const uint32_t B,
29
+ const uint32_t D, const uint32_t C, const uint32_t L,
30
+ const float S, const uint32_t H,
31
+ const bool calc_grad_inputs, const at::Tensor dy_dx,
32
+ at::Tensor grad_inputs, const uint32_t gridtype,
33
+ const bool align_corners);
34
+
35
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
36
+ m.def("forward", &grid_encode_forward,
37
+ "grid_encode_forward (CUDA)");
38
+ m.def("backward", &grid_encode_backward,
39
+ "grid_encode_backward (CUDA)");
40
+ }
gaussiancity/extensions/grid_encoder/grid_encoder_ext.cu ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: grid_encoder_ext.cu
3
+ * @Author: Jiaxiang Tang (@ashawkey)
4
+ * @Date: 2023-04-15 10:43:16
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2023-04-29 11:47:54
7
+ * @Email: ashawkey1999@gmail.com
8
+ * @Ref: https://github.com/ashawkey/torch-ngp
9
+ */
10
+
11
+ #include <cuda.h>
12
+ #include <cuda_fp16.h>
13
+ #include <cuda_runtime.h>
14
+
15
+ #include <ATen/cuda/CUDAContext.h>
16
+ #include <torch/torch.h>
17
+
18
+ #include <algorithm>
19
+ #include <stdexcept>
20
+
21
+ #include <cstdio>
22
+ #include <stdint.h>
23
+
24
+ #define CHECK_CUDA(x) \
25
+ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
26
+ #define CHECK_CONTIGUOUS(x) \
27
+ TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
28
+ #define CHECK_IS_INT(x) \
29
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
30
+ #x " must be an int tensor")
31
+ #define CHECK_IS_FLOATING(x) \
32
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \
33
+ x.scalar_type() == at::ScalarType::Half || \
34
+ x.scalar_type() == at::ScalarType::Double, \
35
+ #x " must be a floating tensor")
36
+
37
+ // just for compatability of half precision in
38
+ // AT_DISPATCH_FLOATING_TYPES_AND_HALF...
39
+ static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
40
+ // requires CUDA >= 10 and ARCH >= 70
41
+ // this is very slow compared to float or __half2, and never used.
42
+ // return atomicAdd(reinterpret_cast<__half*>(address), val);
43
+ }
44
+
45
+ template <typename T>
46
+ static inline __host__ __device__ T div_round_up(T val, T divisor) {
47
+ return (val + divisor - 1) / divisor;
48
+ }
49
+
50
+ template <uint32_t D>
51
+ __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
52
+ static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
53
+
54
+ // While 1 is technically not a good prime for hashing (or a prime at all), it
55
+ // helps memory coherence and is sufficient for our use case of obtaining a
56
+ // uniformly colliding index from high-dimensional coordinates.
57
+ constexpr uint32_t primes[7] = {1, 2654435761, 805459861, 3674653429,
58
+ 2097192037, 1434869437, 2165219737};
59
+
60
+ uint32_t result = 0;
61
+ #pragma unroll
62
+ for (uint32_t i = 0; i < D; ++i) {
63
+ result ^= pos_grid[i] * primes[i];
64
+ }
65
+
66
+ return result;
67
+ }
68
+
69
+ template <uint32_t D, uint32_t C>
70
+ __device__ uint32_t get_grid_index(const uint32_t gridtype,
71
+ const bool align_corners, const uint32_t ch,
72
+ const uint32_t hashmap_size,
73
+ const uint32_t resolution,
74
+ const uint32_t pos_grid[D]) {
75
+ uint32_t stride = 1;
76
+ uint32_t index = 0;
77
+
78
+ #pragma unroll
79
+ for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
80
+ index += pos_grid[d] * stride;
81
+ stride *= align_corners ? resolution : (resolution + 1);
82
+ }
83
+
84
+ // NOTE: for NeRF, the hash is in fact not necessary. Check
85
+ // https://github.com/NVlabs/instant-ngp/issues/97. gridtype: 0 == hash, 1 ==
86
+ // tiled
87
+ if (gridtype == 0 && stride > hashmap_size) {
88
+ index = fast_hash<D>(pos_grid);
89
+ }
90
+
91
+ return (index % hashmap_size) * C + ch;
92
+ }
93
+
94
+ template <typename scalar_t, uint32_t D, uint32_t C>
95
+ __global__ void
96
+ kernel_grid(const float *__restrict__ inputs, const scalar_t *__restrict__ grid,
97
+ const int *__restrict__ offsets, scalar_t *__restrict__ outputs,
98
+ const uint32_t B, const uint32_t L, const float S, const uint32_t H,
99
+ const bool calc_grad_inputs, scalar_t *__restrict__ dy_dx,
100
+ const uint32_t gridtype, const bool align_corners) {
101
+ const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
102
+
103
+ if (b >= B)
104
+ return;
105
+
106
+ const uint32_t level = blockIdx.y;
107
+
108
+ // locate
109
+ grid += (uint32_t)offsets[level] * C;
110
+ inputs += b * D;
111
+ outputs += level * B * C + b * C;
112
+
113
+ // check input range (should be in [0, 1])
114
+ bool flag_oob = false;
115
+ #pragma unroll
116
+ for (uint32_t d = 0; d < D; d++) {
117
+ if (inputs[d] < 0 || inputs[d] > 1) {
118
+ flag_oob = true;
119
+ }
120
+ }
121
+ // if input out of bound, just set output to 0
122
+ if (flag_oob) {
123
+ #pragma unroll
124
+ for (uint32_t ch = 0; ch < C; ch++) {
125
+ outputs[ch] = 0;
126
+ }
127
+ if (calc_grad_inputs) {
128
+ dy_dx += b * D * L * C + level * D * C; // B L D C
129
+ #pragma unroll
130
+ for (uint32_t d = 0; d < D; d++) {
131
+ #pragma unroll
132
+ for (uint32_t ch = 0; ch < C; ch++) {
133
+ dy_dx[d * C + ch] = 0;
134
+ }
135
+ }
136
+ }
137
+ return;
138
+ }
139
+
140
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
141
+ const float scale = exp2f(level * S) * H - 1.0f;
142
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
143
+
144
+ // calculate coordinate
145
+ float pos[D];
146
+ uint32_t pos_grid[D];
147
+
148
+ #pragma unroll
149
+ for (uint32_t d = 0; d < D; d++) {
150
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
151
+ pos_grid[d] = floorf(pos[d]);
152
+ pos[d] -= (float)pos_grid[d];
153
+ }
154
+
155
+ // printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1],
156
+ // pos_grid[0], pos_grid[1]);
157
+
158
+ // interpolate
159
+ scalar_t results[C] = {0}; // temp results in register
160
+
161
+ #pragma unroll
162
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
163
+ float w = 1;
164
+ uint32_t pos_grid_local[D];
165
+
166
+ #pragma unroll
167
+ for (uint32_t d = 0; d < D; d++) {
168
+ if ((idx & (1 << d)) == 0) {
169
+ w *= 1 - pos[d];
170
+ pos_grid_local[d] = pos_grid[d];
171
+ } else {
172
+ w *= pos[d];
173
+ pos_grid_local[d] = pos_grid[d] + 1;
174
+ }
175
+ }
176
+
177
+ uint32_t index = get_grid_index<D, C>(
178
+ gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
179
+
180
+ // writing to register (fast)
181
+ #pragma unroll
182
+ for (uint32_t ch = 0; ch < C; ch++) {
183
+ results[ch] += w * grid[index + ch];
184
+ }
185
+
186
+ // printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx,
187
+ // index, w, grid[index]);
188
+ }
189
+
190
+ // writing to global memory (slow)
191
+ #pragma unroll
192
+ for (uint32_t ch = 0; ch < C; ch++) {
193
+ outputs[ch] = results[ch];
194
+ }
195
+
196
+ // prepare dy_dx for calc_grad_inputs
197
+ // differentiable (soft) indexing:
198
+ // https://discuss.pytorch.org/t/differentiable-indexing/17647/9
199
+ if (calc_grad_inputs) {
200
+
201
+ dy_dx += b * D * L * C + level * D * C; // B L D C
202
+
203
+ #pragma unroll
204
+ for (uint32_t gd = 0; gd < D; gd++) {
205
+
206
+ scalar_t results_grad[C] = {0};
207
+
208
+ #pragma unroll
209
+ for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
210
+ float w = scale;
211
+ uint32_t pos_grid_local[D];
212
+
213
+ #pragma unroll
214
+ for (uint32_t nd = 0; nd < D - 1; nd++) {
215
+ const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
216
+
217
+ if ((idx & (1 << nd)) == 0) {
218
+ w *= 1 - pos[d];
219
+ pos_grid_local[d] = pos_grid[d];
220
+ } else {
221
+ w *= pos[d];
222
+ pos_grid_local[d] = pos_grid[d] + 1;
223
+ }
224
+ }
225
+
226
+ pos_grid_local[gd] = pos_grid[gd];
227
+ uint32_t index_left =
228
+ get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,
229
+ resolution, pos_grid_local);
230
+ pos_grid_local[gd] = pos_grid[gd] + 1;
231
+ uint32_t index_right =
232
+ get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,
233
+ resolution, pos_grid_local);
234
+
235
+ #pragma unroll
236
+ for (uint32_t ch = 0; ch < C; ch++) {
237
+ results_grad[ch] +=
238
+ w * (grid[index_right + ch] - grid[index_left + ch]);
239
+ }
240
+ }
241
+
242
+ #pragma unroll
243
+ for (uint32_t ch = 0; ch < C; ch++) {
244
+ dy_dx[gd * C + ch] = results_grad[ch];
245
+ }
246
+ }
247
+ }
248
+ }
249
+
250
+ template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
251
+ __global__ void kernel_grid_backward(
252
+ const scalar_t *__restrict__ grad, const float *__restrict__ inputs,
253
+ const scalar_t *__restrict__ grid, const int *__restrict__ offsets,
254
+ scalar_t *__restrict__ grad_grid, const uint32_t B, const uint32_t L,
255
+ const float S, const uint32_t H, const uint32_t gridtype,
256
+ const bool align_corners) {
257
+ const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
258
+ if (b >= B)
259
+ return;
260
+
261
+ const uint32_t level = blockIdx.y;
262
+ const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
263
+
264
+ // locate
265
+ grad_grid += offsets[level] * C;
266
+ inputs += b * D;
267
+ grad += level * B * C + b * C + ch; // L, B, C
268
+
269
+ const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
270
+ const float scale = exp2f(level * S) * H - 1.0f;
271
+ const uint32_t resolution = (uint32_t)ceil(scale) + 1;
272
+
273
+ // check input range (should be in [0, 1])
274
+ #pragma unroll
275
+ for (uint32_t d = 0; d < D; d++) {
276
+ if (inputs[d] < 0 || inputs[d] > 1) {
277
+ return; // grad is init as 0, so we simply return.
278
+ }
279
+ }
280
+
281
+ // calculate coordinate
282
+ float pos[D];
283
+ uint32_t pos_grid[D];
284
+
285
+ #pragma unroll
286
+ for (uint32_t d = 0; d < D; d++) {
287
+ pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
288
+ pos_grid[d] = floorf(pos[d]);
289
+ pos[d] -= (float)pos_grid[d];
290
+ }
291
+
292
+ scalar_t grad_cur[N_C] = {0}; // fetch to register
293
+ #pragma unroll
294
+ for (uint32_t c = 0; c < N_C; c++) {
295
+ grad_cur[c] = grad[c];
296
+ }
297
+
298
+ // interpolate
299
+ #pragma unroll
300
+ for (uint32_t idx = 0; idx < (1 << D); idx++) {
301
+ float w = 1;
302
+ uint32_t pos_grid_local[D];
303
+
304
+ #pragma unroll
305
+ for (uint32_t d = 0; d < D; d++) {
306
+ if ((idx & (1 << d)) == 0) {
307
+ w *= 1 - pos[d];
308
+ pos_grid_local[d] = pos_grid[d];
309
+ } else {
310
+ w *= pos[d];
311
+ pos_grid_local[d] = pos_grid[d] + 1;
312
+ }
313
+ }
314
+
315
+ uint32_t index = get_grid_index<D, C>(
316
+ gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
317
+
318
+ // atomicAdd for __half is slow (especially for large values), so we use
319
+ // __half2 if N_C % 2 == 0
320
+ // TODO: use float which is better than __half, if N_C % 2 != 0
321
+ if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
322
+ #pragma unroll
323
+ for (uint32_t c = 0; c < N_C; c += 2) {
324
+ // process two __half at once (by interpreting as a __half2)
325
+ __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
326
+ atomicAdd((__half2 *)&grad_grid[index + c], v);
327
+ }
328
+ // float, or __half when N_C % 2 != 0 (which means C == 1)
329
+ } else {
330
+ #pragma unroll
331
+ for (uint32_t c = 0; c < N_C; c++) {
332
+ atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
333
+ }
334
+ }
335
+ }
336
+ }
337
+
338
+ template <typename scalar_t, uint32_t D, uint32_t C>
339
+ __global__ void kernel_input_backward(const scalar_t *__restrict__ grad,
340
+ const scalar_t *__restrict__ dy_dx,
341
+ scalar_t *__restrict__ grad_inputs,
342
+ uint32_t B, uint32_t L) {
343
+ const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
344
+ if (t >= B * D)
345
+ return;
346
+
347
+ const uint32_t b = t / D;
348
+ const uint32_t d = t - b * D;
349
+
350
+ dy_dx += b * L * D * C;
351
+
352
+ scalar_t result = 0;
353
+
354
+ #pragma unroll
355
+ for (int l = 0; l < L; l++) {
356
+ #pragma unroll
357
+ for (int ch = 0; ch < C; ch++) {
358
+ result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
359
+ }
360
+ }
361
+
362
+ grad_inputs[t] = result;
363
+ }
364
+
365
+ template <typename scalar_t, uint32_t D>
366
+ void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings,
367
+ const int *offsets, scalar_t *outputs,
368
+ const uint32_t B, const uint32_t C, const uint32_t L,
369
+ const float S, const uint32_t H,
370
+ const bool calc_grad_inputs, scalar_t *dy_dx,
371
+ const uint32_t gridtype, const bool align_corners) {
372
+ static constexpr uint32_t N_THREAD = 512;
373
+ const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1};
374
+ switch (C) {
375
+ case 1:
376
+ kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(
377
+ inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
378
+ dy_dx, gridtype, align_corners);
379
+ break;
380
+ case 2:
381
+ kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(
382
+ inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
383
+ dy_dx, gridtype, align_corners);
384
+ break;
385
+ case 4:
386
+ kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(
387
+ inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
388
+ dy_dx, gridtype, align_corners);
389
+ break;
390
+ case 8:
391
+ kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(
392
+ inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
393
+ dy_dx, gridtype, align_corners);
394
+ break;
395
+ default:
396
+ throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
397
+ }
398
+ }
399
+
400
+ // inputs: [B, D], float, in [0, 1]
401
+ // embeddings: [sO, C], float
402
+ // offsets: [L + 1], uint32_t
403
+ // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit
404
+ // into cache at a time.) H: base resolution dy_dx: [B, L * D * C]
405
+ template <typename scalar_t>
406
+ void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings,
407
+ const int *offsets, scalar_t *outputs,
408
+ const uint32_t B, const uint32_t D,
409
+ const uint32_t C, const uint32_t L, const float S,
410
+ const uint32_t H, const bool calc_grad_inputs,
411
+ scalar_t *dy_dx, const uint32_t gridtype,
412
+ const bool align_corners) {
413
+ switch (D) {
414
+ case 2:
415
+ kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C,
416
+ L, S, H, calc_grad_inputs, dy_dx, gridtype,
417
+ align_corners);
418
+ break;
419
+ case 3:
420
+ kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C,
421
+ L, S, H, calc_grad_inputs, dy_dx, gridtype,
422
+ align_corners);
423
+ break;
424
+ case 4:
425
+ kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C,
426
+ L, S, H, calc_grad_inputs, dy_dx, gridtype,
427
+ align_corners);
428
+ break;
429
+ case 5:
430
+ kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C,
431
+ L, S, H, calc_grad_inputs, dy_dx, gridtype,
432
+ align_corners);
433
+ break;
434
+ default:
435
+ throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
436
+ }
437
+ }
438
+
439
+ template <typename scalar_t, uint32_t D>
440
+ void kernel_grid_backward_wrapper(
441
+ const scalar_t *grad, const float *inputs, const scalar_t *embeddings,
442
+ const int *offsets, scalar_t *grad_embeddings, const uint32_t B,
443
+ const uint32_t C, const uint32_t L, const float S, const uint32_t H,
444
+ const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs,
445
+ const uint32_t gridtype, const bool align_corners) {
446
+ static constexpr uint32_t N_THREAD = 256;
447
+ const uint32_t N_C = std::min(2u, C); // n_features_per_thread
448
+ const dim3 blocks_hashgrid = {div_round_up(B * C / N_C, N_THREAD), L, 1};
449
+ switch (C) {
450
+ case 1:
451
+ kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(
452
+ grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
453
+ gridtype, align_corners);
454
+ if (calc_grad_inputs)
455
+ kernel_input_backward<scalar_t, D, 1>
456
+ <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
457
+ grad_inputs, B, L);
458
+ break;
459
+ case 2:
460
+ kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(
461
+ grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
462
+ gridtype, align_corners);
463
+ if (calc_grad_inputs)
464
+ kernel_input_backward<scalar_t, D, 2>
465
+ <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
466
+ grad_inputs, B, L);
467
+ break;
468
+ case 4:
469
+ kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(
470
+ grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
471
+ gridtype, align_corners);
472
+ if (calc_grad_inputs)
473
+ kernel_input_backward<scalar_t, D, 4>
474
+ <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
475
+ grad_inputs, B, L);
476
+ break;
477
+ case 8:
478
+ kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(
479
+ grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
480
+ gridtype, align_corners);
481
+ if (calc_grad_inputs)
482
+ kernel_input_backward<scalar_t, D, 8>
483
+ <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
484
+ grad_inputs, B, L);
485
+ break;
486
+ default:
487
+ throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
488
+ }
489
+ }
490
+
491
+ // grad: [L, B, C], float
492
+ // inputs: [B, D], float, in [0, 1]
493
+ // embeddings: [sO, C], float
494
+ // offsets: [L + 1], uint32_t
495
+ // grad_embeddings: [sO, C]
496
+ // H: base resolution
497
+ template <typename scalar_t>
498
+ void grid_encode_backward_cuda(
499
+ const scalar_t *grad, const float *inputs, const scalar_t *embeddings,
500
+ const int *offsets, scalar_t *grad_embeddings, const uint32_t B,
501
+ const uint32_t D, const uint32_t C, const uint32_t L, const float S,
502
+ const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx,
503
+ scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
504
+ switch (D) {
505
+ case 2:
506
+ kernel_grid_backward_wrapper<scalar_t, 2>(
507
+ grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
508
+ calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
509
+ break;
510
+ case 3:
511
+ kernel_grid_backward_wrapper<scalar_t, 3>(
512
+ grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
513
+ calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
514
+ break;
515
+ case 4:
516
+ kernel_grid_backward_wrapper<scalar_t, 4>(
517
+ grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
518
+ calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
519
+ break;
520
+ case 5:
521
+ kernel_grid_backward_wrapper<scalar_t, 5>(
522
+ grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
523
+ calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
524
+ break;
525
+ default:
526
+ throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
527
+ }
528
+ }
529
+
530
+ void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings,
531
+ const at::Tensor offsets, at::Tensor outputs,
532
+ const uint32_t B, const uint32_t D, const uint32_t C,
533
+ const uint32_t L, const float S, const uint32_t H,
534
+ const bool calc_grad_inputs, at::Tensor dy_dx,
535
+ const uint32_t gridtype, const bool align_corners) {
536
+ CHECK_CUDA(inputs);
537
+ CHECK_CUDA(embeddings);
538
+ CHECK_CUDA(offsets);
539
+ CHECK_CUDA(outputs);
540
+ CHECK_CUDA(dy_dx);
541
+
542
+ CHECK_CONTIGUOUS(inputs);
543
+ CHECK_CONTIGUOUS(embeddings);
544
+ CHECK_CONTIGUOUS(offsets);
545
+ CHECK_CONTIGUOUS(outputs);
546
+ CHECK_CONTIGUOUS(dy_dx);
547
+
548
+ CHECK_IS_FLOATING(inputs);
549
+ CHECK_IS_FLOATING(embeddings);
550
+ CHECK_IS_INT(offsets);
551
+ CHECK_IS_FLOATING(outputs);
552
+ CHECK_IS_FLOATING(dy_dx);
553
+
554
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
555
+ embeddings.scalar_type(), "grid_encode_forward", ([&] {
556
+ grid_encode_forward_cuda<scalar_t>(
557
+ inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(),
558
+ offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L,
559
+ S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype,
560
+ align_corners);
561
+ }));
562
+ }
563
+
564
+ void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs,
565
+ const at::Tensor embeddings, const at::Tensor offsets,
566
+ at::Tensor grad_embeddings, const uint32_t B,
567
+ const uint32_t D, const uint32_t C, const uint32_t L,
568
+ const float S, const uint32_t H,
569
+ const bool calc_grad_inputs, const at::Tensor dy_dx,
570
+ at::Tensor grad_inputs, const uint32_t gridtype,
571
+ const bool align_corners) {
572
+ CHECK_CUDA(grad);
573
+ CHECK_CUDA(inputs);
574
+ CHECK_CUDA(embeddings);
575
+ CHECK_CUDA(offsets);
576
+ CHECK_CUDA(grad_embeddings);
577
+ CHECK_CUDA(dy_dx);
578
+ CHECK_CUDA(grad_inputs);
579
+
580
+ CHECK_CONTIGUOUS(grad);
581
+ CHECK_CONTIGUOUS(inputs);
582
+ CHECK_CONTIGUOUS(embeddings);
583
+ CHECK_CONTIGUOUS(offsets);
584
+ CHECK_CONTIGUOUS(grad_embeddings);
585
+ CHECK_CONTIGUOUS(dy_dx);
586
+ CHECK_CONTIGUOUS(grad_inputs);
587
+
588
+ CHECK_IS_FLOATING(grad);
589
+ CHECK_IS_FLOATING(inputs);
590
+ CHECK_IS_FLOATING(embeddings);
591
+ CHECK_IS_INT(offsets);
592
+ CHECK_IS_FLOATING(grad_embeddings);
593
+ CHECK_IS_FLOATING(dy_dx);
594
+ CHECK_IS_FLOATING(grad_inputs);
595
+
596
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
597
+ grad.scalar_type(), "grid_encode_backward", ([&] {
598
+ grid_encode_backward_cuda<scalar_t>(
599
+ grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(),
600
+ embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(),
601
+ grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H,
602
+ calc_grad_inputs, dy_dx.data_ptr<scalar_t>(),
603
+ grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);
604
+ }));
605
+ }
gaussiancity/extensions/grid_encoder/setup.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: setup.py
4
+ # @Author: Jiaxiang Tang (@ashawkey)
5
+ # @Date: 2023-04-15 10:33:32
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-09-18 10:08:45
8
+ # @Email: ashawkey1999@gmail.com
9
+ # @Ref: https://github.com/ashawkey/torch-ngp
10
+
11
+ from setuptools import setup
12
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
13
+
14
+ setup(
15
+ name="grid_encoder",
16
+ version="1.0.0",
17
+ ext_modules=[
18
+ CUDAExtension(
19
+ name="grid_encoder_ext",
20
+ sources=[
21
+ "grid_encoder_ext.cu",
22
+ "bindings.cpp",
23
+ ],
24
+ extra_compile_args={
25
+ "cxx": ["-O3", "-std=c++17"],
26
+ "nvcc": [
27
+ "-O3",
28
+ "-std=c++17",
29
+ "-U__CUDA_NO_HALF_OPERATORS__",
30
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
31
+ "-U__CUDA_NO_HALF2_OPERATORS__",
32
+ ],
33
+ },
34
+ ),
35
+ ],
36
+ cmdclass={
37
+ "build_ext": BuildExtension,
38
+ },
39
+ )
gaussiancity/extensions/voxlib/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: setup.py
4
+ # @Author: NVIDIA Corporation
5
+ # @Date: 2021-10-13 00:00:00
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-10-13 03:14:15
8
+ # @Email: root@haozhexie.com
9
+
10
+ from voxlib import ray_voxel_intersection_perspective
11
+ from voxlib import points_to_volume
12
+ from voxlib import maps_to_volume
gaussiancity/extensions/voxlib/bindings.cpp ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: bindings.cpp
3
+ * @Author: NVIDIA Corporation
4
+ * @Date: 2021-10-13 00:00:00
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2024-10-13 03:03:45
7
+ * @Email: root@haozhexie.com
8
+ */
9
+
10
+ #include <pybind11/pybind11.h>
11
+ #include <pybind11/stl.h>
12
+ #include <torch/extension.h>
13
+ #include <vector>
14
+
15
+ // Fast voxel traversal along rays
16
+ std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
17
+ const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
18
+ const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
19
+ const std::vector<float> &cam_c, const std::vector<int> &img_dims,
20
+ int max_samples);
21
+
22
+ torch::Tensor points_to_volume_cuda(const torch::Tensor &points,
23
+ const torch::Tensor &pt_ids,
24
+ const torch::Tensor &scales, int h, int w,
25
+ int d);
26
+
27
+ torch::Tensor
28
+ maps_to_volume_cuda(const torch::Tensor &inst_map, const torch::Tensor &td_hf,
29
+ const torch::Tensor &bu_hf,
30
+ const torch::Tensor &pts_map,
31
+ const torch::Tensor &scales);
32
+
33
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
34
+ m.def("ray_voxel_intersection_perspective",
35
+ &ray_voxel_intersection_perspective_cuda,
36
+ "Ray-voxel intersections given perspective camera parameters (CUDA)");
37
+ m.def("points_to_volume", &points_to_volume_cuda,
38
+ "Generate 3D volume from points (CUDA)");
39
+ m.def("maps_to_volume", &maps_to_volume_cuda,
40
+ "Generate 3D volume from maps (CUDA)");
41
+ }
gaussiancity/extensions/voxlib/maps_to_volume.cu ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: maps_to_volume.cu
3
+ * @Author: Haozhe Xie
4
+ * @Date: 2024-10-09 15:42:49
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2024-10-13 12:26:15
7
+ * @Email: root@haozhexie.com
8
+ */
9
+
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <torch/extension.h>
12
+
13
+ #include "voxlib_common.h"
14
+
15
+ #define TILE_DIM 16
16
+ #define BLDG_MAX_HEIGHT 504
17
+ #define BLDG_INS_MIN_ID 10
18
+ #define BLDG_FACADE_SEM 2
19
+ #define BLDG_ROOF_OFFSET 1
20
+
21
+ __global__ void maps_to_volume_cuda_kernel(int height, int width, int depth,
22
+ const int8_t *__restrict__ scales,
23
+ const short *__restrict__ inst_map,
24
+ const short *__restrict__ td_hf,
25
+ const short *__restrict__ bu_hf,
26
+ const bool *__restrict__ pts_map,
27
+ short *__restrict__ volume) {
28
+ size_t i = blockIdx.x * blockDim.x + threadIdx.x; // width
29
+ size_t j = blockIdx.y * blockDim.y + threadIdx.y; // height
30
+
31
+ if (i < width && j < height) {
32
+ bool has_pt = pts_map[j * width + i];
33
+ if (!has_pt) {
34
+ return;
35
+ }
36
+
37
+ // Fix: nonzero is not supported for tensors with more than INT_MAX elements
38
+ short hgt_up = td_hf[j * width + i];
39
+ short hgt_lw = bu_hf[j * width + i];
40
+ short inst = inst_map[j * width + i];
41
+ // WARN: The semantic labels for buildings would be merged to facade.
42
+ short sem_cls = inst < BLDG_INS_MIN_ID ? inst : BLDG_FACADE_SEM;
43
+ short scale = scales[sem_cls];
44
+
45
+ int64_t vol_offset = static_cast<int64_t>(j) * width * depth + i * depth;
46
+ for (int k = hgt_lw; k <= hgt_up; k += scale) {
47
+ // Make all objects hallow
48
+ bool is_border_1 = (k > hgt_up - scale) || (i < scale) ||
49
+ (i >= width - scale - 1) || (j < scale) ||
50
+ (j >= height - scale - 1);
51
+ bool is_border_2 = false;
52
+ bool is_border_3 = false;
53
+ if (!is_border_1) {
54
+ // Check is_border_1 to Prevent OOB
55
+ short nbr_hd_hf[8] = {
56
+ td_hf[(j - scale) * width + (i - scale)],
57
+ td_hf[(j - scale) * width + i],
58
+ td_hf[(j - scale) * width + (i + scale)],
59
+ td_hf[j * width + (i - scale)],
60
+ td_hf[j * width + (i + scale)],
61
+ td_hf[(j + scale) * width + (i - scale)],
62
+ td_hf[(j + scale) * width + i],
63
+ td_hf[(j + scale) * width + (i + scale)],
64
+ };
65
+ for (int ni = 0; ni < 8; ++ni) {
66
+ if (nbr_hd_hf[ni] != hgt_up) {
67
+ is_border_2 = true;
68
+ break;
69
+ }
70
+ }
71
+
72
+ short nbr_inst[8] = {
73
+ inst_map[(j - scale) * width + (i - scale)],
74
+ inst_map[(j - scale) * width + i],
75
+ inst_map[(j - scale) * width + (i + scale)],
76
+ inst_map[j * width + (i - scale)],
77
+ inst_map[j * width + (i + scale)],
78
+ inst_map[(j + scale) * width + (i - scale)],
79
+ inst_map[(j + scale) * width + i],
80
+ inst_map[(j + scale) * width + (i + scale)],
81
+ };
82
+ for (int ni = 0; ni < 8; ++ni) {
83
+ if (nbr_inst[ni] != inst) {
84
+ is_border_3 = true;
85
+ break;
86
+ }
87
+ }
88
+ }
89
+ if (!is_border_1 && !is_border_2 && !is_border_3) {
90
+ continue;
91
+ }
92
+
93
+ // Building Roof Handler (Recover roof instance ID)
94
+ if (k > hgt_up - scale && sem_cls == BLDG_FACADE_SEM) {
95
+ volume[vol_offset + k] = inst + 1;
96
+ } else {
97
+ volume[vol_offset + k] = inst;
98
+ }
99
+ }
100
+ }
101
+ }
102
+
103
+ torch::Tensor maps_to_volume_cuda(const torch::Tensor &inst_map,
104
+ const torch::Tensor &td_hf,
105
+ const torch::Tensor &bu_hf,
106
+ const torch::Tensor &pts_map,
107
+ const torch::Tensor &scales) {
108
+ CHECK_CUDA(inst_map);
109
+ CHECK_CUDA(td_hf);
110
+ CHECK_CUDA(bu_hf);
111
+ CHECK_CUDA(pts_map);
112
+ CHECK_CUDA(scales);
113
+
114
+ int curDevice = -1;
115
+ cudaGetDevice(&curDevice);
116
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
117
+ torch::Device device = inst_map.device();
118
+
119
+ int height = inst_map.size(0);
120
+ int width = inst_map.size(1);
121
+ int depth = BLDG_MAX_HEIGHT;
122
+
123
+ dim3 blockDim(TILE_DIM, TILE_DIM);
124
+ dim3 gridDim((width + blockDim.x - 1) / blockDim.x,
125
+ (height + blockDim.y - 1) / blockDim.y);
126
+
127
+ torch::Tensor volume =
128
+ torch::zeros({height, width, depth},
129
+ torch::TensorOptions().dtype(torch::kInt16).device(device));
130
+ maps_to_volume_cuda_kernel<<<gridDim, blockDim, 0, stream>>>(
131
+ height, width, depth, scales.data_ptr<int8_t>(),
132
+ inst_map.data_ptr<short>(), td_hf.data_ptr<short>(),
133
+ bu_hf.data_ptr<short>(), pts_map.data_ptr<bool>(),
134
+ volume.data_ptr<short>());
135
+
136
+ cudaError_t err = cudaGetLastError();
137
+ if (err != cudaSuccess) {
138
+ printf("Error in maps_to_volume_cuda_kernel: %s\n",
139
+ cudaGetErrorString(err));
140
+ }
141
+ return volume;
142
+ }
gaussiancity/extensions/voxlib/points_to_volume.cu ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: points_to_volume.cu
3
+ * @Author: Haozhe Xie
4
+ * @Date: 2024-02-24 14:09:38
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2024-10-13 12:29:46
7
+ * @Email: root@haozhexie.com
8
+ */
9
+
10
+ #include <cmath>
11
+ #include <cstdio>
12
+ #include <cstdlib>
13
+
14
+ #include <ATen/cuda/CUDAContext.h>
15
+ #include <torch/extension.h>
16
+
17
+ #include "voxlib_common.h"
18
+
19
+ #define THREADS_PER_BLOCK 256
20
+
21
+ __global__ void points_to_volume_cuda_cuda_kernel(
22
+ size_t n_pts, int h, int w, int d, const short *__restrict__ points,
23
+ const int *__restrict__ pt_ids, const short *__restrict__ scales,
24
+ int *__restrict__ volume) {
25
+ size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
26
+
27
+ if (idx >= n_pts) {
28
+ return;
29
+ }
30
+ int pid = pt_ids[idx];
31
+ int idx3 = idx * 3;
32
+ short x = points[idx3];
33
+ short y = points[idx3 + 1];
34
+ short z = points[idx3 + 2];
35
+ short sx = scales[idx3];
36
+ short sy = scales[idx3 + 1];
37
+ short sz = scales[idx3 + 2];
38
+
39
+ if (x >= w || y >= h || z >= d || x < 0 || y < 0 || z < 0) {
40
+ return;
41
+ }
42
+ for (int j = x; j < x + sx && j < w; ++j) {
43
+ for (int k = y; k < y + sy && k < h; ++k) {
44
+ for (int l = z; l < z + sz && l < d; ++l) {
45
+ int64_t idx = static_cast<int64_t>(k) * w * d + j * d + l;
46
+ volume[idx] = pid;
47
+ }
48
+ }
49
+ }
50
+ }
51
+
52
+ torch::Tensor points_to_volume_cuda(const torch::Tensor &points,
53
+ const torch::Tensor &pt_ids,
54
+ const torch::Tensor &scales, int h, int w,
55
+ int d) {
56
+ CHECK_CUDA(points);
57
+ CHECK_CUDA(pt_ids);
58
+ CHECK_CUDA(scales);
59
+
60
+ size_t n_pts = points.size(0);
61
+ int curDevice = -1;
62
+ cudaGetDevice(&curDevice);
63
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
64
+ torch::Device device = points.device();
65
+
66
+ int n_blocks = (n_pts + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
67
+ torch::Tensor volume = torch::zeros(
68
+ {h, w, d}, torch::TensorOptions().dtype(torch::kInt32).device(device));
69
+ points_to_volume_cuda_cuda_kernel<<<n_blocks, THREADS_PER_BLOCK, 0, stream>>>(
70
+ n_pts, h, w, d, points.data_ptr<short>(), pt_ids.data_ptr<int>(),
71
+ scales.data_ptr<short>(), volume.data_ptr<int>());
72
+
73
+ cudaError_t err = cudaGetLastError();
74
+ if (err != cudaSuccess) {
75
+ printf("Error in points_to_volume_cuda_cuda_kernel: %s\n",
76
+ cudaGetErrorString(err));
77
+ }
78
+ return volume;
79
+ }
gaussiancity/extensions/voxlib/ray_voxel_intersection.cu ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * @File: ray_voxel_intersection.cu
3
+ * @Author: NVIDIA Corporation
4
+ * @Date: 2021-10-13 00:00:00
5
+ * @Last Modified by: Haozhe Xie
6
+ * @Last Modified at: 2024-03-27 11:02:41
7
+ * @Email: root@haozhexie.com
8
+ */
9
+
10
+ #include <torch/types.h>
11
+
12
+ #include <ATen/ATen.h>
13
+ #include <ATen/AccumulateType.h>
14
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
15
+ #include <ATen/cuda/CUDAContext.h>
16
+
17
+ #include <cuda.h>
18
+ #include <cuda_runtime.h>
19
+ #include <curand.h>
20
+ #include <curand_kernel.h>
21
+ #include <time.h>
22
+
23
+ //#include <pybind11/numpy.h>
24
+ #include <pybind11/pybind11.h>
25
+ #include <pybind11/stl.h>
26
+ #include <vector>
27
+
28
+ #include "voxlib_common.h"
29
+
30
+ struct RVIP_Params {
31
+ int voxel_dims[3];
32
+ int voxel_strides[3];
33
+ int max_samples;
34
+ int img_dims[2];
35
+ // Camera parameters
36
+ float cam_ori[3];
37
+ float cam_fwd[3];
38
+ float cam_side[3];
39
+ float cam_up[3];
40
+ float cam_c[2];
41
+ float cam_f;
42
+ // unsigned long seed;
43
+ };
44
+
45
+ /*
46
+ out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1]
47
+ out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1]
48
+ out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3]
49
+ Image coordinates refer to the center of the pixel [0, 0, 0] at voxel
50
+ coordinate is at the corner of the corner block (instead of at the center)
51
+ */
52
+ template <int TILE_DIM>
53
+ static __global__ void ray_voxel_intersection_perspective_kernel(
54
+ int32_t *__restrict__ out_voxel_id, float *__restrict__ out_depth,
55
+ float *__restrict__ out_raydirs, const int32_t *__restrict__ in_voxel,
56
+ const RVIP_Params p) {
57
+
58
+ int img_coords[2];
59
+ img_coords[1] = blockIdx.x * TILE_DIM + threadIdx.x;
60
+ img_coords[0] = blockIdx.y * TILE_DIM + threadIdx.y;
61
+ if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) {
62
+ return;
63
+ }
64
+ int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1];
65
+
66
+ // Calculate ray origin and direction
67
+ float rayori[3], raydir[3];
68
+ rayori[0] = p.cam_ori[0];
69
+ rayori[1] = p.cam_ori[1];
70
+ rayori[2] = p.cam_ori[2];
71
+
72
+ // Camera intrinsics
73
+ float ndc_imcoords[2];
74
+ ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height
75
+ ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1];
76
+
77
+ raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] +
78
+ p.cam_fwd[0] * p.cam_f;
79
+ raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] +
80
+ p.cam_fwd[1] * p.cam_f;
81
+ raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] +
82
+ p.cam_fwd[2] * p.cam_f;
83
+ normalize<float, 3>(raydir);
84
+
85
+ // Save out_raydirs
86
+ out_raydirs[pix_index * 3] = raydir[0];
87
+ out_raydirs[pix_index * 3 + 1] = raydir[1];
88
+ out_raydirs[pix_index * 3 + 2] = raydir[2];
89
+
90
+ float axis_t[3];
91
+ int axis_int[3];
92
+ // int axis_intbound[3];
93
+
94
+ // Current voxel
95
+ axis_int[0] = floorf(rayori[0]);
96
+ axis_int[1] = floorf(rayori[1]);
97
+ axis_int[2] = floorf(rayori[2]);
98
+
99
+ #pragma unroll
100
+ for (int i = 0; i < 3; i++) {
101
+ if (raydir[i] > 0) {
102
+ // Initial t value
103
+ // Handle boundary case where rayori[i] is a whole number. Always round Up
104
+ // for the next block
105
+ // axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) /
106
+ // raydir[i];
107
+ axis_t[i] = ((float)(axis_int[i] + 1) - rayori[i]) / raydir[i];
108
+ } else if (raydir[i] < 0) {
109
+ axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
110
+ } else {
111
+ axis_t[i] = HUGE_VALF;
112
+ }
113
+ }
114
+
115
+ // Fused raymarching and sampling
116
+ bool quit = false;
117
+ for (int cur_plane = 0; cur_plane < p.max_samples;
118
+ cur_plane++) { // Last cycle is for calculating p2
119
+ float t = nanf("0");
120
+ float t2 = nanf("0");
121
+ int32_t blk_id = 0;
122
+ // Find the next intersection
123
+ while (!quit) {
124
+ // Find the next smallest t
125
+ float tnow;
126
+ // Hand unroll
127
+ if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
128
+ // Update current t
129
+ tnow = axis_t[0];
130
+ // Update t candidates
131
+ if (raydir[0] > 0) {
132
+ axis_int[0] += 1;
133
+ if (axis_int[0] >= p.voxel_dims[0]) {
134
+ quit = true;
135
+ }
136
+ axis_t[0] = ((float)(axis_int[0] + 1) - rayori[0]) / raydir[0];
137
+ } else {
138
+ axis_int[0] -= 1;
139
+ if (axis_int[0] < 0) {
140
+ quit = true;
141
+ }
142
+ axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0];
143
+ }
144
+ } else if (axis_t[1] <= axis_t[2]) {
145
+ tnow = axis_t[1];
146
+ if (raydir[1] > 0) {
147
+ axis_int[1] += 1;
148
+ if (axis_int[1] >= p.voxel_dims[1]) {
149
+ quit = true;
150
+ }
151
+ axis_t[1] = ((float)(axis_int[1] + 1) - rayori[1]) / raydir[1];
152
+ } else {
153
+ axis_int[1] -= 1;
154
+ if (axis_int[1] < 0) {
155
+ quit = true;
156
+ }
157
+ axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1];
158
+ }
159
+ } else {
160
+ tnow = axis_t[2];
161
+ if (raydir[2] > 0) {
162
+ axis_int[2] += 1;
163
+ if (axis_int[2] >= p.voxel_dims[2]) {
164
+ quit = true;
165
+ }
166
+ axis_t[2] = ((float)(axis_int[2] + 1) - rayori[2]) / raydir[2];
167
+ } else {
168
+ axis_int[2] -= 1;
169
+ if (axis_int[2] < 0) {
170
+ quit = true;
171
+ }
172
+ axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2];
173
+ }
174
+ }
175
+
176
+ if (quit) {
177
+ break;
178
+ }
179
+
180
+ // Skip empty space
181
+ // Could there be deadlock if the ray direction is away from the world?
182
+ if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] ||
183
+ axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] ||
184
+ axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) {
185
+ continue;
186
+ }
187
+
188
+ // Test intersection using voxel grid
189
+ int64_t in_voxel_idx =
190
+ static_cast<int64_t>(axis_int[0]) * p.voxel_strides[0] +
191
+ static_cast<int64_t>(axis_int[1]) * p.voxel_strides[1] +
192
+ static_cast<int64_t>(axis_int[2]) * p.voxel_strides[2];
193
+ blk_id = in_voxel[in_voxel_idx];
194
+ if (blk_id == 0) {
195
+ continue;
196
+ }
197
+
198
+ // Now that there is an intersection
199
+ t = tnow;
200
+ // Calculate t2
201
+ /*
202
+ #pragma unroll
203
+ for (int i=0; i<3; i++) {
204
+ if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
205
+ t2 = axis_t[i];
206
+ break;
207
+ }
208
+ }
209
+ */
210
+ // Hand unroll
211
+ if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
212
+ t2 = axis_t[0];
213
+ } else if (axis_t[1] <= axis_t[2]) {
214
+ t2 = axis_t[1];
215
+ } else {
216
+ t2 = axis_t[2];
217
+ }
218
+ break;
219
+ } // while !quit (ray marching loop)
220
+
221
+ out_depth[pix_index * p.max_samples + cur_plane] = t;
222
+ out_depth[p.img_dims[0] * p.img_dims[1] * p.max_samples +
223
+ pix_index * p.max_samples + cur_plane] = t2;
224
+ out_voxel_id[pix_index * p.max_samples + cur_plane] = blk_id;
225
+ } // cur_plane
226
+ }
227
+
228
+ /*
229
+ out:
230
+ out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1]
231
+ out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1]
232
+ out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3]
233
+ in:
234
+ in_voxel: torch CUDA int32 [X, Y, Z] [40, 512, 512]
235
+ cam_ori: torch float [3]
236
+ cam_dir: torch float [3]
237
+ cam_up: torch float [3]
238
+ cam_f: float
239
+ cam_c: int [2]
240
+ img_dims: int [2]
241
+ max_samples: int
242
+ */
243
+ std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
244
+ const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
245
+ const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
246
+ const std::vector<float> &cam_c, const std::vector<int> &img_dims,
247
+ int max_samples) {
248
+ CHECK_CUDA(in_voxel);
249
+
250
+ int curDevice = -1;
251
+ cudaGetDevice(&curDevice);
252
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
253
+ torch::Device device = in_voxel.device();
254
+
255
+ assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility
256
+ assert(in_voxel.dim() == 3);
257
+ assert(cam_ori.dtype() == torch::kFloat32);
258
+ assert(cam_ori.numel() == 3);
259
+ assert(cam_dir.dtype() == torch::kFloat32);
260
+ assert(cam_dir.numel() == 3);
261
+ assert(cam_up.dtype() == torch::kFloat32);
262
+ assert(cam_up.numel() == 3);
263
+ assert(img_dims.size() == 2);
264
+
265
+ RVIP_Params p;
266
+
267
+ // Calculate camera rays
268
+ const torch::Tensor cam_ori_c = cam_ori.cpu();
269
+ const torch::Tensor cam_dir_c = cam_dir.cpu();
270
+ const torch::Tensor cam_up_c = cam_up.cpu();
271
+
272
+ // Get the coordinate frame of camera space in world space
273
+ normalize<float, 3>(p.cam_fwd, cam_dir_c.data_ptr<float>());
274
+ cross<float>(p.cam_side, p.cam_fwd, cam_up_c.data_ptr<float>());
275
+ normalize<float, 3>(p.cam_side);
276
+ cross<float>(p.cam_up, p.cam_side, p.cam_fwd);
277
+ normalize<float, 3>(p.cam_up); // Not absolutely necessary as both vectors are
278
+ // normalized. But just in case...
279
+
280
+ copyarr<float, 3>(p.cam_ori, cam_ori_c.data_ptr<float>());
281
+
282
+ p.cam_f = cam_f;
283
+ p.cam_c[0] = cam_c[0];
284
+ p.cam_c[1] = cam_c[1];
285
+ p.max_samples = max_samples;
286
+
287
+ p.voxel_dims[0] = in_voxel.size(0);
288
+ p.voxel_dims[1] = in_voxel.size(1);
289
+ p.voxel_dims[2] = in_voxel.size(2);
290
+ p.voxel_strides[0] = in_voxel.stride(0);
291
+ p.voxel_strides[1] = in_voxel.stride(1);
292
+ p.voxel_strides[2] = in_voxel.stride(2);
293
+
294
+ p.img_dims[0] = img_dims[0];
295
+ p.img_dims[1] = img_dims[1];
296
+
297
+ // Create output tensors
298
+ // For Minecraft Seg Mask
299
+ torch::Tensor out_voxel_id =
300
+ torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1},
301
+ torch::TensorOptions().dtype(torch::kInt32).device(device));
302
+
303
+ torch::Tensor out_depth;
304
+ // Produce two sets of localcoords, one for entry point, the other one for
305
+ // exit point. They share the same corner_ids.
306
+ out_depth = torch::empty(
307
+ {2, p.img_dims[0], p.img_dims[1], p.max_samples, 1},
308
+ torch::TensorOptions().dtype(torch::kFloat32).device(device));
309
+
310
+ torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3},
311
+ torch::TensorOptions()
312
+ .dtype(torch::kFloat32)
313
+ .device(device)
314
+ .requires_grad(false));
315
+
316
+ const int TILE_DIM = 8;
317
+ dim3 dimGrid((p.img_dims[1] + TILE_DIM - 1) / TILE_DIM,
318
+ (p.img_dims[0] + TILE_DIM - 1) / TILE_DIM, 1);
319
+ dim3 dimBlock(TILE_DIM, TILE_DIM, 1);
320
+
321
+ ray_voxel_intersection_perspective_kernel<TILE_DIM>
322
+ <<<dimGrid, dimBlock, 0, stream>>>(
323
+ out_voxel_id.data_ptr<int32_t>(), out_depth.data_ptr<float>(),
324
+ out_raydirs.data_ptr<float>(), in_voxel.data_ptr<int32_t>(), p);
325
+
326
+ cudaError_t err = cudaGetLastError();
327
+ if (err != cudaSuccess) {
328
+ printf("Error in ray_voxel_intersection_perspective_kernel: %s\n",
329
+ cudaGetErrorString(err));
330
+ }
331
+ return {out_voxel_id, out_depth, out_raydirs};
332
+ }
gaussiancity/extensions/voxlib/setup.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: setup.py
4
+ # @Author: NVIDIA Corporation
5
+ # @Date: 2021-10-13 00:00:00
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-10-13 03:00:47
8
+ # @Email: root@haozhexie.com
9
+
10
+ from setuptools import setup
11
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
12
+
13
+ cxx_args = ["-fopenmp"]
14
+ nvcc_args = []
15
+
16
+ setup(
17
+ name="voxlib_ext",
18
+ version="3.0.0",
19
+ ext_modules=[
20
+ CUDAExtension(
21
+ "voxlib",
22
+ [
23
+ "bindings.cpp",
24
+ "ray_voxel_intersection.cu",
25
+ "points_to_volume.cu",
26
+ "maps_to_volume.cu",
27
+ ],
28
+ extra_compile_args={"cxx": cxx_args, "nvcc": nvcc_args},
29
+ )
30
+ ],
31
+ cmdclass={"build_ext": BuildExtension},
32
+ )
gaussiancity/extensions/voxlib/voxlib_common.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, check out LICENSE.md
5
+ #ifndef VOXLIB_COMMON_H
6
+ #define VOXLIB_COMMON_H
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) \
10
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
11
+ #define CHECK_INPUT(x) \
12
+ CHECK_CUDA(x); \
13
+ CHECK_CONTIGUOUS(x)
14
+ #define CHECK_CPU(x) \
15
+ TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor")
16
+
17
+ #include <cuda.h>
18
+ #include <cuda_runtime.h>
19
+ // CUDA vector math functions
20
+ __host__ __device__ __forceinline__ int floor_div(int a, int b) {
21
+ int c = a / b;
22
+
23
+ if (c * b > a) {
24
+ c--;
25
+ }
26
+
27
+ return c;
28
+ }
29
+
30
+ template <typename scalar_t>
31
+ __host__ __forceinline__ void cross(scalar_t *r, const scalar_t *a,
32
+ const scalar_t *b) {
33
+ r[0] = a[1] * b[2] - a[2] * b[1];
34
+ r[1] = a[2] * b[0] - a[0] * b[2];
35
+ r[2] = a[0] * b[1] - a[1] * b[0];
36
+ }
37
+
38
+ __device__ __host__ __forceinline__ float dot(const float *a, const float *b) {
39
+ return a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
40
+ }
41
+
42
+ template <typename scalar_t, int ndim>
43
+ __device__ __host__ __forceinline__ void copyarr(scalar_t *r,
44
+ const scalar_t *a) {
45
+ #pragma unroll
46
+ for (int i = 0; i < ndim; i++) {
47
+ r[i] = a[i];
48
+ }
49
+ }
50
+
51
+ // TODO: use rsqrt to speed up
52
+ // inplace version
53
+ template <typename scalar_t, int ndim>
54
+ __device__ __host__ __forceinline__ void normalize(scalar_t *a) {
55
+ scalar_t vec_len = 0.0f;
56
+ #pragma unroll
57
+ for (int i = 0; i < ndim; i++) {
58
+ vec_len += a[i] * a[i];
59
+ }
60
+ vec_len = sqrtf(vec_len);
61
+ #pragma unroll
62
+ for (int i = 0; i < ndim; i++) {
63
+ a[i] /= vec_len;
64
+ }
65
+ }
66
+
67
+ // normalize + copy
68
+ template <typename scalar_t, int ndim>
69
+ __device__ __host__ __forceinline__ void normalize(scalar_t *r,
70
+ const scalar_t *a) {
71
+ scalar_t vec_len = 0.0f;
72
+ #pragma unroll
73
+ for (int i = 0; i < ndim; i++) {
74
+ vec_len += a[i] * a[i];
75
+ }
76
+ vec_len = sqrtf(vec_len);
77
+ #pragma unroll
78
+ for (int i = 0; i < ndim; i++) {
79
+ r[i] = a[i] / vec_len;
80
+ }
81
+ }
82
+
83
+ #endif // VOXLIB_COMMON_H
gaussiancity/generator.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: generator.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2024-03-09 20:36:52
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-09-23 20:49:35
8
+ # @Email: root@haozhexie.com
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ import extensions.grid_encoder
15
+ import gaussiancity.pt_v3
16
+
17
+
18
+ class Generator(torch.nn.Module):
19
+ def __init__(self, cfg, n_classes, proj_size):
20
+ super(Generator, self).__init__()
21
+ self.cfg = cfg
22
+ self.n_classes = n_classes
23
+ if cfg.ENCODER == "GLOBAL":
24
+ self.proj_encoder = GlobalEncoder(
25
+ n_classes, cfg.GLOBAL_ENCODER_N_BLOCKS, cfg.ENCODER_OUT_DIM - 3
26
+ )
27
+ elif cfg.ENCODER == "LOCAL":
28
+ self.proj_encoder = LocalEncoder(n_classes, cfg.ENCODER_OUT_DIM - 3)
29
+ elif cfg.ENCODER is None:
30
+ self.proj_encoder = None
31
+ assert cfg.ENCODER_OUT_DIM == 3
32
+ else:
33
+ raise ValueError("Unknown encoder: %s" % cfg.ENCODER)
34
+
35
+ if cfg.POS_EMD == "HASH_GRID":
36
+ pt_feat_dim = cfg.HASH_GRID_N_LEVELS * cfg.HASH_GRID_LEVEL_DIM
37
+ self.pos_encoder = extensions.grid_encoder.GridEncoder(
38
+ in_channels=cfg.ENCODER_OUT_DIM,
39
+ desired_resolution=proj_size,
40
+ n_levels=cfg.HASH_GRID_N_LEVELS,
41
+ lvl_channels=cfg.HASH_GRID_LEVEL_DIM,
42
+ )
43
+ elif cfg.POS_EMD == "SIN_COS":
44
+ pt_feat_dim = 2 * cfg.ENCODER_OUT_DIM * cfg.SIN_COS_FREQ_BENDS
45
+ self.pos_encoder = SinCosEncoder(cfg.SIN_COS_FREQ_BENDS)
46
+ else:
47
+ raise ValueError("Unknown positional encoder: %s" % cfg.POS_EMD)
48
+
49
+ if cfg.PTV3.ENABLED:
50
+ self.pt_net = gaussiancity.pt_v3.PointTransformerV3(
51
+ in_channels=pt_feat_dim,
52
+ order=cfg.PTV3.ORDER,
53
+ stride=cfg.PTV3.STRIDE,
54
+ enc_depths=cfg.PTV3.ENC_DEPTHS,
55
+ enc_channels=cfg.PTV3.ENC_CHANNELS,
56
+ enc_num_head=cfg.PTV3.ENC_N_HEAD,
57
+ enc_patch_size=cfg.PTV3.ENC_PATCH_SIZE,
58
+ dec_depths=cfg.PTV3.DEC_DEPTHS,
59
+ dec_channels=cfg.PTV3.DEC_CHANNELS,
60
+ dec_num_head=cfg.PTV3.DEC_N_HEAD,
61
+ dec_patch_size=cfg.PTV3.DEC_PATCH_SIZE,
62
+ enable_flash=cfg.PTV3.ENABLE_FLASH_ATTN,
63
+ )
64
+ pt_feat_dim += cfg.PTV3.DEC_CHANNELS[0]
65
+ else:
66
+ self.pt_net = None
67
+
68
+ self.ga_mlp = GaussianAttrMLP(
69
+ n_classes,
70
+ pt_feat_dim,
71
+ cfg.Z_DIM,
72
+ cfg.MLP_HIDDEN_DIM,
73
+ cfg.MLP_N_SHARED_LAYERS,
74
+ cfg.ATTR_FACTORS,
75
+ cfg.ATTR_N_LAYERS,
76
+ )
77
+
78
+ def forward(self, proj_uv, rel_xyz, batch_idx, onehots, z, proj_hf, proj_seg):
79
+ # Ref: https://github.com/hzxie/CityDreamer/blob/master/models/gancraft.py#L381
80
+ if self.cfg.ENCODER == "GLOBAL":
81
+ proj_feat = self.proj_encoder(proj_hf, proj_seg)
82
+ pt_feat = proj_feat.unsqueeze(dim=1).repeat(1, proj_uv.size(1), 1)
83
+ elif self.cfg.ENCODER == "LOCAL":
84
+ proj_feat = self.proj_encoder(proj_hf, proj_seg)
85
+ pt_feat = (
86
+ F.grid_sample(proj_feat, proj_uv.unsqueeze(dim=1), align_corners=True)
87
+ .squeeze(dim=2)
88
+ .permute(0, 2, 1)
89
+ )
90
+ elif self.cfg.ENCODER is None:
91
+ pt_feat = torch.empty(
92
+ rel_xyz.size(0), rel_xyz.size(1), 0, device=proj_uv.device
93
+ )
94
+
95
+ # print(pt_feat.size()) # torch.Size([B, n_pts, cfg.ENCODER_OUT_DIM - 3])
96
+ pt_feat = torch.cat([pt_feat, rel_xyz], dim=2)
97
+ # print(pt_feat.size()) # torch.Size([B, n_pts, cfg.ENCODER_OUT_DIM])
98
+ pt_feat1 = self.pos_encoder(pt_feat)
99
+ # print(pt_feat1.size()) # torch.Size([B, n_pts, pt_feat_dim])
100
+ if self.pt_net is None:
101
+ pt_feat2 = torch.empty(
102
+ rel_xyz.size(0), rel_xyz.size(1), 0, device=proj_uv.device
103
+ )
104
+ else:
105
+ pt_feat2 = self.pt_net(batch_idx, pt_feat1, rel_xyz)
106
+
107
+ # print(pt_feat2.size()) # torch.Size([B, n_pts, pt_feat_dim])
108
+ return self.ga_mlp(torch.cat([pt_feat1, pt_feat2], dim=-1), onehots, z)
109
+
110
+
111
+ class GlobalEncoder(torch.nn.Module):
112
+ def __init__(self, n_classes, n_blocks, out_channels):
113
+ super(GlobalEncoder, self).__init__()
114
+ self.hf_conv = torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1)
115
+ self.seg_conv = torch.nn.Conv2d(
116
+ n_classes,
117
+ 8,
118
+ kernel_size=3,
119
+ stride=2,
120
+ padding=1,
121
+ )
122
+ conv_blocks = []
123
+ cur_hidden_channels = 16
124
+ for _ in range(1, n_blocks):
125
+ conv_blocks.append(
126
+ SRTConvBlock(in_channels=cur_hidden_channels, out_channels=None)
127
+ )
128
+ cur_hidden_channels *= 2
129
+
130
+ self.conv_blocks = torch.nn.Sequential(*conv_blocks)
131
+ self.fc1 = torch.nn.Linear(cur_hidden_channels, 16)
132
+ self.fc2 = torch.nn.Linear(16, out_channels)
133
+ self.act = torch.nn.LeakyReLU(0.2)
134
+
135
+ def forward(self, proj_hf, proj_seg):
136
+ hf = self.act(self.hf_conv(proj_hf))
137
+ seg = self.act(self.seg_conv(proj_seg))
138
+ out = torch.cat([hf, seg], dim=1)
139
+ for layer in self.conv_blocks:
140
+ out = self.act(layer(out))
141
+
142
+ out = out.permute(0, 2, 3, 1)
143
+ out = torch.mean(out.reshape(out.shape[0], -1, out.shape[-1]), dim=1)
144
+ cond = self.act(self.fc1(out))
145
+ cond = torch.tanh(self.fc2(cond))
146
+ return cond
147
+
148
+
149
+ class LocalEncoder(torch.nn.Module):
150
+ def __init__(self, n_classes, out_channels):
151
+ super(LocalEncoder, self).__init__()
152
+ self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3)
153
+ self.seg_conv = torch.nn.Conv2d(
154
+ n_classes, 32, kernel_size=7, stride=2, padding=3
155
+ )
156
+ self.bn1 = torch.nn.GroupNorm(32, 64)
157
+ self.conv2 = ResConvBlock(64, 128)
158
+ self.conv3 = ResConvBlock(128, 256)
159
+ self.conv4 = ResConvBlock(256, 512)
160
+ self.dconv5 = torch.nn.ConvTranspose2d(
161
+ 512, 128, kernel_size=4, stride=2, padding=1
162
+ )
163
+ self.dconv6 = torch.nn.ConvTranspose2d(
164
+ 128, 32, kernel_size=4, stride=2, padding=1
165
+ )
166
+ self.dconv7 = torch.nn.Conv2d(32, out_channels, kernel_size=1)
167
+
168
+ def forward(self, proj_hf, proj_seg):
169
+ hf = self.hf_conv(proj_hf)
170
+ seg = self.seg_conv(proj_seg)
171
+ out = F.relu(self.bn1(torch.cat([hf, seg], dim=1)), inplace=True)
172
+ # print(out.size()) # torch.Size([N, 64, H/2, W/2])
173
+ out = F.avg_pool2d(self.conv2(out), 2, stride=2)
174
+ # print(out.size()) # torch.Size([N, 128, H/4, W/4])
175
+ out = self.conv3(out)
176
+ # print(out.size()) # torch.Size([N, 256, H/4, W/4])
177
+ out = self.conv4(out)
178
+ # print(out.size()) # torch.Size([N, 512, H/4, W/4])
179
+ out = self.dconv5(out)
180
+ # print(out.size()) # torch.Size([N, 128, H/2, W/2])
181
+ out = self.dconv6(out)
182
+ # print(out.size()) # torch.Size([N, 32, H, W])
183
+ out = self.dconv7(out)
184
+ # print(out.size()) # torch.Size([N, OUT_DIM - 1, H, W])
185
+ return torch.tanh(out)
186
+
187
+
188
+ class SRTConvBlock(torch.nn.Module):
189
+ def __init__(self, in_channels, hidden_channels=None, out_channels=None):
190
+ super(SRTConvBlock, self).__init__()
191
+ if hidden_channels is None:
192
+ hidden_channels = in_channels
193
+ if out_channels is None:
194
+ out_channels = 2 * hidden_channels
195
+
196
+ self.layers = torch.nn.Sequential(
197
+ torch.nn.Conv2d(
198
+ in_channels,
199
+ hidden_channels,
200
+ stride=1,
201
+ kernel_size=3,
202
+ padding=1,
203
+ bias=False,
204
+ ),
205
+ torch.nn.ReLU(),
206
+ torch.nn.Conv2d(
207
+ hidden_channels,
208
+ out_channels,
209
+ stride=2,
210
+ kernel_size=3,
211
+ padding=1,
212
+ bias=False,
213
+ ),
214
+ torch.nn.ReLU(),
215
+ )
216
+
217
+ def forward(self, x):
218
+ return self.layers(x)
219
+
220
+
221
+ class ResConvBlock(torch.nn.Module):
222
+ def __init__(self, in_channels, out_channels, bias=False):
223
+ super(ResConvBlock, self).__init__()
224
+ # conv3x3(in_planes, int(out_planes / 2))
225
+ self.conv1 = torch.nn.Conv2d(
226
+ in_channels,
227
+ out_channels // 2,
228
+ kernel_size=3,
229
+ stride=1,
230
+ padding=1,
231
+ bias=bias,
232
+ )
233
+ # conv3x3(int(out_planes / 2), int(out_planes / 4))
234
+ self.conv2 = torch.nn.Conv2d(
235
+ out_channels // 2,
236
+ out_channels // 4,
237
+ kernel_size=3,
238
+ stride=1,
239
+ padding=1,
240
+ bias=bias,
241
+ )
242
+ # conv3x3(int(out_planes / 4), int(out_planes / 4))
243
+ self.conv3 = torch.nn.Conv2d(
244
+ out_channels // 4,
245
+ out_channels // 4,
246
+ kernel_size=3,
247
+ stride=1,
248
+ padding=1,
249
+ bias=bias,
250
+ )
251
+ self.bn1 = torch.nn.GroupNorm(32, in_channels)
252
+ self.bn2 = torch.nn.GroupNorm(32, out_channels // 2)
253
+ self.bn3 = torch.nn.GroupNorm(32, out_channels // 4)
254
+ self.bn4 = torch.nn.GroupNorm(32, in_channels)
255
+
256
+ if in_channels != out_channels:
257
+ self.downsample = torch.nn.Sequential(
258
+ self.bn4,
259
+ torch.nn.ReLU(True),
260
+ torch.nn.Conv2d(
261
+ in_channels, out_channels, kernel_size=1, stride=1, bias=False
262
+ ),
263
+ )
264
+ else:
265
+ self.downsample = None
266
+
267
+ def forward(self, x):
268
+ residual = x
269
+ # print(residual.size()) # torch.Size([N, 64, H, W])
270
+ out1 = self.bn1(x)
271
+ out1 = F.relu(out1, True)
272
+ out1 = self.conv1(out1)
273
+ # print(out1.size()) # torch.Size([N, 64, H, W])
274
+ out2 = self.bn2(out1)
275
+ out2 = F.relu(out2, True)
276
+ out2 = self.conv2(out2)
277
+ # print(out2.size()) # torch.Size([N, 32, H, W])
278
+ out3 = self.bn3(out2)
279
+ out3 = F.relu(out3, True)
280
+ out3 = self.conv3(out3)
281
+ # print(out3.size()) # torch.Size([N, 32, H, W])
282
+ out3 = torch.cat((out1, out2, out3), dim=1)
283
+ # print(out3.size()) # torch.Size([N, 128, H, W])
284
+ if self.downsample is not None:
285
+ residual = self.downsample(residual)
286
+ # print(residual.size()) # torch.Size([N, 128, H, W])
287
+ out3 += residual
288
+ return out3
289
+
290
+
291
+ class SinCosEncoder(torch.nn.Module):
292
+ def __init__(self, n_freq_bands=8):
293
+ super(SinCosEncoder, self).__init__()
294
+ self.freq_bands = 2.0 ** torch.linspace(
295
+ 0,
296
+ n_freq_bands - 1,
297
+ steps=n_freq_bands,
298
+ )
299
+
300
+ def forward(self, features):
301
+ cord_sin = torch.cat(
302
+ [torch.sin(features * fb) for fb in self.freq_bands], dim=-1
303
+ )
304
+ cord_cos = torch.cat(
305
+ [torch.cos(features * fb) for fb in self.freq_bands], dim=-1
306
+ )
307
+ return torch.cat([cord_sin, cord_cos], dim=-1)
308
+
309
+
310
+ class GaussianAttrMLP(torch.nn.Module):
311
+ r"""MLP with affine modulation."""
312
+
313
+ def __init__(
314
+ self,
315
+ n_classes,
316
+ in_dim,
317
+ z_dim,
318
+ hidden_dim,
319
+ n_shared_layers,
320
+ factors={},
321
+ n_layers={},
322
+ ):
323
+ super(GaussianAttrMLP, self).__init__()
324
+ self.factors = factors
325
+ self.n_layers = n_layers
326
+ self.n_shared_layers = n_shared_layers
327
+ self.act = torch.nn.LeakyReLU(negative_slope=0.2)
328
+ self.fc_m_a = torch.nn.Linear(
329
+ n_classes,
330
+ hidden_dim,
331
+ bias=False,
332
+ )
333
+ self.fc_1 = torch.nn.Linear(
334
+ in_dim,
335
+ hidden_dim,
336
+ )
337
+ for i in range(2, n_shared_layers + 1):
338
+ setattr(
339
+ self,
340
+ "fc_%d" % i,
341
+ (
342
+ ModLinear(
343
+ hidden_dim,
344
+ hidden_dim,
345
+ z_dim,
346
+ bias=False,
347
+ mod_bias=True,
348
+ output_mode=True,
349
+ )
350
+ if z_dim is not None
351
+ else torch.nn.Linear(hidden_dim, hidden_dim)
352
+ ),
353
+ )
354
+ for k in factors.keys():
355
+ assert k in ["xyz", "rgb", "scale", "opacity"], "Unknwon key: %s" % k
356
+ for i in range(n_layers[k]):
357
+ setattr(
358
+ self,
359
+ "fc_%d_%s_%d" % (n_shared_layers + 1, k, i),
360
+ (
361
+ ModLinear(
362
+ hidden_dim,
363
+ hidden_dim,
364
+ z_dim,
365
+ bias=False,
366
+ mod_bias=True,
367
+ output_mode=True,
368
+ )
369
+ if z_dim is not None
370
+ else torch.nn.Linear(hidden_dim, hidden_dim)
371
+ ),
372
+ )
373
+ setattr(
374
+ self,
375
+ "fc_out_%s" % k,
376
+ torch.nn.Linear(
377
+ hidden_dim,
378
+ 1 if k == "opacity" else 3,
379
+ ),
380
+ )
381
+
382
+ def forward(self, pt_feat, onehots, zs):
383
+ b, n, _ = pt_feat.size()
384
+
385
+ f = self.fc_1(pt_feat)
386
+ f = f + self.fc_m_a(onehots)
387
+ f = self.act(f)
388
+ if zs is None:
389
+ output = self._instance_forward(f)
390
+ else:
391
+ output = {
392
+ k: torch.zeros(b, n, 1 if k == "opacity" else 3, device=pt_feat.device)
393
+ for k in self.factors.keys()
394
+ }
395
+ for v in zs.values():
396
+ z = v["z"]
397
+ idx = v["idx"]
398
+ _output = self._instance_forward(f[idx].unsqueeze(dim=0), z)
399
+ for k, v in _output.items():
400
+ output[k][idx] = v
401
+
402
+ return output
403
+
404
+ def _instance_forward(self, f, z=None):
405
+ for i in range(2, self.n_shared_layers + 1):
406
+ fc = getattr(self, "fc_%d" % i)
407
+ f = self.act(fc(f, z) if z is not None else fc(f))
408
+
409
+ output = {}
410
+ for k in self.factors.keys():
411
+ _f = f.clone()
412
+ for i in range(self.n_layers[k]):
413
+ _fc = getattr(self, "fc_%d_%s_%d" % (self.n_shared_layers + 1, k, i))
414
+ _f = self.act(_fc(_f, z) if z is not None else _fc(f))
415
+
416
+ fc_out = getattr(self, "fc_out_%s" % k)
417
+ output[k] = fc_out(_f)
418
+
419
+ if "xyz" in self.factors:
420
+ output["xyz"] = (torch.sigmoid(output["xyz"]) - 0.5) * self.factors["xyz"]
421
+ if "rgb" in self.factors:
422
+ output["rgb"] = (torch.sigmoid(output["rgb"]) - 0.5) * self.factors["rgb"]
423
+ if "scale" in self.factors:
424
+ output["scale"] = 1 + output["scale"].clamp(-1, 1) * self.factors["scale"]
425
+ if "opacity" in self.factors:
426
+ output["opacity"] = torch.sigmoid(output["opacity"]) * self.factors[
427
+ "opacity"
428
+ ] + (1 - self.factors["opacity"])
429
+
430
+ return output
431
+
432
+
433
+ class ModLinear(torch.nn.Module):
434
+ r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod).
435
+ Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across
436
+ multiple inputs.
437
+ Args:
438
+ in_features (int): Number of input features.
439
+ out_features (int): Number of output features.
440
+ style_features (int): Number of style features.
441
+ bias (bool): Apply additive bias before the activation function?
442
+ mod_bias (bool): Whether to modulate bias.
443
+ output_mode (bool): If True, modulate output instead of input.
444
+ weight_gain (float): Initialization gain
445
+ """
446
+
447
+ def __init__(
448
+ self,
449
+ in_features,
450
+ out_features,
451
+ style_features,
452
+ bias=True,
453
+ mod_bias=True,
454
+ output_mode=False,
455
+ weight_gain=1,
456
+ bias_init=0,
457
+ ):
458
+ super(ModLinear, self).__init__()
459
+ weight_gain = weight_gain / np.sqrt(in_features)
460
+ self.weight = torch.nn.Parameter(
461
+ torch.randn([out_features, in_features]) * weight_gain
462
+ )
463
+ self.bias = (
464
+ torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
465
+ if bias
466
+ else None
467
+ )
468
+ self.weight_alpha = torch.nn.Parameter(
469
+ torch.randn([in_features, style_features]) / np.sqrt(style_features)
470
+ )
471
+ self.bias_alpha = torch.nn.Parameter(
472
+ torch.full([in_features], 1, dtype=torch.float)
473
+ ) # init to 1
474
+ self.weight_beta = None
475
+ self.bias_beta = None
476
+ self.mod_bias = mod_bias
477
+ self.output_mode = output_mode
478
+ if mod_bias:
479
+ if output_mode:
480
+ mod_bias_dims = out_features
481
+ else:
482
+ mod_bias_dims = in_features
483
+ self.weight_beta = torch.nn.Parameter(
484
+ torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features)
485
+ )
486
+ self.bias_beta = torch.nn.Parameter(
487
+ torch.full([mod_bias_dims], 0, dtype=torch.float)
488
+ )
489
+
490
+ @staticmethod
491
+ def _linear_f(x, w, b):
492
+ w = w.to(x.dtype)
493
+ x_shape = x.shape
494
+ x = x.reshape(-1, x_shape[-1])
495
+ if b is not None:
496
+ b = b.to(x.dtype)
497
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
498
+ else:
499
+ x = x.matmul(w.t())
500
+ x = x.reshape(*x_shape[:-1], -1)
501
+ return x
502
+
503
+ # x: B, ... , Cin
504
+ # z: B, ... , Cz
505
+ def forward(self, x, z):
506
+ x_shape = x.shape
507
+ z_shape = z.shape
508
+ x = x.reshape(x_shape[0], -1, x_shape[-1])
509
+ z = z.reshape(z_shape[0], -1, z_shape[-1])
510
+
511
+ alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I]
512
+ w = self.weight.to(x.dtype) # [O I]
513
+ w = w.unsqueeze(0) * alpha
514
+
515
+ if self.mod_bias:
516
+ beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I]
517
+ if not self.output_mode:
518
+ x = x + beta
519
+
520
+ b = self.bias
521
+ if b is not None:
522
+ b = b.to(x.dtype)[None, None, :]
523
+ if self.mod_bias and self.output_mode:
524
+ if b is None:
525
+ b = beta
526
+ else:
527
+ b = b + beta
528
+
529
+ # [B ? I] @ [B I O] = [B ? O]
530
+ if b is not None:
531
+ x = torch.baddbmm(b, x, w.transpose(1, 2))
532
+ else:
533
+ x = x.bmm(w.transpose(1, 2))
534
+
535
+ x = x.reshape(*x_shape[:-1], x.shape[-1])
536
+ return x
gaussiancity/inference.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: inference.py
4
+ # @Author: Haozhe Xie
5
+ # @Date: 2024-03-02 16:30:00
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-10-13 15:17:20
8
+ # @Email: root@haozhexie.com
9
+
10
+ import cv2
11
+ import math
12
+ import numpy as np
13
+ import scipy.spatial.transform
14
+ import torch
15
+
16
+ from tqdm import tqdm
17
+
18
+ CLASSES = {
19
+ "NULL": 0,
20
+ "ROAD": 1,
21
+ "BLDG_FACADE": 2,
22
+ "GREEN_LANDS": 3,
23
+ "CONSTRUCTION": 4,
24
+ "COAST_ZONES": 5,
25
+ "ZONE": 6,
26
+ "BLDG_ROOF": 7,
27
+ }
28
+ SCALES = {
29
+ "ROAD": 2,
30
+ "BLDG_FACADE": 1,
31
+ "BLDG_ROOF": 1,
32
+ "GREEN_LANDS": 2,
33
+ "CONSTRUCTION": 1,
34
+ "COAST_ZONES": 4,
35
+ "ZONE": 2,
36
+ }
37
+ CONSTANTS = {
38
+ "CAM_K": [1528.1469407006614, 0, 480, 0, 1528.1469407006614, 270, 0, 0, 1],
39
+ "SENSOR_SIZE": [960, 540],
40
+ "BLDG_INST_RANGE": [100, 16384],
41
+ "PROJECTION_SIZE": 2048,
42
+ "POINT_SCALE_FACTOR": 0.5,
43
+ "SPECIAL_Z_SCALE_CLASSES": [
44
+ CLASSES["ROAD"],
45
+ CLASSES["COAST_ZONES"],
46
+ CLASSES["ZONE"],
47
+ ],
48
+ }
49
+
50
+
51
+ def get_instance_seg_map(seg_map):
52
+ # Mapping constructions to buildings
53
+ seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLDG_FACADE"]
54
+ # Use connected components to get building instances
55
+ import pdb; pdb.set_trace()
56
+ _, labels, _, _ = cv2.connectedComponentsWithStats(
57
+ (seg_map == CLASSES["BLDG_FACADE"]).astype(np.uint8), connectivity=4
58
+ )
59
+ # Remove non-building instance masks
60
+ labels[seg_map != CLASSES["BLDG_FACADE"]] = 0
61
+ # Building instance mask
62
+ building_mask = labels != 0
63
+ # Make building instance IDs are even numbers and start from 10
64
+ # Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
65
+ labels = (labels + CONSTANTS["BLDG_INST_RANGE"][0]) * 2
66
+
67
+ seg_map[seg_map == CLASSES["BLDG_FACADE"]] = 0
68
+ seg_map = seg_map * (1 - building_mask) + labels * building_mask
69
+ assert np.max(labels) < 2147483648
70
+ return seg_map.astype(np.int32)
71
+
72
+
73
+ def get_point_map(seg_map):
74
+ inverted_index = {v: k for k, v in CLASSES.items()}
75
+ pts_map = np.zeros(seg_map.shape, dtype=bool)
76
+ for c in np.unique(seg_map):
77
+ cls_name = inverted_index[c]
78
+ if cls_name == "NULL":
79
+ continue
80
+
81
+ mask = seg_map == c
82
+ pt_map = _get_point_map(seg_map.shape, SCALES[cls_name])
83
+ pt_map[~mask] = False
84
+ pts_map += pt_map
85
+
86
+ return pts_map
87
+
88
+
89
+ def _get_point_map(map_size, stride):
90
+ pts_map = np.zeros(map_size, dtype=bool)
91
+ ys = np.arange(0, map_size[0], stride)
92
+ xs = np.arange(0, map_size[1], stride)
93
+ coords = np.stack(np.meshgrid(ys, xs), axis=-1).reshape(-1, 2)
94
+ pts_map[coords[:, 0], coords[:, 1]] = True
95
+ return pts_map
96
+
97
+
98
+ def get_centers(ins_map, td_hf):
99
+ centers = {}
100
+ instances = np.unique(ins_map)
101
+ for i in tqdm(instances, desc="Calculating centers ..."):
102
+ if i >= CONSTANTS["BLDG_INST_RANGE"][0]:
103
+ ds_mask = ins_map == i
104
+ contours, _ = cv2.findContours(
105
+ ds_mask.astype(np.uint8),
106
+ cv2.RETR_EXTERNAL,
107
+ cv2.CHAIN_APPROX_SIMPLE,
108
+ )
109
+ contours = np.vstack(contours).reshape(-1, 2)
110
+ min_x, max_x = np.min(contours[:, 0]), np.max(contours[:, 0])
111
+ min_y, max_y = np.min(contours[:, 1]), np.max(contours[:, 1])
112
+ max_z = np.max(td_hf[ds_mask]) + 1
113
+ else:
114
+ min_x, max_x = 0, CONSTANTS["PROJECTION_SIZE"]
115
+ min_y, max_y = 0, CONSTANTS["PROJECTION_SIZE"]
116
+ max_z = np.max(td_hf)
117
+
118
+ centers[i] = np.array(
119
+ [
120
+ (min_x + max_x) / 2,
121
+ (min_y + max_y) / 2,
122
+ (max_x - min_x),
123
+ (max_y - min_y),
124
+ max_z,
125
+ ],
126
+ dtype=np.float32,
127
+ )
128
+
129
+ return centers
130
+
131
+
132
+ def generate_city(
133
+ fgm, bgm, city_layout, cx, cy, radius, altitude, azimuth, style_lut=None
134
+ ):
135
+ import gaussiancity.extensions.diff_gaussian_rasterization as dgr
136
+
137
+ device = torch.device("cuda")
138
+ gr = dgr.GaussianRasterizerWrapper(
139
+ np.array(CONSTANTS["CAM_K"], dtype=np.float32).reshape((3, 3)),
140
+ CONSTANTS["SENSOR_SIZE"],
141
+ flip_lr=True,
142
+ flip_ud=False,
143
+ device=device,
144
+ )
145
+ layout = _get_local_layout(
146
+ city_layout,
147
+ cx,
148
+ cy,
149
+ CONSTANTS["PROJECTION_SIZE"] // 2,
150
+ CONSTANTS["BLDG_INST_RANGE"],
151
+ device,
152
+ )
153
+
154
+ bev_pts = _get_bev_points(layout, SCALES, CLASSES)
155
+ bev_pt_classes = _instances_to_classes(
156
+ bev_pts[:, [3]], CONSTANTS["BLDG_INST_RANGE"], CLASSES
157
+ )
158
+ bev_pt_classes_onehot = _get_onehot_seg(bev_pt_classes, len(CLASSES))
159
+ bev_pt_scales = _get_point_scales(
160
+ bev_pt_classes,
161
+ SCALES,
162
+ CLASSES,
163
+ CONSTANTS["SPECIAL_Z_SCALE_CLASSES"],
164
+ )
165
+ bev_pts = torch.cat([bev_pts, bev_pt_scales, bev_pt_classes_onehot], dim=1)
166
+ # print(bev_pts.shape) # [N, XYZ + Inst + Scale3D + N_CLASSES]
167
+ if style_lut is None:
168
+ style_lut = _get_style_lut(
169
+ layout["CTR"],
170
+ {"BLDG": fgm, "REST": bgm},
171
+ {
172
+ "BLDG": CONSTANTS["BLDG_INST_RANGE"],
173
+ "REST": [0, CONSTANTS["BLDG_INST_RANGE"][0]],
174
+ },
175
+ device,
176
+ )
177
+
178
+ cam_look_at, cam_pose = _get_orbit_camera_pose(
179
+ radius, altitude, azimuth, CONSTANTS["PROJECTION_SIZE"] // 2, device
180
+ )
181
+ vp_idx = _get_visible_points(
182
+ bev_pts[:, :3],
183
+ bev_pt_scales,
184
+ CONSTANTS["CAM_K"],
185
+ CONSTANTS["SENSOR_SIZE"],
186
+ cam_pose[:3],
187
+ cam_look_at,
188
+ )
189
+ gs_attrs = _get_gs_attrs(
190
+ bev_pts[vp_idx],
191
+ layout["TD_HF"].float(),
192
+ layout["SEG"].float(),
193
+ style_lut,
194
+ layout["CTR"],
195
+ {"BLDG": fgm, "REST": bgm},
196
+ CONSTANTS["POINT_SCALE_FACTOR"],
197
+ CONSTANTS["BLDG_INST_RANGE"],
198
+ )
199
+ return _render(gs_attrs, gr, cam_pose)
200
+
201
+
202
+ def _get_local_layout(city_layout, cx, cy, half_proj_size, bldg_inst_range, device):
203
+ x_min, x_max = cx - half_proj_size, cx + half_proj_size
204
+ y_min, y_max = cy - half_proj_size, cy + half_proj_size
205
+
206
+ _layout = {
207
+ k: torch.from_numpy(v[None, None, y_min:y_max, x_min:x_max]).cuda(device)
208
+ for k, v in city_layout.items()
209
+ if k in ["TD_HF", "BU_HF", "SEG", "INS", "PTS"]
210
+ }
211
+ _layout["SEG"] = _get_onehot_seg(_layout["SEG"], len(CLASSES))
212
+
213
+ _instances = torch.unique(_layout["INS"])
214
+ _centers = {}
215
+ for inst in _instances:
216
+ inst = inst.item()
217
+ if inst >= bldg_inst_range[0]:
218
+ _centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device)
219
+ _centers[inst][0] -= x_min
220
+ _centers[inst][1] -= y_min
221
+ _centers[inst + 1] = _centers[inst] # Fix the centers for BLDG_ROOF
222
+ else:
223
+ _centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device)
224
+ _centers[inst][0] = x_min
225
+ _centers[inst][1] = y_min
226
+
227
+ _layout["CTR"] = _centers
228
+ return _layout
229
+
230
+
231
+ def _get_onehot_seg(seg_map, n_classes):
232
+ shape = seg_map.shape
233
+ # shape -> NxCxHxW or NxC
234
+ # assert shape[1] == 1
235
+ output_shape = (shape[0], n_classes, *shape[2:])
236
+
237
+ one_hot_masks = torch.zeros(output_shape, device=seg_map.device, dtype=torch.bool)
238
+ for i in range(n_classes):
239
+ one_hot_masks[:, [i]] = seg_map == i
240
+
241
+ return one_hot_masks
242
+
243
+
244
+ def _get_style_lut(centers, models, inst_ranges, device, z_dim=256):
245
+ lut = {ins: torch.rand(1, z_dim, device=device) for ins in centers.keys()}
246
+ for k, v in models.items():
247
+ if v is None:
248
+ continue
249
+
250
+ if v.module.cfg.Z_DIM is None:
251
+ for i in range(*inst_ranges[k]):
252
+ if i in lut:
253
+ del lut[i]
254
+ continue
255
+
256
+ if hasattr(v.module, "z"):
257
+ zs = v.module.z
258
+ lut.update(
259
+ {
260
+ ins: zs[np.random.choice(list(zs.keys()))].unsqueeze(0)
261
+ for ins in centers.keys()
262
+ }
263
+ )
264
+
265
+ return lut
266
+
267
+
268
+ def _get_orbit_camera_pose(radius, altitude, azimuth, half_proj_size, device):
269
+ cx, cy = half_proj_size, half_proj_size
270
+ theta = np.deg2rad(azimuth)
271
+ cam_x = cx + radius * math.cos(theta)
272
+ cam_y = cy + radius * math.sin(theta)
273
+
274
+ cam_pos = np.array([cam_x, cam_y, altitude], dtype=np.float32)
275
+ cam_look_at = np.array([cx, cy, 1], dtype=np.float32)
276
+ quat = _get_quat_from_look_at(cam_pos, cam_look_at)
277
+ return torch.tensor([*cam_look_at], device=device), torch.tensor(
278
+ [*cam_pos, *quat], device=device
279
+ )
280
+
281
+
282
+ def _get_quat_from_look_at(cam_pos, cam_look_at):
283
+ fwd_vec = cam_look_at - cam_pos
284
+ fwd_vec /= np.linalg.norm(fwd_vec)
285
+ up_vec = np.array([0, 0, 1])
286
+ right_vec = np.cross(up_vec, fwd_vec)
287
+ right_vec /= np.linalg.norm(right_vec)
288
+ up_vec = np.cross(fwd_vec, right_vec)
289
+ R = np.stack([fwd_vec, right_vec, up_vec], axis=1)
290
+ return scipy.spatial.transform.Rotation.from_matrix(R).as_quat()
291
+
292
+
293
+ def _get_bev_points(layout, scales, classes):
294
+ import gaussiancity.extensions.voxlib
295
+
296
+ assert torch.max(layout["INS"]) < 16384
297
+ # torch.nonzero(torch.zeros(2048, 2048, 512).cuda())
298
+ # -> nonzero is not supported for tensors with more than INT_MAX elements
299
+ # torch.nonzero(torch.zeros(2048, 2048, 508).cuda())
300
+ # -> an illegal memory access was encountered
301
+ assert torch.max(layout["TD_HF"]) <= 500
302
+
303
+ volume = gaussiancity.extensions.voxlib.maps_to_volume(
304
+ layout["INS"].squeeze().short(),
305
+ layout["TD_HF"].squeeze().short(),
306
+ layout["BU_HF"].squeeze().short(),
307
+ layout["PTS"].squeeze().bool(),
308
+ torch.tensor(
309
+ [scales[k] if k in scales else 0 for k in classes.keys()],
310
+ dtype=torch.int8,
311
+ device=layout["INS"].device,
312
+ ),
313
+ )
314
+ non_zero_indices = torch.nonzero(volume, as_tuple=False)
315
+ non_zero_values = volume[
316
+ non_zero_indices[:, 0], non_zero_indices[:, 1], non_zero_indices[:, 2]
317
+ ]
318
+ return torch.cat(
319
+ [non_zero_indices.short(), non_zero_values.unsqueeze(dim=1)], dim=1
320
+ )
321
+
322
+
323
+ def _instances_to_classes(instances, bldg_inst_range, bldg_classes):
324
+ bldg_facade_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 0)
325
+ bldg_roof_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 1)
326
+
327
+ classes = instances.clone()
328
+ classes[bldg_facade_idx] = bldg_classes["BLDG_FACADE"]
329
+ classes[bldg_roof_idx] = bldg_classes["BLDG_ROOF"]
330
+ return classes
331
+
332
+
333
+ def _get_point_scales(pt_classes, scales, classes, special_z_scale_classes=[]):
334
+ pt_scales = pt_classes.clone()
335
+ for k, v in scales.items():
336
+ pt_scales[pt_classes == classes[k]] = v
337
+
338
+ pt_scales_3d = torch.ones_like(pt_scales).repeat(1, 3) * pt_scales
339
+ # Set the z-scale = 1 for roads, zones, and waters
340
+ pt_scales_3d[..., 2][
341
+ torch.isin(
342
+ pt_classes.squeeze(dim=-1),
343
+ torch.tensor(
344
+ list(special_z_scale_classes),
345
+ device=pt_classes.device,
346
+ ),
347
+ )
348
+ ] = 1
349
+ return pt_scales_3d
350
+
351
+
352
+ def _get_visible_points(points, scales, K, sensor_size, cam_pos, cam_look_at):
353
+ ## NOTE: Each point is assigned with a unique ID. The values in the rendered map
354
+ ## denotes the visibility of the points. The values are the same as the point IDs.
355
+ # Generate 3D volume
356
+ volume, offsets = _get_volume(points, scales)
357
+ # Ray-voxel intersection
358
+ vp_map = _get_ray_voxel_intersection(
359
+ K, sensor_size, cam_pos - offsets, cam_look_at - cam_pos, volume
360
+ )
361
+ ## Generate the instance segmentation map as a side product
362
+ # ins_map = instances[vp_map]
363
+ # null_mask = vp_map == -1
364
+ # ins_map[null_mask] = null_class_id
365
+
366
+ # Manually release the memory to avoid OOM
367
+ del volume
368
+ torch.cuda.empty_cache()
369
+
370
+ vp_idx = torch.unique(vp_map)
371
+ return vp_idx[vp_idx >= 0]
372
+
373
+
374
+ def _get_volume(points, scales):
375
+ import gaussiancity.extensions.voxlib
376
+
377
+ x_min, x_max = torch.min(points[:, 0]).item(), torch.max(points[:, 0]).item()
378
+ y_min, y_max = torch.min(points[:, 1]).item(), torch.max(points[:, 1]).item()
379
+ z_min, z_max = torch.min(points[:, 2]).item(), torch.max(points[:, 2]).item()
380
+ offsets = torch.tensor(
381
+ [x_min, y_min, z_min], dtype=torch.int16, device=points.device
382
+ )
383
+ # Normalize points coordinates to local coordinate system
384
+ points = _get_localized_pt_cords(points, offsets)
385
+ # Generate an empty 3D volume
386
+ w, h, d = x_max - x_min + 1, y_max - y_min + 1, z_max - z_min + 2
387
+ # Generate point IDs
388
+ # NOTE: The point IDs start from 1 to avoid the conflict with the NULL class.
389
+ assert points.shape[0] < 2147483648
390
+ pt_ids = torch.arange(
391
+ start=1, end=points.shape[0] + 1, dtype=torch.int32, device=points.device
392
+ ).unsqueeze(dim=1)
393
+ volume = gaussiancity.extensions.voxlib.points_to_volume(
394
+ points.contiguous(), pt_ids, scales, h, w, d
395
+ )
396
+ return volume, offsets
397
+
398
+
399
+ def _get_localized_pt_cords(points, offsets):
400
+ points[:, 0] -= offsets[0]
401
+ points[:, 1] -= offsets[1]
402
+ points[:, 2] -= offsets[2] - 1
403
+ return points
404
+
405
+
406
+ def _get_ray_voxel_intersection(K, sensor_size, cam_origin, viewdir, volume):
407
+ import gaussiancity.extensions.voxlib
408
+
409
+ N_MAX_SAMPLES = 1
410
+ voxel_id, _, _ = gaussiancity.extensions.voxlib.ray_voxel_intersection_perspective(
411
+ volume,
412
+ cam_origin[[1, 0, 2]].float(),
413
+ viewdir[[1, 0, 2]].float(),
414
+ torch.tensor([0, 0, 1], dtype=torch.float32),
415
+ K[0],
416
+ [K[5], K[2]],
417
+ [sensor_size[1], sensor_size[0]],
418
+ N_MAX_SAMPLES,
419
+ )
420
+ # NOTE: The point ID for NULL class is -1, the rest point IDs are from 0 to N - 1.
421
+ # The ray_voxel_intersection_perspective seems not accepting the negative values.
422
+ return voxel_id.squeeze() - 1
423
+
424
+
425
+ def get_hf_seg_tensor(part_hf, part_seg, layout_cfg, output_device):
426
+ part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device)
427
+ part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device)
428
+ part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"]
429
+ part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"])
430
+ return torch.cat([part_hf, part_seg], dim=1)
431
+
432
+
433
+ def _masks_to_onehots(masks, n_class, ignored_classes=[]):
434
+ b, h, w = masks.shape
435
+ n_class_actual = n_class - len(ignored_classes)
436
+ one_hot_masks = torch.zeros(
437
+ (b, n_class_actual, h, w), dtype=torch.float32, device=masks.device
438
+ )
439
+
440
+ n_class_cnt = 0
441
+ for i in range(n_class):
442
+ if i not in ignored_classes:
443
+ one_hot_masks[:, n_class_cnt] = masks == i
444
+ n_class_cnt += 1
445
+ return one_hot_masks
446
+
447
+
448
+ def _get_gs_attrs(
449
+ pts,
450
+ proj_hf,
451
+ proj_seg,
452
+ style_lut,
453
+ centers,
454
+ models,
455
+ scale_factor,
456
+ bldg_inst_range,
457
+ ):
458
+ n_pts, _ = pts.shape
459
+ # NOTE: 4: XYZ, Instance ID; 3: Scale; N_CLASSES: One-hot
460
+ # print(pts.shape) # [N, 4 + 3 + N_CLASSES]
461
+ bldg_selector = pts[:, 3] >= bldg_inst_range[0]
462
+ bldg_pts = pts[bldg_selector]
463
+ rest_pts = pts[~bldg_selector]
464
+
465
+ bldg_attrs = _get_pt_input_attrs(
466
+ bldg_pts[:, :4],
467
+ centers,
468
+ style_lut,
469
+ models["BLDG"].module.cfg.Z_DIM,
470
+ bldg_inst_range,
471
+ )
472
+ rest_attrs = _get_pt_input_attrs(
473
+ rest_pts[:, :4],
474
+ centers,
475
+ style_lut,
476
+ models["REST"].module.cfg.Z_DIM,
477
+ bldg_inst_range,
478
+ )
479
+ bldg_colors = _get_gs_colors(
480
+ bldg_pts, bldg_attrs, proj_hf, proj_seg, models["BLDG"]
481
+ )
482
+ rest_colors = _get_gs_colors(
483
+ rest_pts, rest_attrs, proj_hf, proj_seg, models["REST"]
484
+ )
485
+
486
+ abs_xyz = torch.cat([bldg_pts[:, :3], rest_pts[:, :3]], dim=0)
487
+ scales = torch.cat([bldg_pts[:, 4:7], rest_pts[:, 4:7]], dim=0) * scale_factor
488
+ rgb = torch.cat([bldg_colors, rest_colors], dim=0)
489
+ # Attributes with default values
490
+ opacity = torch.ones((n_pts, 1), device=pts.device)
491
+ rotations = torch.cat(
492
+ [
493
+ torch.ones(n_pts, 1, device=pts.device),
494
+ torch.zeros(n_pts, 3, device=pts.device),
495
+ ],
496
+ dim=-1,
497
+ )
498
+ return torch.cat((abs_xyz, opacity, scales, rotations, rgb), dim=-1)
499
+
500
+
501
+ def _get_pt_input_attrs(pts, centers, style_lut, z_dim, bldg_inst_range):
502
+ n_pts = pts.shape[0]
503
+ instances = torch.unique(pts[:, -1])
504
+ rel_xyz = torch.zeros(1, n_pts, 3, dtype=torch.float32, device=pts.device)
505
+ batch_idx = torch.zeros(1, n_pts, dtype=torch.int32, device=pts.device)
506
+ zs = {} if z_dim is not None else None
507
+ for idx, ins in enumerate(instances):
508
+ ins = ins.item()
509
+ is_pts = pts[:, -1] == ins
510
+ cx, cy, w, h, d = centers[ins]
511
+
512
+ if ins >= bldg_inst_range[0]:
513
+ rel_xyz[:, is_pts, 0] = (pts[is_pts, 0] - cx) / w * 2 if w > 0 else 0
514
+ rel_xyz[:, is_pts, 1] = (pts[is_pts, 1] - cy) / h * 2 if h > 0 else 0
515
+ else:
516
+ # Make the BG contiguous
517
+ period_x = torch.ceil((pts[is_pts, 0] / w / 2) - 0.5)
518
+ period_y = torch.ceil((pts[is_pts, 1] / h / 2) - 0.5)
519
+ rel_xyz[:, is_pts, 0] = (
520
+ (pts[is_pts, 0] - 2 * period_x * w) * (-1) ** period_x
521
+ ) / w
522
+ rel_xyz[:, is_pts, 1] = (
523
+ (pts[is_pts, 1] - 2 * period_y * h) * (-1) ** period_y
524
+ ) / h
525
+
526
+ rel_xyz[:, is_pts, 2] = (
527
+ torch.clip(pts[is_pts, 2] / d * 2 - 1, -1, 1) if d > 0 else 0
528
+ )
529
+ batch_idx[:, is_pts] = idx
530
+ if zs is not None:
531
+ zs[ins] = {"z": style_lut[ins], "idx": is_pts.unsqueeze(dim=0)}
532
+
533
+ return rel_xyz, batch_idx, zs
534
+
535
+
536
+ def _get_gs_colors(pts, pt_attrs, proj_hf, proj_seg, model):
537
+ if pts.shape[0] == 0:
538
+ return torch.empty(0, 3, dtype=torch.float32, device=pts.device)
539
+
540
+ abs_xyz, onehots = pts[None, :, :3], pts[None, :, 7:]
541
+ rel_xyz, batch_idx, zs = pt_attrs
542
+ proj_uv = None
543
+ if model.module.cfg.ENCODER is not None:
544
+ proj_uv = get_projection_uv(abs_xyz)
545
+
546
+ with torch.no_grad():
547
+ # TODO: Optimize the _instance_forward in Generator
548
+ gs_attrs = model(
549
+ proj_uv, rel_xyz, batch_idx, onehots.float(), zs, proj_hf, proj_seg
550
+ )
551
+
552
+ return gs_attrs["rgb"].squeeze(dim=0)
553
+
554
+
555
+ def get_projection_uv(xyz, proj_tlp=None, proj_size=2048):
556
+ n_pts = xyz.size(1)
557
+ if proj_tlp is None:
558
+ proj_uv = xyz[..., :2].clone().float()
559
+ else:
560
+ proj_uv = xyz[..., :2] - proj_tlp.unsqueeze(dim=1)
561
+
562
+ assert proj_uv.size() == (xyz.size(0), n_pts, 2)
563
+ proj_uv[..., 0] /= proj_size
564
+ proj_uv[..., 1] /= proj_size
565
+ # Normalize to [-1, 1]
566
+ return proj_uv * 2 - 1
567
+
568
+
569
+ def _render(gs_attrs, rasterizator, cam_pose):
570
+ import torchvision.transforms.functional as F
571
+
572
+ with torch.no_grad():
573
+ img = rasterizator(
574
+ gs_attrs,
575
+ cam_pose[:3], # Position
576
+ cam_pose[3:], # Quaternion
577
+ )
578
+
579
+ img = img.squeeze() / 2 + 0.5
580
+ img = F.adjust_brightness(img, 1.2)
581
+ img = F.adjust_contrast(img, 1.2)
582
+ return (img * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
gaussiancity/pt_v3.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ # @File: pt_v3.py
4
+ # @Author: Xiaoyang Wu <xiaoyang.wu.cs@gmail.com>
5
+ # @Date: 2024-04-01 16:31:36
6
+ # @Last Modified by: Haozhe Xie
7
+ # @Last Modified at: 2024-05-15 22:05:09
8
+ # @Email: root@haozhexie.com
9
+ # Ref:
10
+ # - https://github.com/Pointcept/PointTransformerV3/blob/main/model.py
11
+ # - https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py
12
+
13
+ import addict
14
+ import collections
15
+ import functools
16
+ import flash_attn
17
+ import math
18
+ import torch
19
+ import spconv.pytorch as spconv
20
+ import torch_scatter
21
+ import typing
22
+
23
+
24
+ @torch.inference_mode()
25
+ def offset2bincount(offset):
26
+ return torch.diff(
27
+ offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
28
+ )
29
+
30
+
31
+ @torch.inference_mode()
32
+ def offset2batch(offset):
33
+ bincount = offset2bincount(offset)
34
+ return torch.arange(
35
+ len(bincount), device=offset.device, dtype=torch.long
36
+ ).repeat_interleave(bincount)
37
+
38
+
39
+ @torch.inference_mode()
40
+ def batch2offset(batch):
41
+ return torch.cumsum(batch.bincount(), dim=0).long()
42
+
43
+
44
+ class KeyLUT:
45
+ def __init__(self):
46
+ r256 = torch.arange(256, dtype=torch.int64)
47
+ r512 = torch.arange(512, dtype=torch.int64)
48
+ zero = torch.zeros(256, dtype=torch.int64)
49
+ device = torch.device("cpu")
50
+
51
+ self._encode = {
52
+ device: (
53
+ self.xyz2key(r256, zero, zero, 8),
54
+ self.xyz2key(zero, r256, zero, 8),
55
+ self.xyz2key(zero, zero, r256, 8),
56
+ )
57
+ }
58
+ self._decode = {device: self.key2xyz(r512, 9)}
59
+
60
+ def encode_lut(self, device=torch.device("cpu")):
61
+ if device not in self._encode:
62
+ cpu = torch.device("cpu")
63
+ self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
64
+ return self._encode[device]
65
+
66
+ def decode_lut(self, device=torch.device("cpu")):
67
+ if device not in self._decode:
68
+ cpu = torch.device("cpu")
69
+ self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
70
+ return self._decode[device]
71
+
72
+ def xyz2key(self, x, y, z, depth):
73
+ key = torch.zeros_like(x)
74
+ for i in range(depth):
75
+ mask = 1 << i
76
+ key = (
77
+ key
78
+ | ((x & mask) << (2 * i + 2))
79
+ | ((y & mask) << (2 * i + 1))
80
+ | ((z & mask) << (2 * i + 0))
81
+ )
82
+ return key
83
+
84
+ def key2xyz(self, key, depth):
85
+ x = torch.zeros_like(key)
86
+ y = torch.zeros_like(key)
87
+ z = torch.zeros_like(key)
88
+ for i in range(depth):
89
+ x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
90
+ y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
91
+ z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
92
+ return x, y, z
93
+
94
+
95
+ class Serializator:
96
+ def encode(self, grid_coord, grid_size=0.01, batch=None, depth=16, order="cord"):
97
+ assert order in {"cord", "z", "z-trans", "hilbert", "hilbert-trans"}
98
+ if order in ["z", "z-trans"]:
99
+ self.key_lut = KeyLUT()
100
+ if order == "cord":
101
+ code = self.cord_encode(grid_coord, grid_size)
102
+ elif order == "z":
103
+ code = self.z_order_encode(grid_coord, depth=depth)
104
+ elif order == "z-trans":
105
+ code = self.z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
106
+ elif order == "hilbert":
107
+ code = self.hilbert_encode(grid_coord, depth=depth)
108
+ elif order == "hilbert-trans":
109
+ code = self.hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
110
+ else:
111
+ raise NotImplementedError
112
+
113
+ if batch is not None:
114
+ batch = batch.long()
115
+ code = batch << depth * 3 | code
116
+
117
+ return code
118
+
119
+ def cord_encode(self, grid_coord: torch.Tensor, grid_size: float):
120
+ x, y, z = (
121
+ grid_coord[:, 0].long(),
122
+ grid_coord[:, 1].long(),
123
+ grid_coord[:, 2].long(),
124
+ )
125
+ # we block the support to batch, maintain batched code in Point class
126
+ code = x / grid_size**2 + y / grid_size + z
127
+ return code.long()
128
+
129
+ def z_order_encode(self, grid_coord: torch.Tensor, depth: int = 16):
130
+ x, y, z = (
131
+ grid_coord[:, 0].long(),
132
+ grid_coord[:, 1].long(),
133
+ grid_coord[:, 2].long(),
134
+ )
135
+ # we block the support to batch, maintain batched code in Point class
136
+ code = self._xyz2key(x, y, z, b=None, depth=depth)
137
+ return code
138
+
139
+ def _xyz2key(
140
+ self,
141
+ x: torch.Tensor,
142
+ y: torch.Tensor,
143
+ z: torch.Tensor,
144
+ b: typing.Optional[typing.Union[torch.Tensor, int]] = None,
145
+ depth: int = 16,
146
+ ):
147
+ r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
148
+ based on pre-computed look up tables. The speed of this function is much
149
+ faster than the method based on for-loop.
150
+
151
+ Args:
152
+ x (torch.Tensor): The x coordinate.
153
+ y (torch.Tensor): The y coordinate.
154
+ z (torch.Tensor): The z coordinate.
155
+ b (torch.Tensor or int): The batch index of the coordinates, and should be
156
+ smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
157
+ :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
158
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
159
+ """
160
+ EX, EY, EZ = self.key_lut.encode_lut(x.device)
161
+ x, y, z = x.long(), y.long(), z.long()
162
+
163
+ mask = 255 if depth > 8 else (1 << depth) - 1
164
+ key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
165
+ if depth > 8:
166
+ mask = (1 << (depth - 8)) - 1
167
+ key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
168
+ key = key16 << 24 | key
169
+
170
+ if b is not None:
171
+ b = b.long()
172
+ key = b << 48 | key
173
+
174
+ return key
175
+
176
+ def hilbert_encode(self, grid_coord: torch.Tensor, depth: int = 16):
177
+ return self._hilbert_encode(grid_coord, num_dims=3, num_bits=depth)
178
+
179
+ def _hilbert_encode(self, locs, num_dims, num_bits):
180
+ """Decode an array of locations in a hypercube into a Hilbert integer.
181
+
182
+ This is a vectorized-ish version of the Hilbert curve implementation by John
183
+ Skilling as described in:
184
+
185
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
186
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
187
+
188
+ Params:
189
+ -------
190
+ locs - An ndarray of locations in a hypercube of num_dims dimensions, in
191
+ which each dimension runs from 0 to 2**num_bits-1. The shape can
192
+ be arbitrary, as long as the last dimension of the same has size
193
+ num_dims.
194
+ num_dims - The dimensionality of the hypercube. Integer.
195
+ num_bits - The number of bits for each dimension. Integer.
196
+
197
+ Returns:
198
+ --------
199
+ The output is an ndarray of uint64 integers with the same shape as the
200
+ input, excluding the last dimension, which needs to be num_dims.
201
+ """
202
+
203
+ # Keep around the original shape for later.
204
+ orig_shape = locs.shape
205
+ bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
206
+ bitpack_mask_rev = bitpack_mask.flip(-1)
207
+
208
+ if orig_shape[-1] != num_dims:
209
+ raise ValueError(
210
+ """
211
+ The shape of locs was surprising in that the last dimension was of size
212
+ %d, but num_dims=%d. These need to be equal.
213
+ """
214
+ % (orig_shape[-1], num_dims)
215
+ )
216
+
217
+ if num_dims * num_bits > 63:
218
+ raise ValueError(
219
+ """
220
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
221
+ into a int64. Are you sure you need that many points on your Hilbert
222
+ curve?
223
+ """
224
+ % (num_dims, num_bits, num_dims * num_bits)
225
+ )
226
+
227
+ # Treat the location integers as 64-bit unsigned and then split them up into
228
+ # a sequence of uint8s. Preserve the association by dimension.
229
+ locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
230
+
231
+ # Now turn these into bits and truncate to num_bits.
232
+ gray = (
233
+ locs_uint8.unsqueeze(-1)
234
+ .bitwise_and(bitpack_mask_rev)
235
+ .ne(0)
236
+ .byte()
237
+ .flatten(-2, -1)[..., -num_bits:]
238
+ )
239
+
240
+ # Run the decoding process the other way.
241
+ # Iterate forwards through the bits.
242
+ for bit in range(0, num_bits):
243
+ # Iterate forwards through the dimensions.
244
+ for dim in range(0, num_dims):
245
+ # Identify which ones have this bit active.
246
+ mask = gray[:, dim, bit]
247
+
248
+ # Where this bit is on, invert the 0 dimension for lower bits.
249
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
250
+ gray[:, 0, bit + 1 :], mask[:, None]
251
+ )
252
+
253
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
254
+ to_flip = torch.logical_and(
255
+ torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
256
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
257
+ )
258
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
259
+ gray[:, dim, bit + 1 :], to_flip
260
+ )
261
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
262
+ gray[:, 0, bit + 1 :], to_flip
263
+ )
264
+
265
+ # Now flatten out.
266
+ # Fix: shape '[-1, 0]' is invalid for input of size 192
267
+ gray = gray.swapaxes(1, 2).reshape((gray.size(0), -1))
268
+
269
+ # Convert Gray back to binary.
270
+ hh_bin = self._gray2binary(gray)
271
+
272
+ # Pad back out to 64 bits.
273
+ extra_dims = 64 - gray.size(1)
274
+ padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
275
+
276
+ # Convert binary values into uint8s.
277
+ hh_uint8 = (
278
+ (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
279
+ .sum(2)
280
+ .squeeze()
281
+ .type(torch.uint8)
282
+ )
283
+
284
+ # Convert uint8s into uint64s.
285
+ hh_uint64 = hh_uint8.view(torch.int64).squeeze()
286
+
287
+ return hh_uint64
288
+
289
+ def _gray2binary(self, gray, axis=-1):
290
+ """Convert an array of Gray codes back into binary values.
291
+
292
+ Parameters:
293
+ -----------
294
+ gray: An ndarray of gray codes.
295
+ axis: The axis along which to perform Gray decoding. Default=-1.
296
+
297
+ Returns:
298
+ --------
299
+ Returns an ndarray of binary values.
300
+ """
301
+
302
+ # Loop the log2(bits) number of times necessary, with shift and xor.
303
+ shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
304
+ while shift > 0:
305
+ gray = torch.logical_xor(gray, self._right_shift(gray, shift))
306
+ shift = torch.div(shift, 2, rounding_mode="floor")
307
+ return gray
308
+
309
+ def _right_shift(self, binary, k=1, axis=-1):
310
+ """Right shift an array of binary values.
311
+
312
+ Parameters:
313
+ -----------
314
+ binary: An ndarray of binary values.
315
+
316
+ k: The number of bits to shift. Default 1.
317
+
318
+ axis: The axis along which to shift. Default -1.
319
+
320
+ Returns:
321
+ --------
322
+ Returns an ndarray with zero prepended and the ends truncated, along
323
+ whatever axis was specified."""
324
+
325
+ # If we're shifting the whole thing, just return zeros.
326
+ if binary.shape[axis] <= k:
327
+ return torch.zeros_like(binary)
328
+
329
+ # Determine the padding pattern.
330
+ # padding = [(0,0)] * len(binary.shape)
331
+ # padding[axis] = (k,0)
332
+
333
+ # Determine the slicing pattern to eliminate just the last one.
334
+ slicing = [slice(None)] * len(binary.shape)
335
+ slicing[axis] = slice(None, -k)
336
+ shifted = torch.nn.functional.pad(
337
+ binary[tuple(slicing)], (k, 0), mode="constant", value=0
338
+ )
339
+
340
+ return shifted
341
+
342
+
343
+ class PointModule(torch.nn.Module):
344
+ r"""PointModule
345
+ placeholder, all module subclass from this will take Point in PointSequential.
346
+ """
347
+
348
+ def __init__(self, *args, **kwargs):
349
+ super().__init__(*args, **kwargs)
350
+
351
+
352
+ class Point(addict.Dict):
353
+ """
354
+ Point Structure of Pointcept
355
+
356
+ A Point (point cloud) in Pointcept is a dictionary that contains various properties of
357
+ a batched point cloud. The property with the following names have a specific definition
358
+ as follows:
359
+
360
+ - "coord": original coordinate of point cloud;
361
+ - "grid_coord": grid coordinate for specific grid size (related to GridSampling);
362
+ Point also support the following optional attributes:
363
+ - "offset": if not exist, initialized as batch size is 1;
364
+ - "batch": if not exist, initialized as batch size is 1;
365
+ - "feat": feature of point cloud, default input of model;
366
+ - "grid_size": Grid size of point cloud (related to GridSampling);
367
+ (related to Serialization)
368
+ - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
369
+ - "serialized_code": a list of serialization codes;
370
+ - "serialized_order": a list of serialization order determined by code;
371
+ - "serialized_inverse": a list of inverse mapping determined by code;
372
+ (related to Sparsify: SpConv)
373
+ - "sparse_shape": Sparse shape for Sparse Conv Tensor;
374
+ - "sparse_conv_feat": SparseConvTensor init with information provide by Point;
375
+ """
376
+
377
+ def __init__(self, *args, **kwargs):
378
+ super().__init__(*args, **kwargs)
379
+ self.serializator = Serializator()
380
+ # If one of "offset" or "batch" do not exist, generate by the existing one
381
+ if "batch" not in self.keys() and "offset" in self.keys():
382
+ self["batch"] = offset2batch(self.offset)
383
+ elif "offset" not in self.keys() and "batch" in self.keys():
384
+ self["offset"] = batch2offset(self.batch)
385
+
386
+ def serialization(self, order="z", depth=None, shuffle_orders=False):
387
+ """
388
+ Point Cloud Serialization
389
+
390
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
391
+ """
392
+ assert "batch" in self.keys()
393
+ if "grid_coord" not in self.keys():
394
+ # if you don't want to operate GridSampling in data augmentation,
395
+ # please add the following augmentation into your pipline:
396
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
397
+ # (adjust `grid_size` to what your want)
398
+ assert {"grid_size", "coord"}.issubset(self.keys())
399
+ self["grid_coord"] = torch.div(
400
+ self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
401
+ ).int()
402
+
403
+ if depth is None:
404
+ # Adaptive measure the depth of serialization cube (length = 2 ^ depth)
405
+ depth = int(self.grid_coord.max()).bit_length()
406
+
407
+ self["serialized_depth"] = depth
408
+ # Maximum bit length for serialization code is 63 (int64)
409
+ assert depth * 3 + len(self.offset).bit_length() <= 63
410
+ # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
411
+ # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
412
+ # cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
413
+ # We can unlock the limitation by optimizing the z-order encoding function if necessary.
414
+ assert depth <= 16
415
+
416
+ # The serialization codes are arranged as following structures:
417
+ # [Order1 ([n]),
418
+ # Order2 ([n]),
419
+ # ...
420
+ # OrderN ([n])] (k, n)
421
+ code = [
422
+ self.serializator.encode(
423
+ self.grid_coord, self.grid_size, self.batch, depth, order=order_
424
+ )
425
+ for order_ in order
426
+ ]
427
+ code = torch.stack(code)
428
+ order = torch.argsort(code)
429
+ inverse = torch.zeros_like(order).scatter_(
430
+ dim=1,
431
+ index=order,
432
+ src=torch.arange(0, code.shape[1], device=order.device).repeat(
433
+ code.shape[0], 1
434
+ ),
435
+ )
436
+
437
+ if shuffle_orders:
438
+ perm = torch.randperm(code.shape[0])
439
+ code = code[perm]
440
+ order = order[perm]
441
+ inverse = inverse[perm]
442
+
443
+ self["serialized_code"] = code
444
+ self["serialized_order"] = order
445
+ self["serialized_inverse"] = inverse
446
+
447
+ def sparsify(self, pad=96):
448
+ """
449
+ Point Cloud Serialization
450
+
451
+ Point cloud is sparse, here we use "sparsify" to specifically refer to
452
+ preparing "spconv.SparseConvTensor" for SpConv.
453
+
454
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
455
+
456
+ pad: padding sparse for sparse shape.
457
+ """
458
+ assert {"feat", "batch"}.issubset(self.keys())
459
+ if "grid_coord" not in self.keys():
460
+ # if you don't want to operate GridSampling in data augmentation,
461
+ # please add the following augmentation into your pipline:
462
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
463
+ # (adjust `grid_size` to what your want)
464
+ assert {"grid_size", "coord"}.issubset(self.keys())
465
+ self["grid_coord"] = torch.div(
466
+ self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
467
+ ).int()
468
+
469
+ if "sparse_shape" in self.keys():
470
+ sparse_shape = self.sparse_shape
471
+ else:
472
+ sparse_shape = torch.add(
473
+ torch.max(self.grid_coord, dim=0).values, pad
474
+ ).tolist()
475
+
476
+ sparse_conv_feat = spconv.SparseConvTensor(
477
+ features=self.feat,
478
+ indices=torch.cat(
479
+ [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
480
+ ).contiguous(),
481
+ spatial_shape=sparse_shape,
482
+ batch_size=self.batch[-1].tolist() + 1,
483
+ )
484
+ self["sparse_shape"] = sparse_shape
485
+ self["sparse_conv_feat"] = sparse_conv_feat
486
+
487
+
488
+ class PointSequential(PointModule):
489
+ r"""A sequential container.
490
+ Modules will be added to it in the order they are passed in the constructor.
491
+ Alternatively, an ordered dict of modules can also be passed in.
492
+ """
493
+
494
+ def __init__(self, name="", *args, **kwargs):
495
+ super().__init__()
496
+ self.name = name
497
+ if len(args) == 1 and isinstance(args[0], collections.OrderedDict):
498
+ for key, module in args[0].items():
499
+ self.add_module(key, module)
500
+ else:
501
+ for idx, module in enumerate(args):
502
+ self.add_module(str(idx), module)
503
+ for name, module in kwargs.items():
504
+ if name in self._modules:
505
+ raise ValueError("name exists.")
506
+ self.add_module(name, module)
507
+
508
+ def __getitem__(self, idx):
509
+ if not (-len(self) <= idx < len(self)):
510
+ raise IndexError("index {} is out of range".format(idx))
511
+ if idx < 0:
512
+ idx += len(self)
513
+ it = iter(self._modules.values())
514
+ for i in range(idx):
515
+ next(it)
516
+ return next(it)
517
+
518
+ def __len__(self):
519
+ return len(self._modules)
520
+
521
+ def add(self, module, name=None):
522
+ if name is None:
523
+ name = str(len(self._modules))
524
+ if name in self._modules:
525
+ raise KeyError("name exists")
526
+ self.add_module(name, module)
527
+
528
+ def forward(self, x):
529
+ for module in self._modules.values():
530
+ # Point module
531
+ if isinstance(module, PointModule):
532
+ x = module(x)
533
+ # Spconv module
534
+ elif spconv.modules.is_spconv_module(module):
535
+ if isinstance(x, Point):
536
+ x.sparse_conv_feat = module(x.sparse_conv_feat)
537
+ x.feat = x.sparse_conv_feat.features
538
+ else:
539
+ x = module(x)
540
+ # Fix: Expected more than 1 value per channel when training
541
+ elif isinstance(module, torch.nn.BatchNorm1d) and isinstance(x, Point):
542
+ if x.feat.size(0) != 1:
543
+ x.feat = module(x.feat)
544
+ # PyTorch module
545
+ else:
546
+ if isinstance(x, Point):
547
+ x.feat = module(x.feat)
548
+ if "sparse_conv_feat" in x.keys():
549
+ x.sparse_conv_feat = x.sparse_conv_feat.replace_feature(x.feat)
550
+ elif isinstance(x, spconv.SparseConvTensor):
551
+ if x.indices.shape[0] != 0:
552
+ x = x.replace_feature(module(x.features))
553
+ else:
554
+ x = module(x)
555
+
556
+ return x
557
+
558
+
559
+ class PDNorm(PointModule):
560
+ def __init__(
561
+ self,
562
+ num_features,
563
+ norm_layer,
564
+ context_channels=256,
565
+ conditions=("ScanNet", "S3DIS", "Structured3D"),
566
+ decouple=True,
567
+ adaptive=False,
568
+ ):
569
+ super().__init__()
570
+ self.conditions = conditions
571
+ self.decouple = decouple
572
+ self.adaptive = adaptive
573
+ if self.decouple:
574
+ self.norm = torch.nn.ModuleList(
575
+ [norm_layer(num_features) for _ in conditions]
576
+ )
577
+ else:
578
+ self.norm = norm_layer
579
+ if self.adaptive:
580
+ self.modulation = torch.nn.Sequential(
581
+ torch.nn.SiLU(),
582
+ torch.nn.Linear(context_channels, 2 * num_features, bias=True),
583
+ )
584
+
585
+ def forward(self, point):
586
+ assert {"feat", "condition"}.issubset(point.keys())
587
+ if isinstance(point.condition, str):
588
+ condition = point.condition
589
+ else:
590
+ condition = point.condition[0]
591
+ if self.decouple:
592
+ assert condition in self.conditions
593
+ norm = self.norm[self.conditions.index(condition)]
594
+ else:
595
+ norm = self.norm
596
+ point.feat = norm(point.feat)
597
+ if self.adaptive:
598
+ assert "context" in point.keys()
599
+ shift, scale = self.modulation(point.context).chunk(2, dim=1)
600
+ point.feat = point.feat * (1.0 + scale) + shift
601
+ return point
602
+
603
+
604
+ class RPE(torch.nn.Module):
605
+ def __init__(self, patch_size, num_heads):
606
+ super().__init__()
607
+ self.patch_size = patch_size
608
+ self.num_heads = num_heads
609
+ self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
610
+ self.rpe_num = 2 * self.pos_bnd + 1
611
+ self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
612
+ torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
613
+
614
+ def forward(self, coord):
615
+ idx = (
616
+ coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
617
+ + self.pos_bnd # relative position to positive index
618
+ + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
619
+ )
620
+ out = self.rpe_table.index_select(0, idx.reshape(-1))
621
+ out = out.view(idx.shape + (-1,)).sum(3)
622
+ out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
623
+ return out
624
+
625
+
626
+ class SerializedAttention(PointModule):
627
+ def __init__(
628
+ self,
629
+ channels,
630
+ num_heads,
631
+ patch_size,
632
+ qkv_bias=True,
633
+ qk_scale=None,
634
+ attn_drop=0.0,
635
+ proj_drop=0.0,
636
+ order_index=0,
637
+ enable_rpe=False,
638
+ enable_flash=True,
639
+ upcast_attention=True,
640
+ upcast_softmax=True,
641
+ ):
642
+ super().__init__()
643
+ assert channels % num_heads == 0
644
+ self.channels = channels
645
+ self.num_heads = num_heads
646
+ self.scale = qk_scale or (channels // num_heads) ** -0.5
647
+ self.order_index = order_index
648
+ self.upcast_attention = upcast_attention
649
+ self.upcast_softmax = upcast_softmax
650
+ self.enable_rpe = enable_rpe
651
+ self.enable_flash = enable_flash
652
+ if enable_flash:
653
+ assert (
654
+ enable_rpe is False
655
+ ), "Set enable_rpe to False when enable Flash Attention"
656
+ assert (
657
+ upcast_attention is False
658
+ ), "Set upcast_attention to False when enable Flash Attention"
659
+ assert (
660
+ upcast_softmax is False
661
+ ), "Set upcast_softmax to False when enable Flash Attention"
662
+ assert flash_attn is not None, "Make sure flash_attn is installed."
663
+ self.patch_size = patch_size
664
+ self.attn_drop = attn_drop
665
+ else:
666
+ # when disable flash attention, we still don't want to use mask
667
+ # consequently, patch size will auto set to the
668
+ # min number of patch_size_max and number of points
669
+ self.patch_size_max = patch_size
670
+ self.patch_size = 0
671
+ self.attn_drop = torch.nn.Dropout(attn_drop)
672
+
673
+ self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
674
+ self.proj = torch.nn.Linear(channels, channels)
675
+ self.proj_drop = torch.nn.Dropout(proj_drop)
676
+ self.softmax = torch.nn.Softmax(dim=-1)
677
+ self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
678
+
679
+ @torch.no_grad()
680
+ def get_rel_pos(self, point, order):
681
+ K = self.patch_size
682
+ rel_pos_key = f"rel_pos_{self.order_index}"
683
+ if rel_pos_key not in point.keys():
684
+ grid_coord = point.grid_coord[order]
685
+ grid_coord = grid_coord.reshape(-1, K, 3)
686
+ point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
687
+ return point[rel_pos_key]
688
+
689
+ @torch.no_grad()
690
+ def get_padding_and_inverse(self, point):
691
+ pad_key = "pad"
692
+ unpad_key = "unpad"
693
+ cu_seqlens_key = "cu_seqlens_key"
694
+ if (
695
+ pad_key not in point.keys()
696
+ or unpad_key not in point.keys()
697
+ or cu_seqlens_key not in point.keys()
698
+ ):
699
+ offset = point.offset
700
+ bincount = offset2bincount(offset)
701
+ bincount_pad = (
702
+ torch.div(
703
+ bincount + self.patch_size - 1,
704
+ self.patch_size,
705
+ rounding_mode="trunc",
706
+ )
707
+ * self.patch_size
708
+ )
709
+ # only pad point when num of points larger than patch_size
710
+ mask_pad = bincount > self.patch_size
711
+ bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
712
+ _offset = torch.nn.functional.pad(offset, (1, 0))
713
+ _offset_pad = torch.nn.functional.pad(
714
+ torch.cumsum(bincount_pad, dim=0), (1, 0)
715
+ )
716
+ pad = torch.arange(_offset_pad[-1], device=offset.device)
717
+ unpad = torch.arange(_offset[-1], device=offset.device)
718
+ cu_seqlens = []
719
+ for i in range(len(offset)):
720
+ unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
721
+ if bincount[i] != bincount_pad[i]:
722
+ pad[
723
+ _offset_pad[i + 1]
724
+ - self.patch_size
725
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
726
+ ] = pad[
727
+ _offset_pad[i + 1]
728
+ - 2 * self.patch_size
729
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
730
+ - self.patch_size
731
+ ]
732
+ pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
733
+ cu_seqlens.append(
734
+ torch.arange(
735
+ _offset_pad[i],
736
+ _offset_pad[i + 1],
737
+ step=self.patch_size,
738
+ dtype=torch.int32,
739
+ device=offset.device,
740
+ )
741
+ )
742
+ point[pad_key] = pad
743
+ point[unpad_key] = unpad
744
+ point[cu_seqlens_key] = torch.nn.functional.pad(
745
+ torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
746
+ )
747
+ return point[pad_key], point[unpad_key], point[cu_seqlens_key]
748
+
749
+ def forward(self, point):
750
+ if not self.enable_flash:
751
+ self.patch_size = min(
752
+ offset2bincount(point.offset).min().tolist(), self.patch_size_max
753
+ )
754
+
755
+ H = self.num_heads
756
+ K = self.patch_size
757
+ C = self.channels
758
+
759
+ pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
760
+
761
+ order = point.serialized_order[self.order_index][pad]
762
+ inverse = unpad[point.serialized_inverse[self.order_index]]
763
+
764
+ # padding and reshape feat and batch for serialized point patch
765
+ qkv = self.qkv(point.feat)[order]
766
+
767
+ if not self.enable_flash:
768
+ # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
769
+ q, k, v = (
770
+ qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
771
+ )
772
+ # attn
773
+ if self.upcast_attention:
774
+ q = q.float()
775
+ k = k.float()
776
+ attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
777
+ if self.enable_rpe:
778
+ attn = attn + self.rpe(self.get_rel_pos(point, order))
779
+ if self.upcast_softmax:
780
+ attn = attn.float()
781
+ attn = self.softmax(attn)
782
+ attn = self.attn_drop(attn).to(qkv.dtype)
783
+ feat = (attn @ v).transpose(1, 2).reshape(-1, C)
784
+ else:
785
+ feat = flash_attn.flash_attn_varlen_qkvpacked_func(
786
+ qkv.half().reshape(-1, 3, H, C // H),
787
+ cu_seqlens,
788
+ max_seqlen=self.patch_size,
789
+ dropout_p=self.attn_drop if self.training else 0,
790
+ softmax_scale=self.scale,
791
+ ).reshape(-1, C)
792
+ feat = feat.to(qkv.dtype)
793
+ feat = feat[inverse]
794
+
795
+ # ffn
796
+ feat = self.proj(feat)
797
+ feat = self.proj_drop(feat)
798
+ point.feat = feat
799
+ return point
800
+
801
+
802
+ class MLP(torch.nn.Module):
803
+ def __init__(
804
+ self,
805
+ in_channels,
806
+ hidden_channels=None,
807
+ out_channels=None,
808
+ act_layer=torch.nn.GELU,
809
+ drop=0.0,
810
+ ):
811
+ super().__init__()
812
+ out_channels = out_channels or in_channels
813
+ hidden_channels = hidden_channels or in_channels
814
+ self.fc1 = torch.nn.Linear(in_channels, hidden_channels)
815
+ self.act = act_layer()
816
+ self.fc2 = torch.nn.Linear(hidden_channels, out_channels)
817
+ # self.drop = torch.nn.Dropout(drop)
818
+
819
+ def forward(self, x):
820
+ x = self.fc1(x)
821
+ x = self.act(x)
822
+ # x = self.drop(x)
823
+ x = self.fc2(x)
824
+ # x = self.drop(x)
825
+ return x
826
+
827
+
828
+ class Block(PointModule):
829
+ def __init__(
830
+ self,
831
+ channels,
832
+ num_heads,
833
+ patch_size=48,
834
+ mlp_ratio=4.0,
835
+ qkv_bias=True,
836
+ qk_scale=None,
837
+ attn_drop=0.0,
838
+ proj_drop=0.0,
839
+ drop_path=0.0,
840
+ norm_layer=torch.nn.LayerNorm,
841
+ act_layer=torch.nn.GELU,
842
+ pre_norm=True,
843
+ order_index=0,
844
+ cpe_indice_key=None,
845
+ enable_rpe=False,
846
+ enable_flash=True,
847
+ upcast_attention=True,
848
+ upcast_softmax=True,
849
+ ):
850
+ super().__init__()
851
+ self.channels = channels
852
+ self.pre_norm = pre_norm
853
+
854
+ self.cpe = PointSequential(
855
+ spconv.SubMConv3d(
856
+ channels,
857
+ channels,
858
+ kernel_size=3,
859
+ bias=True,
860
+ indice_key=cpe_indice_key,
861
+ ),
862
+ torch.nn.Linear(channels, channels),
863
+ norm_layer(channels),
864
+ )
865
+
866
+ self.norm1 = PointSequential(norm_layer(channels))
867
+ self.attn = SerializedAttention(
868
+ channels=channels,
869
+ patch_size=patch_size,
870
+ num_heads=num_heads,
871
+ qkv_bias=qkv_bias,
872
+ qk_scale=qk_scale,
873
+ attn_drop=attn_drop,
874
+ proj_drop=proj_drop,
875
+ order_index=order_index,
876
+ enable_rpe=enable_rpe,
877
+ enable_flash=enable_flash,
878
+ upcast_attention=upcast_attention,
879
+ upcast_softmax=upcast_softmax,
880
+ )
881
+ self.norm2 = PointSequential(norm_layer(channels))
882
+ self.mlp = PointSequential(
883
+ MLP(
884
+ in_channels=channels,
885
+ hidden_channels=int(channels * mlp_ratio),
886
+ out_channels=channels,
887
+ act_layer=act_layer,
888
+ drop=proj_drop,
889
+ )
890
+ )
891
+ self.drop_path = PointSequential(
892
+ DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
893
+ )
894
+
895
+ def forward(self, point: Point):
896
+ shortcut = point.feat
897
+ point = self.cpe(point)
898
+ point.feat = shortcut + point.feat
899
+ shortcut = point.feat
900
+ if self.pre_norm:
901
+ point = self.norm1(point)
902
+ point = self.drop_path(self.attn(point))
903
+ point.feat = shortcut + point.feat
904
+ if not self.pre_norm:
905
+ point = self.norm1(point)
906
+
907
+ shortcut = point.feat
908
+ if self.pre_norm:
909
+ point = self.norm2(point)
910
+ point = self.drop_path(self.mlp(point))
911
+ point.feat = shortcut + point.feat
912
+ if not self.pre_norm:
913
+ point = self.norm2(point)
914
+ point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
915
+ return point
916
+
917
+
918
+ class DropPath(torch.nn.Module):
919
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
920
+
921
+ def __init__(self, drop_prob=None, scale_by_keep=True):
922
+ super(DropPath, self).__init__()
923
+ self.drop_prob = drop_prob
924
+ self.scale_by_keep = scale_by_keep
925
+
926
+ def _drop_path(
927
+ self,
928
+ x,
929
+ drop_prob: float = 0.0,
930
+ training: bool = False,
931
+ scale_by_keep: bool = True,
932
+ ):
933
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
934
+
935
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
936
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
937
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
938
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
939
+ 'survival rate' as the argument.
940
+
941
+ """
942
+ if drop_prob == 0.0 or not training:
943
+ return x
944
+ keep_prob = 1 - drop_prob
945
+ shape = (x.shape[0],) + (1,) * (
946
+ x.ndim - 1
947
+ ) # work with diff dim tensors, not just 2D ConvNets
948
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
949
+ if keep_prob > 0.0 and scale_by_keep:
950
+ random_tensor.div_(keep_prob)
951
+ return x * random_tensor
952
+
953
+ def forward(self, x):
954
+ return self._drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
955
+
956
+
957
+ class SerializedPooling(PointModule):
958
+ def __init__(
959
+ self,
960
+ in_channels,
961
+ out_channels,
962
+ stride=2,
963
+ norm_layer=None,
964
+ act_layer=None,
965
+ reduce="max",
966
+ shuffle_orders=True,
967
+ traceable=True, # record parent and cluster
968
+ ):
969
+ super().__init__()
970
+ self.in_channels = in_channels
971
+ self.out_channels = out_channels
972
+
973
+ assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8
974
+ # TODO: add support to grid pool (any stride)
975
+ self.stride = stride
976
+ assert reduce in ["sum", "mean", "min", "max"]
977
+ self.reduce = reduce
978
+ self.shuffle_orders = shuffle_orders
979
+ self.traceable = traceable
980
+
981
+ self.proj = torch.nn.Linear(in_channels, out_channels)
982
+ if norm_layer is not None:
983
+ self.norm = PointSequential(norm_layer(out_channels))
984
+ if act_layer is not None:
985
+ self.act = PointSequential(act_layer())
986
+
987
+ def forward(self, point: Point):
988
+ pooling_depth = (math.ceil(self.stride) - 1).bit_length()
989
+ if pooling_depth > point.serialized_depth:
990
+ pooling_depth = 0
991
+ assert {
992
+ "serialized_code",
993
+ "serialized_order",
994
+ "serialized_inverse",
995
+ "serialized_depth",
996
+ }.issubset(
997
+ point.keys()
998
+ ), "Run point.serialization() point cloud before SerializedPooling"
999
+
1000
+ code = point.serialized_code >> pooling_depth * 3
1001
+ code_, cluster, counts = torch.unique(
1002
+ code[0],
1003
+ sorted=True,
1004
+ return_inverse=True,
1005
+ return_counts=True,
1006
+ )
1007
+ # indices of point sorted by cluster, for torch_scatter.segment_csr
1008
+ _, indices = torch.sort(cluster)
1009
+ # index pointer for sorted point, for torch_scatter.segment_csr
1010
+ idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
1011
+ # head_indices of each cluster, for reduce attr e.g. code, batch
1012
+ head_indices = indices[idx_ptr[:-1]]
1013
+ # generate down code, order, inverse
1014
+ code = code[:, head_indices]
1015
+ order = torch.argsort(code)
1016
+ inverse = torch.zeros_like(order).scatter_(
1017
+ dim=1,
1018
+ index=order,
1019
+ src=torch.arange(0, code.shape[1], device=order.device).repeat(
1020
+ code.shape[0], 1
1021
+ ),
1022
+ )
1023
+
1024
+ if self.shuffle_orders:
1025
+ perm = torch.randperm(code.shape[0])
1026
+ code = code[perm]
1027
+ order = order[perm]
1028
+ inverse = inverse[perm]
1029
+
1030
+ # collect information
1031
+ point_dict = addict.Dict(
1032
+ feat=torch_scatter.segment_csr(
1033
+ self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
1034
+ ),
1035
+ coord=torch_scatter.segment_csr(
1036
+ point.coord[indices], idx_ptr, reduce="mean"
1037
+ ),
1038
+ grid_coord=point.grid_coord[head_indices] >> pooling_depth,
1039
+ serialized_code=code,
1040
+ serialized_order=order,
1041
+ serialized_inverse=inverse,
1042
+ serialized_depth=point.serialized_depth - pooling_depth,
1043
+ batch=point.batch[head_indices],
1044
+ )
1045
+
1046
+ if "condition" in point.keys():
1047
+ point_dict["condition"] = point.condition
1048
+ if "context" in point.keys():
1049
+ point_dict["context"] = point.context
1050
+
1051
+ if self.traceable:
1052
+ point_dict["pooling_inverse"] = cluster
1053
+ point_dict["pooling_parent"] = point
1054
+
1055
+ point = Point(point_dict)
1056
+ # Fix: Expected more than 1 value per channel when training
1057
+ if self.norm is not None and point.feat.size(0) != 1:
1058
+ point = self.norm(point)
1059
+ if self.act is not None:
1060
+ point = self.act(point)
1061
+ point.sparsify()
1062
+
1063
+ return point
1064
+
1065
+
1066
+ class SerializedUnpooling(PointModule):
1067
+ def __init__(
1068
+ self,
1069
+ in_channels,
1070
+ skip_channels,
1071
+ out_channels,
1072
+ norm_layer=None,
1073
+ act_layer=None,
1074
+ traceable=False, # record parent and cluster
1075
+ ):
1076
+ super().__init__()
1077
+ self.proj = PointSequential(torch.nn.Linear(in_channels, out_channels))
1078
+ self.proj_skip = PointSequential(torch.nn.Linear(skip_channels, out_channels))
1079
+
1080
+ if norm_layer is not None:
1081
+ self.proj.add(norm_layer(out_channels))
1082
+ self.proj_skip.add(norm_layer(out_channels))
1083
+
1084
+ if act_layer is not None:
1085
+ self.proj.add(act_layer())
1086
+ self.proj_skip.add(act_layer())
1087
+
1088
+ self.traceable = traceable
1089
+
1090
+ def forward(self, point):
1091
+ assert "pooling_parent" in point.keys()
1092
+ assert "pooling_inverse" in point.keys()
1093
+ parent = point.pop("pooling_parent")
1094
+ inverse = point.pop("pooling_inverse")
1095
+ point = self.proj(point)
1096
+ parent = self.proj_skip(parent)
1097
+ parent.feat = parent.feat + point.feat[inverse]
1098
+
1099
+ if self.traceable:
1100
+ parent["unpooling_parent"] = point
1101
+ return parent
1102
+
1103
+
1104
+ class Embedding(PointModule):
1105
+ def __init__(
1106
+ self,
1107
+ in_channels,
1108
+ embed_channels,
1109
+ norm_layer=None,
1110
+ act_layer=None,
1111
+ ):
1112
+ super().__init__()
1113
+ self.in_channels = in_channels
1114
+ self.embed_channels = embed_channels
1115
+
1116
+ # TODO: check remove spconv
1117
+ self.stem = PointSequential(
1118
+ conv=spconv.SubMConv3d(
1119
+ in_channels,
1120
+ embed_channels,
1121
+ kernel_size=5,
1122
+ padding=1,
1123
+ bias=False,
1124
+ indice_key="stem",
1125
+ )
1126
+ )
1127
+ if norm_layer is not None:
1128
+ self.stem.add(norm_layer(embed_channels), name="norm")
1129
+ if act_layer is not None:
1130
+ self.stem.add(act_layer(), name="act")
1131
+
1132
+ def forward(self, point: Point):
1133
+ point = self.stem(point)
1134
+ return point
1135
+
1136
+
1137
+ class PointTransformerV3(PointModule):
1138
+ def __init__(
1139
+ self,
1140
+ in_channels=6,
1141
+ order=("cord"),
1142
+ stride=(2, 2, 2, 2),
1143
+ enc_depths=(2, 2, 2, 6, 2),
1144
+ enc_channels=(32, 64, 128, 256, 512),
1145
+ enc_num_head=(2, 4, 8, 16, 32),
1146
+ enc_patch_size=(1024, 1024, 1024, 1024, 1024),
1147
+ dec_depths=(2, 2, 2, 2),
1148
+ dec_channels=(64, 64, 128, 256),
1149
+ dec_num_head=(4, 4, 8, 16),
1150
+ dec_patch_size=(1024, 1024, 1024, 1024),
1151
+ mlp_ratio=4,
1152
+ grid_size=0.01,
1153
+ qkv_bias=True,
1154
+ qk_scale=None,
1155
+ attn_drop=0.0,
1156
+ proj_drop=0.0,
1157
+ drop_path=0.3,
1158
+ pre_norm=True,
1159
+ shuffle_orders=True,
1160
+ enable_rpe=False,
1161
+ enable_flash=True,
1162
+ upcast_attention=False,
1163
+ upcast_softmax=False,
1164
+ cls_mode=False,
1165
+ pdnorm_bn=False,
1166
+ pdnorm_ln=False,
1167
+ pdnorm_decouple=True,
1168
+ pdnorm_adaptive=False,
1169
+ pdnorm_affine=True,
1170
+ pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"),
1171
+ ):
1172
+ super().__init__()
1173
+ self.num_stages = len(enc_depths)
1174
+ self.order = [order] if isinstance(order, str) else order
1175
+ self.cls_mode = cls_mode
1176
+ self.shuffle_orders = shuffle_orders
1177
+ self.grid_size = grid_size
1178
+
1179
+ assert self.num_stages == len(stride) + 1
1180
+ assert self.num_stages == len(enc_depths)
1181
+ assert self.num_stages == len(enc_channels)
1182
+ assert self.num_stages == len(enc_num_head)
1183
+ assert self.num_stages == len(enc_patch_size)
1184
+ assert self.cls_mode or self.num_stages == len(dec_depths) + 1
1185
+ assert self.cls_mode or self.num_stages == len(dec_channels) + 1
1186
+ assert self.cls_mode or self.num_stages == len(dec_num_head) + 1
1187
+ assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1
1188
+
1189
+ # norm layers
1190
+ if pdnorm_bn:
1191
+ bn_layer = functools.partial(
1192
+ PDNorm,
1193
+ norm_layer=functools.partial(
1194
+ torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine
1195
+ ),
1196
+ conditions=pdnorm_conditions,
1197
+ decouple=pdnorm_decouple,
1198
+ adaptive=pdnorm_adaptive,
1199
+ )
1200
+ else:
1201
+ bn_layer = functools.partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01)
1202
+ if pdnorm_ln:
1203
+ ln_layer = functools.partial(
1204
+ PDNorm,
1205
+ norm_layer=functools.partial(
1206
+ torch.nn.LayerNorm, elementwise_affine=pdnorm_affine
1207
+ ),
1208
+ conditions=pdnorm_conditions,
1209
+ decouple=pdnorm_decouple,
1210
+ adaptive=pdnorm_adaptive,
1211
+ )
1212
+ else:
1213
+ ln_layer = torch.nn.LayerNorm
1214
+ # activation layers
1215
+ act_layer = torch.nn.GELU
1216
+ self.embedding = Embedding(
1217
+ in_channels=in_channels,
1218
+ embed_channels=enc_channels[0],
1219
+ norm_layer=bn_layer,
1220
+ act_layer=act_layer,
1221
+ )
1222
+
1223
+ # encoder
1224
+ enc_drop_path = [
1225
+ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
1226
+ ]
1227
+ self.enc = PointSequential(name="encoder")
1228
+ for s in range(self.num_stages):
1229
+ enc_drop_path_ = enc_drop_path[
1230
+ sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
1231
+ ]
1232
+ enc = PointSequential(name="encoder_layer_%d" % s)
1233
+ if s > 0:
1234
+ enc.add(
1235
+ SerializedPooling(
1236
+ in_channels=enc_channels[s - 1],
1237
+ out_channels=enc_channels[s],
1238
+ stride=stride[s - 1],
1239
+ norm_layer=bn_layer,
1240
+ act_layer=act_layer,
1241
+ ),
1242
+ name="down",
1243
+ )
1244
+ for i in range(enc_depths[s]):
1245
+ enc.add(
1246
+ Block(
1247
+ channels=enc_channels[s],
1248
+ num_heads=enc_num_head[s],
1249
+ patch_size=enc_patch_size[s],
1250
+ mlp_ratio=mlp_ratio,
1251
+ qkv_bias=qkv_bias,
1252
+ qk_scale=qk_scale,
1253
+ attn_drop=attn_drop,
1254
+ proj_drop=proj_drop,
1255
+ drop_path=enc_drop_path_[i],
1256
+ norm_layer=ln_layer,
1257
+ act_layer=act_layer,
1258
+ pre_norm=pre_norm,
1259
+ order_index=i % len(self.order),
1260
+ cpe_indice_key=f"stage{s}",
1261
+ enable_rpe=enable_rpe,
1262
+ enable_flash=enable_flash,
1263
+ upcast_attention=upcast_attention,
1264
+ upcast_softmax=upcast_softmax,
1265
+ ),
1266
+ name=f"block{i}",
1267
+ )
1268
+ if len(enc) != 0:
1269
+ self.enc.add(module=enc, name=f"enc{s}")
1270
+
1271
+ # decoder
1272
+ if not self.cls_mode:
1273
+ dec_drop_path = [
1274
+ x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
1275
+ ]
1276
+ self.dec = PointSequential(name="decoder")
1277
+ dec_channels = list(dec_channels) + [enc_channels[-1]]
1278
+ for s in reversed(range(self.num_stages - 1)):
1279
+ dec_drop_path_ = dec_drop_path[
1280
+ sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
1281
+ ]
1282
+ dec_drop_path_.reverse()
1283
+ dec = PointSequential(name="decoder_layer_%d" % s)
1284
+ dec.add(
1285
+ SerializedUnpooling(
1286
+ in_channels=dec_channels[s + 1],
1287
+ skip_channels=enc_channels[s],
1288
+ out_channels=dec_channels[s],
1289
+ norm_layer=bn_layer,
1290
+ act_layer=act_layer,
1291
+ ),
1292
+ name="up",
1293
+ )
1294
+ for i in range(dec_depths[s]):
1295
+ dec.add(
1296
+ Block(
1297
+ channels=dec_channels[s],
1298
+ num_heads=dec_num_head[s],
1299
+ patch_size=dec_patch_size[s],
1300
+ mlp_ratio=mlp_ratio,
1301
+ qkv_bias=qkv_bias,
1302
+ qk_scale=qk_scale,
1303
+ attn_drop=attn_drop,
1304
+ proj_drop=proj_drop,
1305
+ drop_path=dec_drop_path_[i],
1306
+ norm_layer=ln_layer,
1307
+ act_layer=act_layer,
1308
+ pre_norm=pre_norm,
1309
+ order_index=i % len(self.order),
1310
+ cpe_indice_key=f"stage{s}",
1311
+ enable_rpe=enable_rpe,
1312
+ enable_flash=enable_flash,
1313
+ upcast_attention=upcast_attention,
1314
+ upcast_softmax=upcast_softmax,
1315
+ ),
1316
+ name=f"block{i}",
1317
+ )
1318
+ self.dec.add(module=dec, name=f"dec{s}")
1319
+
1320
+ def forward(self, batch, feat, coord):
1321
+ """
1322
+ A data_dict is a dictionary containing properties of a batched point cloud.
1323
+ It should contain the following properties for PTv3:
1324
+ 1. "feat": feature of point cloud
1325
+ 2. "grid_coord": discrete coordinate after grid sampling (voxelization) or "coord" + "grid_size"
1326
+ 3. "offset" or "batch": https://github.com/Pointcept/Pointcept?tab=readme-ov-file#offset
1327
+ """
1328
+ point = Point(
1329
+ {
1330
+ "batch": batch.squeeze(dim=0),
1331
+ "feat": feat.squeeze(dim=0),
1332
+ "coord": coord.squeeze(dim=0),
1333
+ "grid_size": self.grid_size,
1334
+ }
1335
+ )
1336
+ point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
1337
+ point.sparsify()
1338
+
1339
+ point = self.embedding(point)
1340
+ point = self.enc(point)
1341
+ if not self.cls_mode:
1342
+ point = self.dec(point)
1343
+
1344
+ return point.feat.unsqueeze(dim=0)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ easydict
3
+ gradio
4
+ numpy<2.0.0
5
+ opencv-python
6
+ pillow
7
+ scipy
8
+ torch==2.2.2
9
+ torchvision==0.17.2
10
+
11
+ addict
12
+ spconv-cu121
13
+ torch_scatter
14
+