Spaces:
Running on Zero
Running on Zero
Commit ·
83d5461
verified ·
0
Parent(s):
fix: reinitialize the repo.
Browse files- .gitattributes +37 -0
- .gitignore +183 -0
- .gitmodules +3 -0
- ARTICLE.md +25 -0
- LICENSE +35 -0
- README.md +17 -0
- app.py +241 -0
- assets/CENTERS.pkl +3 -0
- assets/NYC-HghtFld.png +3 -0
- assets/NYC-SegMap.png +3 -0
- gaussiancity/__init__.py +0 -0
- gaussiancity/extensions/__init__.py +0 -0
- gaussiancity/extensions/diff_gaussian_rasterization/CMakeLists.txt +36 -0
- gaussiancity/extensions/diff_gaussian_rasterization/LICENSE.md +83 -0
- gaussiancity/extensions/diff_gaussian_rasterization/__init__.py +426 -0
- gaussiancity/extensions/diff_gaussian_rasterization/bindings.cpp +19 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/auxiliary.h +169 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/backward.cu +622 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/backward.h +41 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/config.h +19 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/forward.cu +376 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/forward.h +43 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer.h +52 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer_impl.cu +339 -0
- gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer_impl.h +70 -0
- gaussiancity/extensions/diff_gaussian_rasterization/rasterize_points.cu +173 -0
- gaussiancity/extensions/diff_gaussian_rasterization/rasterize_points.h +46 -0
- gaussiancity/extensions/diff_gaussian_rasterization/setup.py +40 -0
- gaussiancity/extensions/diff_gaussian_rasterization/third_party/glm +1 -0
- gaussiancity/extensions/diff_gaussian_rasterization/third_party/stbi_image_write.h +1724 -0
- gaussiancity/extensions/grid_encoder/__init__.py +193 -0
- gaussiancity/extensions/grid_encoder/bindings.cpp +40 -0
- gaussiancity/extensions/grid_encoder/grid_encoder_ext.cu +605 -0
- gaussiancity/extensions/grid_encoder/setup.py +39 -0
- gaussiancity/extensions/voxlib/__init__.py +12 -0
- gaussiancity/extensions/voxlib/bindings.cpp +41 -0
- gaussiancity/extensions/voxlib/maps_to_volume.cu +142 -0
- gaussiancity/extensions/voxlib/points_to_volume.cu +79 -0
- gaussiancity/extensions/voxlib/ray_voxel_intersection.cu +332 -0
- gaussiancity/extensions/voxlib/setup.py +32 -0
- gaussiancity/extensions/voxlib/voxlib_common.h +83 -0
- gaussiancity/generator.py +536 -0
- gaussiancity/inference.py +582 -0
- gaussiancity/pt_v3.py +1344 -0
- requirements.txt +14 -0
.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.whl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---> Python
|
| 2 |
+
# Byte-compiled / optimized / DLL files
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
*$py.class
|
| 6 |
+
|
| 7 |
+
# C extensions
|
| 8 |
+
*.so
|
| 9 |
+
|
| 10 |
+
# Distribution / packaging
|
| 11 |
+
.Python
|
| 12 |
+
build/
|
| 13 |
+
develop-eggs/
|
| 14 |
+
dist/
|
| 15 |
+
downloads/
|
| 16 |
+
eggs/
|
| 17 |
+
.eggs/
|
| 18 |
+
lib/
|
| 19 |
+
lib64/
|
| 20 |
+
parts/
|
| 21 |
+
sdist/
|
| 22 |
+
var/
|
| 23 |
+
wheels/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
cover/
|
| 54 |
+
|
| 55 |
+
# Translations
|
| 56 |
+
*.mo
|
| 57 |
+
*.pot
|
| 58 |
+
|
| 59 |
+
# Django stuff:
|
| 60 |
+
*.log
|
| 61 |
+
local_settings.py
|
| 62 |
+
db.sqlite3
|
| 63 |
+
db.sqlite3-journal
|
| 64 |
+
|
| 65 |
+
# Flask stuff:
|
| 66 |
+
instance/
|
| 67 |
+
.webassets-cache
|
| 68 |
+
|
| 69 |
+
# Scrapy stuff:
|
| 70 |
+
.scrapy
|
| 71 |
+
|
| 72 |
+
# Sphinx documentation
|
| 73 |
+
docs/_build/
|
| 74 |
+
|
| 75 |
+
# PyBuilder
|
| 76 |
+
.pybuilder/
|
| 77 |
+
target/
|
| 78 |
+
|
| 79 |
+
# Jupyter Notebook
|
| 80 |
+
.ipynb_checkpoints
|
| 81 |
+
|
| 82 |
+
# IPython
|
| 83 |
+
profile_default/
|
| 84 |
+
ipython_config.py
|
| 85 |
+
|
| 86 |
+
# pyenv
|
| 87 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 89 |
+
# .python-version
|
| 90 |
+
|
| 91 |
+
# pipenv
|
| 92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 95 |
+
# install all needed dependencies.
|
| 96 |
+
#Pipfile.lock
|
| 97 |
+
|
| 98 |
+
# poetry
|
| 99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 101 |
+
# commonly ignored for libraries.
|
| 102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 103 |
+
#poetry.lock
|
| 104 |
+
|
| 105 |
+
# pdm
|
| 106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 107 |
+
#pdm.lock
|
| 108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 109 |
+
# in version control.
|
| 110 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 111 |
+
.pdm.toml
|
| 112 |
+
|
| 113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 114 |
+
__pypackages__/
|
| 115 |
+
|
| 116 |
+
# Celery stuff
|
| 117 |
+
celerybeat-schedule
|
| 118 |
+
celerybeat.pid
|
| 119 |
+
|
| 120 |
+
# SageMath parsed files
|
| 121 |
+
*.sage.py
|
| 122 |
+
|
| 123 |
+
# Environments
|
| 124 |
+
.env
|
| 125 |
+
.venv
|
| 126 |
+
env/
|
| 127 |
+
venv/
|
| 128 |
+
ENV/
|
| 129 |
+
env.bak/
|
| 130 |
+
venv.bak/
|
| 131 |
+
|
| 132 |
+
# Spyder project settings
|
| 133 |
+
.spyderproject
|
| 134 |
+
.spyproject
|
| 135 |
+
|
| 136 |
+
# Rope project settings
|
| 137 |
+
.ropeproject
|
| 138 |
+
|
| 139 |
+
# mkdocs documentation
|
| 140 |
+
/site
|
| 141 |
+
|
| 142 |
+
# mypy
|
| 143 |
+
.mypy_cache/
|
| 144 |
+
.dmypy.json
|
| 145 |
+
dmypy.json
|
| 146 |
+
|
| 147 |
+
# Pyre type checker
|
| 148 |
+
.pyre/
|
| 149 |
+
|
| 150 |
+
# pytype static type analyzer
|
| 151 |
+
.pytype/
|
| 152 |
+
|
| 153 |
+
# Cython debug symbols
|
| 154 |
+
cython_debug/
|
| 155 |
+
|
| 156 |
+
# PyCharm
|
| 157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 161 |
+
idea/
|
| 162 |
+
|
| 163 |
+
# VSCode
|
| 164 |
+
.vscode/
|
| 165 |
+
|
| 166 |
+
# ---> JupyterNotebooks
|
| 167 |
+
# gitignore template for Jupyter Notebooks
|
| 168 |
+
# website: http://jupyter.org/
|
| 169 |
+
|
| 170 |
+
.ipynb_checkpoints
|
| 171 |
+
*/.ipynb_checkpoints/*
|
| 172 |
+
|
| 173 |
+
# IPython
|
| 174 |
+
profile_default/
|
| 175 |
+
ipython_config.py
|
| 176 |
+
|
| 177 |
+
# User data
|
| 178 |
+
configs/
|
| 179 |
+
data/
|
| 180 |
+
notebooks/
|
| 181 |
+
output/
|
| 182 |
+
flagged/
|
| 183 |
+
*.pth
|
.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "gaussiancity/extensions/diff_gaussian_rasterization/third_party/glm"]
|
| 2 |
+
path = gaussiancity/extensions/diff_gaussian_rasterization/third_party/glm
|
| 3 |
+
url = https://github.com/g-truc/glm.git
|
ARTICLE.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Citation 📝
|
| 2 |
+
|
| 3 |
+
If our work is useful for your research, please consider citing:
|
| 4 |
+
|
| 5 |
+
```bibtex
|
| 6 |
+
@inproceedings{xie2025gaussiancity,
|
| 7 |
+
title = {Generative Gaussian Splatting for Unbounded 3{D} City Generation},
|
| 8 |
+
author = {Xie, Haozhe and
|
| 9 |
+
Chen, Zhaoxi and
|
| 10 |
+
Hong, Fangzhou and
|
| 11 |
+
Liu, Ziwei},
|
| 12 |
+
booktitle = {CVPR},
|
| 13 |
+
year = {2025}
|
| 14 |
+
}
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
### License 📋
|
| 18 |
+
|
| 19 |
+
This project is licensed under [S-Lab License 1.0](https://huggingface.co/hzxie/city-dreamer/blob/main/LICENSE).
|
| 20 |
+
Redistribution and use for non-commercial purposes should follow this license.
|
| 21 |
+
|
| 22 |
+

|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
LICENSE
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
S-Lab License 1.0
|
| 2 |
+
|
| 3 |
+
Copyright 2025 S-Lab
|
| 4 |
+
|
| 5 |
+
Redistribution and use for non-commercial purpose in source and
|
| 6 |
+
binary forms, with or without modification, are permitted provided
|
| 7 |
+
that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
1. Redistributions of source code must retain the above copyright
|
| 10 |
+
notice, this list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
2. Redistributions in binary form must reproduce the above copyright
|
| 13 |
+
notice, this list of conditions and the following disclaimer in
|
| 14 |
+
the documentation and/or other materials provided with the
|
| 15 |
+
distribution.
|
| 16 |
+
|
| 17 |
+
3. Neither the name of the copyright holder nor the names of its
|
| 18 |
+
contributors may be used to endorse or promote products derived
|
| 19 |
+
from this software without specific prior written permission.
|
| 20 |
+
|
| 21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 22 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 23 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 24 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 25 |
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 26 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 27 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 28 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
| 29 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 30 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 31 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 32 |
+
|
| 33 |
+
In the event that redistribution and/or use for commercial purpose in
|
| 34 |
+
source or binary forms, with or without modification is required,
|
| 35 |
+
please contact the contributor(s) of the work.
|
README.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: GaussianCity
|
| 3 |
+
emoji: 🏙️
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.20.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Official demo for **[Generative Gaussian Splatting for Unbounded 3D City Generation](https://github.com/hzxie/GaussianCity) (CVPR 2025).**
|
| 13 |
+
|
| 14 |
+
- 🔥 GaussianCity is a unbounded 3D city generator based on 3D Gaussian Splatting.
|
| 15 |
+
- 🤗 Try GaussianCity to generate photolistic 3D cities.
|
| 16 |
+
- ⚠️ Due to the limited computational resources at Hugging Face, this demo only generates **A SINGLE IMAGE** based on the New York City layout.
|
| 17 |
+
|
app.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: app.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2024-03-02 16:30:00
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-10-13 15:36:50
|
| 8 |
+
# @Email: root@haozhexie.com
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import logging
|
| 12 |
+
import numpy as np
|
| 13 |
+
import os
|
| 14 |
+
import pickle
|
| 15 |
+
import ssl
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
import urllib.request
|
| 19 |
+
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
# Reinstall PyTorch with CUDA 11.8 (Default version is 12.1)
|
| 23 |
+
# subprocess.call(
|
| 24 |
+
# [
|
| 25 |
+
# "pip",
|
| 26 |
+
# "install",
|
| 27 |
+
# "torch==2.2.2",
|
| 28 |
+
# "torchvision==0.17.2",
|
| 29 |
+
# "--index-url",
|
| 30 |
+
# "https://download.pytorch.org/whl/cu118",
|
| 31 |
+
# ]
|
| 32 |
+
# )
|
| 33 |
+
import torch
|
| 34 |
+
|
| 35 |
+
# Create a dummy decorator for Non-ZeroGPU environments
|
| 36 |
+
if os.environ.get("SPACES_ZERO_GPU") is not None:
|
| 37 |
+
import spaces
|
| 38 |
+
else:
|
| 39 |
+
|
| 40 |
+
class spaces:
|
| 41 |
+
@staticmethod
|
| 42 |
+
def GPU(func):
|
| 43 |
+
# This is a dummy wrapper that just calls the function.
|
| 44 |
+
def wrapper(*args, **kwargs):
|
| 45 |
+
return func(*args, **kwargs)
|
| 46 |
+
|
| 47 |
+
return wrapper
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
|
| 51 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 52 |
+
# Import GaussianCity modules
|
| 53 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "gaussiancity"))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _get_output(cmd):
|
| 57 |
+
try:
|
| 58 |
+
return subprocess.check_output(cmd).decode("utf-8")
|
| 59 |
+
except Exception as ex:
|
| 60 |
+
logging.exception(ex)
|
| 61 |
+
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def install_cuda_toolkit():
|
| 66 |
+
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
|
| 67 |
+
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
|
| 68 |
+
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
|
| 69 |
+
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
|
| 70 |
+
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
|
| 71 |
+
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
|
| 72 |
+
|
| 73 |
+
os.environ["CUDA_HOME"] = "/usr/local/cuda"
|
| 74 |
+
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
|
| 75 |
+
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
|
| 76 |
+
os.environ["CUDA_HOME"],
|
| 77 |
+
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
|
| 78 |
+
)
|
| 79 |
+
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
|
| 80 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def setup_runtime_env():
|
| 84 |
+
logging.info("Python Version: %s" % _get_output(["python", "--version"]))
|
| 85 |
+
logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"]))
|
| 86 |
+
logging.info("GCC Version: %s" % _get_output(["gcc", "--version"]))
|
| 87 |
+
logging.info("CUDA is available: %s" % torch.cuda.is_available())
|
| 88 |
+
logging.info("CUDA Device Capability: %s" % (torch.cuda.get_device_capability(),))
|
| 89 |
+
|
| 90 |
+
# Install Pre-compiled CUDA extensions (Not working)
|
| 91 |
+
# Ref: https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/110
|
| 92 |
+
#
|
| 93 |
+
# ext_dir = os.path.join(os.path.dirname(__file__), "wheels")
|
| 94 |
+
# for e in os.listdir(ext_dir):
|
| 95 |
+
# logging.info("Installing Extensions from %s" % e)
|
| 96 |
+
# subprocess.call(
|
| 97 |
+
# ["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT
|
| 98 |
+
# )
|
| 99 |
+
# Compile CUDA extensions
|
| 100 |
+
ext_dir = os.path.join(os.path.dirname(__file__), "gaussiancity", "extensions")
|
| 101 |
+
for e in os.listdir(ext_dir):
|
| 102 |
+
if os.path.isdir(os.path.join(ext_dir, e)):
|
| 103 |
+
subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))
|
| 104 |
+
|
| 105 |
+
logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"]))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_models(file_name):
|
| 109 |
+
import gaussiancity.generator
|
| 110 |
+
|
| 111 |
+
if not os.path.exists(file_name):
|
| 112 |
+
urllib.request.urlretrieve(
|
| 113 |
+
"https://huggingface.co/hzxie/gaussian-city/resolve/main/%s" % file_name,
|
| 114 |
+
file_name,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 118 |
+
ckpt = torch.load(file_name, map_location=torch.device(device), weights_only=False)
|
| 119 |
+
model = gaussiancity.generator.Generator(
|
| 120 |
+
ckpt["cfg"].NETWORK.GAUSSIAN,
|
| 121 |
+
n_classes=ckpt["cfg"].DATASETS.GOOGLE_EARTH.N_CLASSES,
|
| 122 |
+
proj_size=ckpt["cfg"].DATASETS.GOOGLE_EARTH.PROJ_SIZE,
|
| 123 |
+
)
|
| 124 |
+
if torch.cuda.is_available():
|
| 125 |
+
model = torch.nn.DataParallel(model).cuda().eval()
|
| 126 |
+
|
| 127 |
+
model.load_state_dict(ckpt["gaussian_g"], strict=False)
|
| 128 |
+
return model
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_city_layout():
|
| 132 |
+
import gaussiancity.inference
|
| 133 |
+
|
| 134 |
+
layout = None
|
| 135 |
+
if os.path.exists("assets/NYC.pkl"):
|
| 136 |
+
with open("assets/NYC.pkl", "rb") as fp:
|
| 137 |
+
layout = pickle.load(fp)
|
| 138 |
+
else:
|
| 139 |
+
td_hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
|
| 140 |
+
# Fix: nonzero is not supported for tensors with more than INT_MAX elements
|
| 141 |
+
td_hf[td_hf > 500] = 500
|
| 142 |
+
bu_hf = np.zeros_like(td_hf)
|
| 143 |
+
seg_map = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(
|
| 144 |
+
np.int32
|
| 145 |
+
)
|
| 146 |
+
ins_map = gaussiancity.inference.get_instance_seg_map(seg_map.copy())
|
| 147 |
+
pts_map = gaussiancity.inference.get_point_map(seg_map)
|
| 148 |
+
layout = {
|
| 149 |
+
"TD_HF": td_hf,
|
| 150 |
+
"BU_HF": bu_hf,
|
| 151 |
+
"SEG": seg_map,
|
| 152 |
+
"INS": ins_map,
|
| 153 |
+
"PTS": pts_map,
|
| 154 |
+
}
|
| 155 |
+
with open("assets/NYC.pkl", "wb") as fp:
|
| 156 |
+
pickle.dump(layout, fp)
|
| 157 |
+
|
| 158 |
+
centers = None
|
| 159 |
+
if os.path.exists("assets/CENTERS.pkl"):
|
| 160 |
+
with open("assets/CENTERS.pkl", "rb") as fp:
|
| 161 |
+
centers = pickle.load(fp)
|
| 162 |
+
else:
|
| 163 |
+
centers = gaussiancity.inference.get_centers(layout["INS"], layout["TD_HF"])
|
| 164 |
+
with open("assets/CENTERS.pkl", "wb") as fp:
|
| 165 |
+
pickle.dump(centers, fp)
|
| 166 |
+
|
| 167 |
+
layout["CTR"] = centers
|
| 168 |
+
return layout
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@spaces.GPU
|
| 172 |
+
def get_generated_city(radius, altitude, azimuth, map_center):
|
| 173 |
+
logging.info("CUDA is available: %s" % torch.cuda.is_available())
|
| 174 |
+
logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda)
|
| 175 |
+
# The import must be done after CUDA extension compilation
|
| 176 |
+
import gaussiancity.inference
|
| 177 |
+
|
| 178 |
+
return gaussiancity.inference.generate_city(
|
| 179 |
+
get_generated_city.fgm.to("cuda"),
|
| 180 |
+
get_generated_city.bgm.to("cuda"),
|
| 181 |
+
get_generated_city.city_layout,
|
| 182 |
+
map_center,
|
| 183 |
+
map_center,
|
| 184 |
+
radius,
|
| 185 |
+
altitude,
|
| 186 |
+
azimuth,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def main(debug):
|
| 191 |
+
title = "Generative Gaussian Splatting for Unbounded 3D City Generation"
|
| 192 |
+
with open("README.md", "r") as f:
|
| 193 |
+
markdown = f.read()
|
| 194 |
+
desc = markdown[markdown.rfind("---") + 3 :]
|
| 195 |
+
with open("ARTICLE.md", "r") as f:
|
| 196 |
+
arti = f.read()
|
| 197 |
+
|
| 198 |
+
app = gr.Interface(
|
| 199 |
+
get_generated_city,
|
| 200 |
+
[
|
| 201 |
+
gr.Slider(256, 768, value=512, step=4, label="Camera Radius (m)"),
|
| 202 |
+
gr.Slider(256, 768, value=512, step=4, label="Camera Altitude (m)"),
|
| 203 |
+
gr.Slider(0, 360, value=60, step=5, label="Camera Azimuth (°)"),
|
| 204 |
+
gr.Slider(1024, 7168, value=3570, step=4, label="Map Center (px)"),
|
| 205 |
+
],
|
| 206 |
+
[gr.Image(type="numpy", label="Generated City")],
|
| 207 |
+
title=title,
|
| 208 |
+
description=desc,
|
| 209 |
+
article=arti,
|
| 210 |
+
flagging_mode="never",
|
| 211 |
+
)
|
| 212 |
+
app.queue(api_open=False)
|
| 213 |
+
app.launch(debug=debug)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
logging.basicConfig(
|
| 218 |
+
format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
|
| 219 |
+
)
|
| 220 |
+
logging.info("Environment Variables: %s" % os.environ)
|
| 221 |
+
if _get_output(["nvcc", "--version"]) is None:
|
| 222 |
+
logging.info("Installing CUDA toolkit...")
|
| 223 |
+
install_cuda_toolkit()
|
| 224 |
+
else:
|
| 225 |
+
logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"]))
|
| 226 |
+
|
| 227 |
+
logging.info("Compiling CUDA extensions...")
|
| 228 |
+
# setup_runtime_env()
|
| 229 |
+
|
| 230 |
+
logging.info("Downloading pretrained models...")
|
| 231 |
+
fgm = get_models("GaussianCity-Fgnd.pth")
|
| 232 |
+
bgm = get_models("GaussianCity-Bgnd.pth")
|
| 233 |
+
get_generated_city.fgm = fgm
|
| 234 |
+
get_generated_city.bgm = bgm
|
| 235 |
+
|
| 236 |
+
logging.info("Loading New York city layout to RAM...")
|
| 237 |
+
city_layout = get_city_layout()
|
| 238 |
+
get_generated_city.city_layout = city_layout
|
| 239 |
+
|
| 240 |
+
logging.info("Starting the main application...")
|
| 241 |
+
main(os.getenv("DEBUG") == "1")
|
assets/CENTERS.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cad871bfb3997485a6d464c1464c2a551b601a4444913b9ec808530d093eefd8
|
| 3 |
+
size 728474
|
assets/NYC-HghtFld.png
ADDED
|
Git LFS Details
|
assets/NYC-SegMap.png
ADDED
|
Git LFS Details
|
gaussiancity/__init__.py
ADDED
|
File without changes
|
gaussiancity/extensions/__init__.py
ADDED
|
File without changes
|
gaussiancity/extensions/diff_gaussian_rasterization/CMakeLists.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (C) 2023, Inria
|
| 3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This software is free for non-commercial, research and evaluation use
|
| 7 |
+
# under the terms of the LICENSE.md file.
|
| 8 |
+
#
|
| 9 |
+
# For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
#
|
| 11 |
+
|
| 12 |
+
cmake_minimum_required(VERSION 3.20)
|
| 13 |
+
|
| 14 |
+
project(DiffRast LANGUAGES CUDA CXX)
|
| 15 |
+
|
| 16 |
+
set(CMAKE_CXX_STANDARD 17)
|
| 17 |
+
set(CMAKE_CXX_EXTENSIONS OFF)
|
| 18 |
+
set(CMAKE_CUDA_STANDARD 17)
|
| 19 |
+
|
| 20 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
| 21 |
+
|
| 22 |
+
add_library(CudaRasterizer
|
| 23 |
+
cuda_rasterizer/backward.h
|
| 24 |
+
cuda_rasterizer/backward.cu
|
| 25 |
+
cuda_rasterizer/forward.h
|
| 26 |
+
cuda_rasterizer/forward.cu
|
| 27 |
+
cuda_rasterizer/auxiliary.h
|
| 28 |
+
cuda_rasterizer/rasterizer_impl.cu
|
| 29 |
+
cuda_rasterizer/rasterizer_impl.h
|
| 30 |
+
cuda_rasterizer/rasterizer.h
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86")
|
| 34 |
+
|
| 35 |
+
target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer)
|
| 36 |
+
target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
gaussiancity/extensions/diff_gaussian_rasterization/LICENSE.md
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Gaussian-Splatting License
|
| 2 |
+
===========================
|
| 3 |
+
|
| 4 |
+
**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
|
| 5 |
+
The *Software* is in the process of being registered with the Agence pour la Protection des
|
| 6 |
+
Programmes (APP).
|
| 7 |
+
|
| 8 |
+
The *Software* is still being developed by the *Licensor*.
|
| 9 |
+
|
| 10 |
+
*Licensor*'s goal is to allow the research community to use, test and evaluate
|
| 11 |
+
the *Software*.
|
| 12 |
+
|
| 13 |
+
## 1. Definitions
|
| 14 |
+
|
| 15 |
+
*Licensee* means any person or entity that uses the *Software* and distributes
|
| 16 |
+
its *Work*.
|
| 17 |
+
|
| 18 |
+
*Licensor* means the owners of the *Software*, i.e Inria and MPII
|
| 19 |
+
|
| 20 |
+
*Software* means the original work of authorship made available under this
|
| 21 |
+
License ie gaussian-splatting.
|
| 22 |
+
|
| 23 |
+
*Work* means the *Software* and any additions to or derivative works of the
|
| 24 |
+
*Software* that are made available under this License.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## 2. Purpose
|
| 28 |
+
This license is intended to define the rights granted to the *Licensee* by
|
| 29 |
+
Licensors under the *Software*.
|
| 30 |
+
|
| 31 |
+
## 3. Rights granted
|
| 32 |
+
|
| 33 |
+
For the above reasons Licensors have decided to distribute the *Software*.
|
| 34 |
+
Licensors grant non-exclusive rights to use the *Software* for research purposes
|
| 35 |
+
to research users (both academic and industrial), free of charge, without right
|
| 36 |
+
to sublicense.. The *Software* may be used "non-commercially", i.e., for research
|
| 37 |
+
and/or evaluation purposes only.
|
| 38 |
+
|
| 39 |
+
Subject to the terms and conditions of this License, you are granted a
|
| 40 |
+
non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
|
| 41 |
+
publicly display, publicly perform and distribute its *Work* and any resulting
|
| 42 |
+
derivative works in any form.
|
| 43 |
+
|
| 44 |
+
## 4. Limitations
|
| 45 |
+
|
| 46 |
+
**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
|
| 47 |
+
so under this License, (b) you include a complete copy of this License with
|
| 48 |
+
your distribution, and (c) you retain without modification any copyright,
|
| 49 |
+
patent, trademark, or attribution notices that are present in the *Work*.
|
| 50 |
+
|
| 51 |
+
**4.2 Derivative Works.** You may specify that additional or different terms apply
|
| 52 |
+
to the use, reproduction, and distribution of your derivative works of the *Work*
|
| 53 |
+
("Your Terms") only if (a) Your Terms provide that the use limitation in
|
| 54 |
+
Section 2 applies to your derivative works, and (b) you identify the specific
|
| 55 |
+
derivative works that are subject to Your Terms. Notwithstanding Your Terms,
|
| 56 |
+
this License (including the redistribution requirements in Section 3.1) will
|
| 57 |
+
continue to apply to the *Work* itself.
|
| 58 |
+
|
| 59 |
+
**4.3** Any other use without of prior consent of Licensors is prohibited. Research
|
| 60 |
+
users explicitly acknowledge having received from Licensors all information
|
| 61 |
+
allowing to appreciate the adequacy between of the *Software* and their needs and
|
| 62 |
+
to undertake all necessary precautions for its execution and use.
|
| 63 |
+
|
| 64 |
+
**4.4** The *Software* is provided both as a compiled library file and as source
|
| 65 |
+
code. In case of using the *Software* for a publication or other results obtained
|
| 66 |
+
through the use of the *Software*, users are strongly encouraged to cite the
|
| 67 |
+
corresponding publications as explained in the documentation of the *Software*.
|
| 68 |
+
|
| 69 |
+
## 5. Disclaimer
|
| 70 |
+
|
| 71 |
+
THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
|
| 72 |
+
WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
|
| 73 |
+
UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
|
| 74 |
+
CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
|
| 75 |
+
OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
|
| 76 |
+
USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
|
| 77 |
+
ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
|
| 78 |
+
AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
| 79 |
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
| 80 |
+
GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
|
| 81 |
+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
| 82 |
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
|
| 83 |
+
IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
|
gaussiancity/extensions/diff_gaussian_rasterization/__init__.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: __init__.py
|
| 4 |
+
# @Author: Inria <george.drettakis@inria.fr>
|
| 5 |
+
# @Date: 2024-01-31 19:07:01
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-05-01 14:14:49
|
| 8 |
+
# @Email: root@haozhexie.com
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import numpy as np
|
| 12 |
+
import scipy.spatial.transform
|
| 13 |
+
import torch
|
| 14 |
+
import typing
|
| 15 |
+
|
| 16 |
+
import diff_gaussian_rasterization_ext as dgr_ext
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RasterizeGaussiansFunction(torch.autograd.Function):
|
| 20 |
+
@staticmethod
|
| 21 |
+
def _cpu_deep_copy_tuple(input_tuple):
|
| 22 |
+
copied_tensors = [
|
| 23 |
+
item.cpu().clone() if isinstance(item, torch.Tensor) else item
|
| 24 |
+
for item in input_tuple
|
| 25 |
+
]
|
| 26 |
+
return tuple(copied_tensors)
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def forward(
|
| 30 |
+
ctx,
|
| 31 |
+
means3D,
|
| 32 |
+
means2D,
|
| 33 |
+
sh,
|
| 34 |
+
colors_precomp,
|
| 35 |
+
opacities,
|
| 36 |
+
scales,
|
| 37 |
+
rotations,
|
| 38 |
+
cov3Ds_precomp,
|
| 39 |
+
raster_settings,
|
| 40 |
+
):
|
| 41 |
+
# Restructure arguments the way that the C++ lib expects them
|
| 42 |
+
args = (
|
| 43 |
+
raster_settings.bg,
|
| 44 |
+
means3D,
|
| 45 |
+
colors_precomp,
|
| 46 |
+
opacities,
|
| 47 |
+
scales,
|
| 48 |
+
rotations,
|
| 49 |
+
raster_settings.scale_modifier,
|
| 50 |
+
cov3Ds_precomp,
|
| 51 |
+
raster_settings.view_matrix,
|
| 52 |
+
raster_settings.proj_matrix,
|
| 53 |
+
raster_settings.tanfovx,
|
| 54 |
+
raster_settings.tanfovy,
|
| 55 |
+
raster_settings.img_h,
|
| 56 |
+
raster_settings.img_w,
|
| 57 |
+
sh,
|
| 58 |
+
raster_settings.sh_degree,
|
| 59 |
+
raster_settings.campos,
|
| 60 |
+
raster_settings.prefiltered,
|
| 61 |
+
raster_settings.debug,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Invoke C++/CUDA rasterizer
|
| 65 |
+
if raster_settings.debug:
|
| 66 |
+
cpu_args = RasterizeGaussiansFunction._cpu_deep_copy_tuple(
|
| 67 |
+
input_tuple=args
|
| 68 |
+
) # Copy them before they can be corrupted
|
| 69 |
+
try:
|
| 70 |
+
(
|
| 71 |
+
num_rendered,
|
| 72 |
+
color,
|
| 73 |
+
radii,
|
| 74 |
+
geom_buffer,
|
| 75 |
+
binning_buffer,
|
| 76 |
+
img_buffer,
|
| 77 |
+
) = dgr_ext.rasterize_gaussians(*args)
|
| 78 |
+
except Exception as ex:
|
| 79 |
+
torch.save(cpu_args, "snapshot_fw.dump")
|
| 80 |
+
print(
|
| 81 |
+
"\nAn error occured in forward. Please forward snapshot_fw.dump for debugging."
|
| 82 |
+
)
|
| 83 |
+
raise ex
|
| 84 |
+
else:
|
| 85 |
+
(
|
| 86 |
+
num_rendered,
|
| 87 |
+
color,
|
| 88 |
+
radii,
|
| 89 |
+
geom_buffer,
|
| 90 |
+
binning_buffer,
|
| 91 |
+
img_buffer,
|
| 92 |
+
) = dgr_ext.rasterize_gaussians(*args)
|
| 93 |
+
|
| 94 |
+
# Keep relevant tensors for backward
|
| 95 |
+
ctx.raster_settings = raster_settings
|
| 96 |
+
ctx.num_rendered = num_rendered
|
| 97 |
+
ctx.save_for_backward(
|
| 98 |
+
colors_precomp,
|
| 99 |
+
means3D,
|
| 100 |
+
scales,
|
| 101 |
+
rotations,
|
| 102 |
+
cov3Ds_precomp,
|
| 103 |
+
radii,
|
| 104 |
+
sh,
|
| 105 |
+
geom_buffer,
|
| 106 |
+
binning_buffer,
|
| 107 |
+
img_buffer,
|
| 108 |
+
)
|
| 109 |
+
return color, radii
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def backward(ctx, grad_out_color, _):
|
| 113 |
+
# Restore necessary values from context
|
| 114 |
+
num_rendered = ctx.num_rendered
|
| 115 |
+
raster_settings = ctx.raster_settings
|
| 116 |
+
(
|
| 117 |
+
colors_precomp,
|
| 118 |
+
means3D,
|
| 119 |
+
scales,
|
| 120 |
+
rotations,
|
| 121 |
+
cov3Ds_precomp,
|
| 122 |
+
radii,
|
| 123 |
+
sh,
|
| 124 |
+
geom_buffer,
|
| 125 |
+
binning_buffer,
|
| 126 |
+
img_buffer,
|
| 127 |
+
) = ctx.saved_tensors
|
| 128 |
+
|
| 129 |
+
# Restructure args as C++ method expects them
|
| 130 |
+
args = (
|
| 131 |
+
raster_settings.bg,
|
| 132 |
+
means3D,
|
| 133 |
+
radii,
|
| 134 |
+
colors_precomp,
|
| 135 |
+
scales,
|
| 136 |
+
rotations,
|
| 137 |
+
raster_settings.scale_modifier,
|
| 138 |
+
cov3Ds_precomp,
|
| 139 |
+
raster_settings.view_matrix,
|
| 140 |
+
raster_settings.proj_matrix,
|
| 141 |
+
raster_settings.tanfovx,
|
| 142 |
+
raster_settings.tanfovy,
|
| 143 |
+
grad_out_color,
|
| 144 |
+
sh,
|
| 145 |
+
raster_settings.sh_degree,
|
| 146 |
+
raster_settings.campos,
|
| 147 |
+
geom_buffer,
|
| 148 |
+
num_rendered,
|
| 149 |
+
binning_buffer,
|
| 150 |
+
img_buffer,
|
| 151 |
+
raster_settings.debug,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Compute gradients for relevant tensors by invoking backward method
|
| 155 |
+
if raster_settings.debug:
|
| 156 |
+
cpu_args = RasterizeGaussiansFunction._cpu_deep_copy_tuple(
|
| 157 |
+
input_tuple=args
|
| 158 |
+
) # Copy them before they can be corrupted
|
| 159 |
+
try:
|
| 160 |
+
(
|
| 161 |
+
grad_means2D,
|
| 162 |
+
grad_colors_precomp,
|
| 163 |
+
grad_opacities,
|
| 164 |
+
grad_means3D,
|
| 165 |
+
grad_cov3Ds_precomp,
|
| 166 |
+
grad_sh,
|
| 167 |
+
grad_scales,
|
| 168 |
+
grad_rotations,
|
| 169 |
+
) = dgr_ext.rasterize_gaussians_backward(*args)
|
| 170 |
+
except Exception as ex:
|
| 171 |
+
torch.save(cpu_args, "snapshot_bw.dump")
|
| 172 |
+
print(
|
| 173 |
+
"\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n"
|
| 174 |
+
)
|
| 175 |
+
raise ex
|
| 176 |
+
else:
|
| 177 |
+
(
|
| 178 |
+
grad_means2D,
|
| 179 |
+
grad_colors_precomp,
|
| 180 |
+
grad_opacities,
|
| 181 |
+
grad_means3D,
|
| 182 |
+
grad_cov3Ds_precomp,
|
| 183 |
+
grad_sh,
|
| 184 |
+
grad_scales,
|
| 185 |
+
grad_rotations,
|
| 186 |
+
) = dgr_ext.rasterize_gaussians_backward(*args)
|
| 187 |
+
|
| 188 |
+
grads = (
|
| 189 |
+
grad_means3D,
|
| 190 |
+
grad_means2D,
|
| 191 |
+
grad_sh,
|
| 192 |
+
grad_colors_precomp,
|
| 193 |
+
grad_opacities,
|
| 194 |
+
grad_scales,
|
| 195 |
+
grad_rotations,
|
| 196 |
+
grad_cov3Ds_precomp,
|
| 197 |
+
None,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return grads
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class GaussianRasterizationSettings(typing.NamedTuple):
|
| 204 |
+
img_h: int
|
| 205 |
+
img_w: int
|
| 206 |
+
tanfovx: float
|
| 207 |
+
tanfovy: float
|
| 208 |
+
bg: torch.Tensor
|
| 209 |
+
scale_modifier: float
|
| 210 |
+
view_matrix: torch.Tensor
|
| 211 |
+
proj_matrix: torch.Tensor
|
| 212 |
+
sh_degree: int
|
| 213 |
+
campos: torch.Tensor
|
| 214 |
+
prefiltered: bool
|
| 215 |
+
debug: bool
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class GaussianRasterizer(torch.nn.Module):
|
| 219 |
+
def __init__(self, raster_settings):
|
| 220 |
+
super(GaussianRasterizer, self).__init__()
|
| 221 |
+
self.raster_settings = raster_settings
|
| 222 |
+
|
| 223 |
+
def forward(
|
| 224 |
+
self,
|
| 225 |
+
means3D,
|
| 226 |
+
means2D,
|
| 227 |
+
opacities,
|
| 228 |
+
shs=None,
|
| 229 |
+
colors_precomp=None,
|
| 230 |
+
scales=None,
|
| 231 |
+
rotations=None,
|
| 232 |
+
cov3D_precomp=None,
|
| 233 |
+
):
|
| 234 |
+
raster_settings = self.raster_settings
|
| 235 |
+
|
| 236 |
+
if (shs is None and colors_precomp is None) or (
|
| 237 |
+
shs is not None and colors_precomp is not None
|
| 238 |
+
):
|
| 239 |
+
raise Exception(
|
| 240 |
+
"Please provide excatly one of either SHs or precomputed colors!"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
if ((scales is None or rotations is None) and cov3D_precomp is None) or (
|
| 244 |
+
(scales is not None or rotations is not None) and cov3D_precomp is not None
|
| 245 |
+
):
|
| 246 |
+
raise Exception(
|
| 247 |
+
"Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if shs is None:
|
| 251 |
+
shs = torch.Tensor([])
|
| 252 |
+
if colors_precomp is None:
|
| 253 |
+
colors_precomp = torch.Tensor([])
|
| 254 |
+
|
| 255 |
+
if scales is None:
|
| 256 |
+
scales = torch.Tensor([])
|
| 257 |
+
if rotations is None:
|
| 258 |
+
rotations = torch.Tensor([])
|
| 259 |
+
if cov3D_precomp is None:
|
| 260 |
+
cov3D_precomp = torch.Tensor([])
|
| 261 |
+
|
| 262 |
+
# Invoke C++/CUDA rasterization routine
|
| 263 |
+
return RasterizeGaussiansFunction.apply(
|
| 264 |
+
means3D,
|
| 265 |
+
means2D,
|
| 266 |
+
shs,
|
| 267 |
+
colors_precomp,
|
| 268 |
+
opacities,
|
| 269 |
+
scales,
|
| 270 |
+
rotations,
|
| 271 |
+
cov3D_precomp,
|
| 272 |
+
raster_settings,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class GaussianRasterizerWrapper(torch.nn.Module):
|
| 277 |
+
# Carving flowers on a mountain of dung code.
|
| 278 |
+
#
|
| 279 |
+
# This class is a wrapper for the GaussianRasterizer class.
|
| 280 |
+
# It is used to port for the GaussianCity project.
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
K,
|
| 284 |
+
sensor_size,
|
| 285 |
+
flip_lr=True,
|
| 286 |
+
flip_ud=False,
|
| 287 |
+
z_near=0.01,
|
| 288 |
+
z_far=50000.0,
|
| 289 |
+
device=torch.device("cuda"),
|
| 290 |
+
):
|
| 291 |
+
super(GaussianRasterizerWrapper, self).__init__()
|
| 292 |
+
self.flip_lr = flip_lr
|
| 293 |
+
self.flip_ud = flip_ud
|
| 294 |
+
self.z_near = z_near
|
| 295 |
+
self.z_far = z_far
|
| 296 |
+
self.device = device
|
| 297 |
+
# Shared camera parameters
|
| 298 |
+
self.K = K
|
| 299 |
+
self.sensor_size = sensor_size
|
| 300 |
+
self.fov_x, self.fov_y = self._intrinsic_to_fov()
|
| 301 |
+
self.P = self._get_projection_matrix()
|
| 302 |
+
|
| 303 |
+
def get_gaussian_rasterizer(self, cam_position, cam_quaternion):
|
| 304 |
+
# cam_position in (tx, ty, tz)
|
| 305 |
+
# cam_quaternion in (qx, qy, qz, qw)
|
| 306 |
+
return GaussianRasterizer(
|
| 307 |
+
raster_settings=self._get_gaussian_rasterization_settings(
|
| 308 |
+
cam_position, cam_quaternion
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def forward(
|
| 313 |
+
self, points, cam_position=None, cam_quaternion=None, gaussian_rasterizer=None
|
| 314 |
+
):
|
| 315 |
+
# points: [N, M], M -> 0:3 xyz, 3:4 opacity, 4:7 scale, 7:11 rotation, 11:14 rgbs
|
| 316 |
+
_, M = points.shape
|
| 317 |
+
assert M == 14, "The input tensor should have 14 channels."
|
| 318 |
+
|
| 319 |
+
if gaussian_rasterizer is None:
|
| 320 |
+
gaussian_rasterizer = self.get_gaussian_rasterizer(
|
| 321 |
+
cam_position, cam_quaternion
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
return self._get_gaussian_rasterization(points, gaussian_rasterizer)
|
| 325 |
+
|
| 326 |
+
def _intrinsic_to_fov(self):
|
| 327 |
+
# graphdeco-inria/gaussian-splatting/utils/graphics_utils.py#L76
|
| 328 |
+
fx, fy = self.K[0, 0], self.K[1, 1]
|
| 329 |
+
fov_x = 2 * np.arctan2(self.sensor_size[0], (2 * fx))
|
| 330 |
+
fov_y = 2 * np.arctan2(self.sensor_size[1], (2 * fy))
|
| 331 |
+
return fov_x, fov_y
|
| 332 |
+
|
| 333 |
+
def _get_projection_matrix(self):
|
| 334 |
+
fx = self.K[0, 0]
|
| 335 |
+
fy = self.K[1, 1]
|
| 336 |
+
cx = self.K[0, 2]
|
| 337 |
+
cy = self.K[1, 2]
|
| 338 |
+
|
| 339 |
+
P = np.zeros((4, 4), dtype=np.float32)
|
| 340 |
+
P[0, 0] = 2.0 * fx / self.sensor_size[0]
|
| 341 |
+
P[1, 1] = 2.0 * fy / self.sensor_size[1]
|
| 342 |
+
P[0, 2] = (2.0 * cx / self.sensor_size[0]) - 1.0
|
| 343 |
+
P[1, 2] = (2.0 * cy / self.sensor_size[1]) - 1.0
|
| 344 |
+
P[2, 2] = -(self.z_far + self.z_near) / (self.z_far - self.z_near)
|
| 345 |
+
P[3, 2] = -1.0
|
| 346 |
+
P[2, 3] = -2.0 * self.z_far * self.z_near / (self.z_far - self.z_near)
|
| 347 |
+
return torch.from_numpy(P).to(self.device)
|
| 348 |
+
|
| 349 |
+
def _get_w2c_matrix(self, cam_position, cam_quaternion):
|
| 350 |
+
if type(cam_position) is torch.Tensor:
|
| 351 |
+
cam_position = cam_position.cpu().numpy()
|
| 352 |
+
if type(cam_quaternion) is torch.Tensor:
|
| 353 |
+
cam_quaternion = cam_quaternion.cpu().numpy()
|
| 354 |
+
|
| 355 |
+
R = scipy.spatial.transform.Rotation.from_quat(cam_quaternion).as_matrix()
|
| 356 |
+
# look_at = cam_position + R[:3, 0]
|
| 357 |
+
R = R[:, [1, 2, 0]] # [F|R|U] -> [R|U|F]
|
| 358 |
+
# graphdeco-inria/gaussian-splatting/blob/main/scene/cameras.py#L31
|
| 359 |
+
# The w2c matrix
|
| 360 |
+
Rt = np.zeros((4, 4), dtype=np.float32)
|
| 361 |
+
Rt[:3, :3] = R.transpose()
|
| 362 |
+
Rt[:3, [3]] = -R.transpose() @ cam_position[:, None]
|
| 363 |
+
Rt[3, 3] = 1.0
|
| 364 |
+
# The c2w matrix
|
| 365 |
+
# Rt[:3, :3] = R
|
| 366 |
+
# Rt[:3, 3] = cam_position
|
| 367 |
+
# Rt[3, 3] = 1.0
|
| 368 |
+
return torch.from_numpy(Rt).to(self.device)
|
| 369 |
+
|
| 370 |
+
def _world_to_pixel(self, world_coords, w2c):
|
| 371 |
+
# NOTE: The function is used to debug whether the w2c matrix is correct.
|
| 372 |
+
# Convert world coordinates to camera coordinates using the inverse of w2c
|
| 373 |
+
camera_coords = np.dot(w2c[:3, :3], world_coords) + w2c[:3, 3]
|
| 374 |
+
# camera_coords = np.dot(np.linalg.inv(c2w[:3, :3]), (world_coords- w2c[:3, 3]))
|
| 375 |
+
# Apply the camera intrinsic matrix K to obtain normalized image coordinates
|
| 376 |
+
homogeneous_coords = np.dot(self.K, camera_coords)
|
| 377 |
+
# Normalize homogeneous coordinates
|
| 378 |
+
normalized_coords = homogeneous_coords / homogeneous_coords[2]
|
| 379 |
+
# Convert normalized coordinates to pixel coordinates
|
| 380 |
+
return normalized_coords[:2].astype(int)
|
| 381 |
+
|
| 382 |
+
def _get_gaussian_rasterization_settings(self, cam_position, cam_quaternion):
|
| 383 |
+
BG_COLOR = torch.tensor(
|
| 384 |
+
[0.0, 0.0, 0.0], dtype=torch.float32, device=self.device
|
| 385 |
+
)
|
| 386 |
+
w2c = self._get_w2c_matrix(cam_position, cam_quaternion).transpose(0, 1)
|
| 387 |
+
prj_mtx = self.P.transpose(0, 1)
|
| 388 |
+
|
| 389 |
+
return GaussianRasterizationSettings(
|
| 390 |
+
img_h=self.sensor_size[1],
|
| 391 |
+
img_w=self.sensor_size[0],
|
| 392 |
+
tanfovx=math.tan(self.fov_x * 0.5),
|
| 393 |
+
tanfovy=math.tan(self.fov_y * 0.5),
|
| 394 |
+
bg=BG_COLOR,
|
| 395 |
+
scale_modifier=1.0,
|
| 396 |
+
view_matrix=w2c,
|
| 397 |
+
proj_matrix=w2c @ prj_mtx,
|
| 398 |
+
sh_degree=0,
|
| 399 |
+
campos=w2c.inverse()[3, :3],
|
| 400 |
+
prefiltered=False,
|
| 401 |
+
debug=False,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
def _get_gaussian_rasterization(self, points, rasterizer):
|
| 405 |
+
xyz = points[:, 0:3]
|
| 406 |
+
opacity = points[:, 3:4]
|
| 407 |
+
scales = points[:, 4:7]
|
| 408 |
+
quaternion = points[:, 7:11]
|
| 409 |
+
rgbs = points[:, 11:]
|
| 410 |
+
|
| 411 |
+
rendered_image, _ = rasterizer(
|
| 412 |
+
means3D=xyz,
|
| 413 |
+
means2D=torch.zeros_like(xyz, dtype=torch.float32, device=self.device),
|
| 414 |
+
shs=None,
|
| 415 |
+
colors_precomp=rgbs,
|
| 416 |
+
opacities=opacity,
|
| 417 |
+
scales=scales,
|
| 418 |
+
rotations=quaternion,
|
| 419 |
+
cov3D_precomp=None,
|
| 420 |
+
)
|
| 421 |
+
if self.flip_lr:
|
| 422 |
+
rendered_image = torch.flip(rendered_image, dims=[2])
|
| 423 |
+
if self.flip_ud:
|
| 424 |
+
rendered_image = torch.flip(rendered_image, dims=[1])
|
| 425 |
+
|
| 426 |
+
return rendered_image
|
gaussiancity/extensions/diff_gaussian_rasterization/bindings.cpp
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include "rasterize_points.h"
|
| 13 |
+
#include <torch/extension.h>
|
| 14 |
+
|
| 15 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 16 |
+
m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
|
| 17 |
+
m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
|
| 18 |
+
m.def("mark_visible", &markVisible);
|
| 19 |
+
}
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/auxiliary.h
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
|
| 13 |
+
#define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
|
| 14 |
+
|
| 15 |
+
#include "config.h"
|
| 16 |
+
#include "stdio.h"
|
| 17 |
+
|
| 18 |
+
#define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
|
| 19 |
+
#define NUM_WARPS (BLOCK_SIZE / 32)
|
| 20 |
+
|
| 21 |
+
// Spherical harmonics coefficients
|
| 22 |
+
__device__ const float SH_C0 = 0.28209479177387814f;
|
| 23 |
+
__device__ const float SH_C1 = 0.4886025119029199f;
|
| 24 |
+
__device__ const float SH_C2[] = {1.0925484305920792f, -1.0925484305920792f,
|
| 25 |
+
0.31539156525252005f, -1.0925484305920792f,
|
| 26 |
+
0.5462742152960396f};
|
| 27 |
+
__device__ const float SH_C3[] = {-0.5900435899266435f, 2.890611442640554f,
|
| 28 |
+
-0.4570457994644658f, 0.3731763325901154f,
|
| 29 |
+
-0.4570457994644658f, 1.445305721320277f,
|
| 30 |
+
-0.5900435899266435f};
|
| 31 |
+
|
| 32 |
+
__forceinline__ __device__ float ndc2Pix(float v, int S) {
|
| 33 |
+
return ((v + 1.0) * S - 1.0) * 0.5;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
__forceinline__ __device__ void getRect(const float2 p, int max_radius,
|
| 37 |
+
uint2 &rect_min, uint2 &rect_max,
|
| 38 |
+
dim3 grid) {
|
| 39 |
+
rect_min = {min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
|
| 40 |
+
min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))};
|
| 41 |
+
rect_max = {
|
| 42 |
+
min(grid.x,
|
| 43 |
+
max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
|
| 44 |
+
min(grid.y,
|
| 45 |
+
max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))};
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
__forceinline__ __device__ float3 transformPoint4x3(const float3 &p,
|
| 49 |
+
const float *matrix) {
|
| 50 |
+
float3 transformed = {
|
| 51 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
|
| 52 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
|
| 53 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
|
| 54 |
+
};
|
| 55 |
+
return transformed;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
__forceinline__ __device__ float4 transformPoint4x4(const float3 &p,
|
| 59 |
+
const float *matrix) {
|
| 60 |
+
float4 transformed = {
|
| 61 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
|
| 62 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
|
| 63 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
|
| 64 |
+
matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]};
|
| 65 |
+
return transformed;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
__forceinline__ __device__ float3 transformVec4x3(const float3 &p,
|
| 69 |
+
const float *matrix) {
|
| 70 |
+
float3 transformed = {
|
| 71 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z,
|
| 72 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z,
|
| 73 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z,
|
| 74 |
+
};
|
| 75 |
+
return transformed;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
__forceinline__ __device__ float3
|
| 79 |
+
transformVec4x3Transpose(const float3 &p, const float *matrix) {
|
| 80 |
+
float3 transformed = {
|
| 81 |
+
matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z,
|
| 82 |
+
matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z,
|
| 83 |
+
matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z,
|
| 84 |
+
};
|
| 85 |
+
return transformed;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
__forceinline__ __device__ float dnormvdz(float3 v, float3 dv) {
|
| 89 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
|
| 90 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
| 91 |
+
float dnormvdz =
|
| 92 |
+
(-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) *
|
| 93 |
+
invsum32;
|
| 94 |
+
return dnormvdz;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
__forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) {
|
| 98 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
|
| 99 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
| 100 |
+
|
| 101 |
+
float3 dnormvdv;
|
| 102 |
+
dnormvdv.x =
|
| 103 |
+
((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) *
|
| 104 |
+
invsum32;
|
| 105 |
+
dnormvdv.y =
|
| 106 |
+
(-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) *
|
| 107 |
+
invsum32;
|
| 108 |
+
dnormvdv.z =
|
| 109 |
+
(-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) *
|
| 110 |
+
invsum32;
|
| 111 |
+
return dnormvdv;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
__forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) {
|
| 115 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
|
| 116 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
| 117 |
+
|
| 118 |
+
float4 vdv = {v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w};
|
| 119 |
+
float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
|
| 120 |
+
float4 dnormvdv;
|
| 121 |
+
dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32;
|
| 122 |
+
dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32;
|
| 123 |
+
dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32;
|
| 124 |
+
dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32;
|
| 125 |
+
return dnormvdv;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
__forceinline__ __device__ float sigmoid(float x) {
|
| 129 |
+
return 1.0f / (1.0f + expf(-x));
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
__forceinline__ __device__ bool in_frustum(int idx, const float *orig_points,
|
| 133 |
+
const float *viewmatrix,
|
| 134 |
+
const float *projmatrix,
|
| 135 |
+
bool prefiltered, float3 &p_view) {
|
| 136 |
+
float3 p_orig = {orig_points[3 * idx], orig_points[3 * idx + 1],
|
| 137 |
+
orig_points[3 * idx + 2]};
|
| 138 |
+
|
| 139 |
+
// Bring points to screen space
|
| 140 |
+
float4 p_hom = transformPoint4x4(p_orig, projmatrix);
|
| 141 |
+
float p_w = 1.0f / (p_hom.w + 0.0000001f);
|
| 142 |
+
float3 p_proj = {p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w};
|
| 143 |
+
p_view = transformPoint4x3(p_orig, viewmatrix);
|
| 144 |
+
|
| 145 |
+
if (p_view.z <= 0.2f) // || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y <
|
| 146 |
+
// -1.3 || p_proj.y > 1.3)))
|
| 147 |
+
{
|
| 148 |
+
if (prefiltered) {
|
| 149 |
+
printf("Point is filtered although prefiltered is set. This shouldn't "
|
| 150 |
+
"happen!");
|
| 151 |
+
__trap();
|
| 152 |
+
}
|
| 153 |
+
return false;
|
| 154 |
+
}
|
| 155 |
+
return true;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
#define CHECK_CUDA(A, debug) \
|
| 159 |
+
A; \
|
| 160 |
+
if (debug) { \
|
| 161 |
+
auto ret = cudaDeviceSynchronize(); \
|
| 162 |
+
if (ret != cudaSuccess) { \
|
| 163 |
+
std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ \
|
| 164 |
+
<< ": " << cudaGetErrorString(ret); \
|
| 165 |
+
throw std::runtime_error(cudaGetErrorString(ret)); \
|
| 166 |
+
} \
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
#endif
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/backward.cu
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include "auxiliary.h"
|
| 13 |
+
#include "backward.h"
|
| 14 |
+
#include <cooperative_groups.h>
|
| 15 |
+
#include <cooperative_groups/reduce.h>
|
| 16 |
+
namespace cg = cooperative_groups;
|
| 17 |
+
|
| 18 |
+
// Backward pass for conversion of spherical harmonics to RGB for
|
| 19 |
+
// each Gaussian.
|
| 20 |
+
__device__ void computeColorFromSH(int idx, int deg, int max_coeffs,
|
| 21 |
+
const glm::vec3 *means, glm::vec3 campos,
|
| 22 |
+
const float *shs, const bool *clamped,
|
| 23 |
+
const glm::vec3 *dL_dcolor,
|
| 24 |
+
glm::vec3 *dL_dmeans, glm::vec3 *dL_dshs) {
|
| 25 |
+
// Compute intermediate values, as it is done during forward
|
| 26 |
+
glm::vec3 pos = means[idx];
|
| 27 |
+
glm::vec3 dir_orig = pos - campos;
|
| 28 |
+
glm::vec3 dir = dir_orig / glm::length(dir_orig);
|
| 29 |
+
|
| 30 |
+
glm::vec3 *sh = ((glm::vec3 *)shs) + idx * max_coeffs;
|
| 31 |
+
|
| 32 |
+
// Use PyTorch rule for clamping: if clamping was applied,
|
| 33 |
+
// gradient becomes 0.
|
| 34 |
+
glm::vec3 dL_dRGB = dL_dcolor[idx];
|
| 35 |
+
dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
|
| 36 |
+
dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
|
| 37 |
+
dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
|
| 38 |
+
|
| 39 |
+
glm::vec3 dRGBdx(0, 0, 0);
|
| 40 |
+
glm::vec3 dRGBdy(0, 0, 0);
|
| 41 |
+
glm::vec3 dRGBdz(0, 0, 0);
|
| 42 |
+
float x = dir.x;
|
| 43 |
+
float y = dir.y;
|
| 44 |
+
float z = dir.z;
|
| 45 |
+
|
| 46 |
+
// Target location for this Gaussian to write SH gradients to
|
| 47 |
+
glm::vec3 *dL_dsh = dL_dshs + idx * max_coeffs;
|
| 48 |
+
|
| 49 |
+
// No tricks here, just high school-level calculus.
|
| 50 |
+
float dRGBdsh0 = SH_C0;
|
| 51 |
+
dL_dsh[0] = dRGBdsh0 * dL_dRGB;
|
| 52 |
+
if (deg > 0) {
|
| 53 |
+
float dRGBdsh1 = -SH_C1 * y;
|
| 54 |
+
float dRGBdsh2 = SH_C1 * z;
|
| 55 |
+
float dRGBdsh3 = -SH_C1 * x;
|
| 56 |
+
dL_dsh[1] = dRGBdsh1 * dL_dRGB;
|
| 57 |
+
dL_dsh[2] = dRGBdsh2 * dL_dRGB;
|
| 58 |
+
dL_dsh[3] = dRGBdsh3 * dL_dRGB;
|
| 59 |
+
|
| 60 |
+
dRGBdx = -SH_C1 * sh[3];
|
| 61 |
+
dRGBdy = -SH_C1 * sh[1];
|
| 62 |
+
dRGBdz = SH_C1 * sh[2];
|
| 63 |
+
|
| 64 |
+
if (deg > 1) {
|
| 65 |
+
float xx = x * x, yy = y * y, zz = z * z;
|
| 66 |
+
float xy = x * y, yz = y * z, xz = x * z;
|
| 67 |
+
|
| 68 |
+
float dRGBdsh4 = SH_C2[0] * xy;
|
| 69 |
+
float dRGBdsh5 = SH_C2[1] * yz;
|
| 70 |
+
float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
|
| 71 |
+
float dRGBdsh7 = SH_C2[3] * xz;
|
| 72 |
+
float dRGBdsh8 = SH_C2[4] * (xx - yy);
|
| 73 |
+
dL_dsh[4] = dRGBdsh4 * dL_dRGB;
|
| 74 |
+
dL_dsh[5] = dRGBdsh5 * dL_dRGB;
|
| 75 |
+
dL_dsh[6] = dRGBdsh6 * dL_dRGB;
|
| 76 |
+
dL_dsh[7] = dRGBdsh7 * dL_dRGB;
|
| 77 |
+
dL_dsh[8] = dRGBdsh8 * dL_dRGB;
|
| 78 |
+
|
| 79 |
+
dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] +
|
| 80 |
+
SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
|
| 81 |
+
dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] +
|
| 82 |
+
SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
|
| 83 |
+
dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] +
|
| 84 |
+
SH_C2[3] * x * sh[7];
|
| 85 |
+
|
| 86 |
+
if (deg > 2) {
|
| 87 |
+
float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
|
| 88 |
+
float dRGBdsh10 = SH_C3[1] * xy * z;
|
| 89 |
+
float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
|
| 90 |
+
float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
|
| 91 |
+
float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
|
| 92 |
+
float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
|
| 93 |
+
float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
|
| 94 |
+
dL_dsh[9] = dRGBdsh9 * dL_dRGB;
|
| 95 |
+
dL_dsh[10] = dRGBdsh10 * dL_dRGB;
|
| 96 |
+
dL_dsh[11] = dRGBdsh11 * dL_dRGB;
|
| 97 |
+
dL_dsh[12] = dRGBdsh12 * dL_dRGB;
|
| 98 |
+
dL_dsh[13] = dRGBdsh13 * dL_dRGB;
|
| 99 |
+
dL_dsh[14] = dRGBdsh14 * dL_dRGB;
|
| 100 |
+
dL_dsh[15] = dRGBdsh15 * dL_dRGB;
|
| 101 |
+
|
| 102 |
+
dRGBdx += (SH_C3[0] * sh[9] * 3.f * 2.f * xy + SH_C3[1] * sh[10] * yz +
|
| 103 |
+
SH_C3[2] * sh[11] * -2.f * xy +
|
| 104 |
+
SH_C3[3] * sh[12] * -3.f * 2.f * xz +
|
| 105 |
+
SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
|
| 106 |
+
SH_C3[5] * sh[14] * 2.f * xz +
|
| 107 |
+
SH_C3[6] * sh[15] * 3.f * (xx - yy));
|
| 108 |
+
|
| 109 |
+
dRGBdy +=
|
| 110 |
+
(SH_C3[0] * sh[9] * 3.f * (xx - yy) + SH_C3[1] * sh[10] * xz +
|
| 111 |
+
SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
|
| 112 |
+
SH_C3[3] * sh[12] * -3.f * 2.f * yz +
|
| 113 |
+
SH_C3[4] * sh[13] * -2.f * xy + SH_C3[5] * sh[14] * -2.f * yz +
|
| 114 |
+
SH_C3[6] * sh[15] * -3.f * 2.f * xy);
|
| 115 |
+
|
| 116 |
+
dRGBdz += (SH_C3[1] * sh[10] * xy + SH_C3[2] * sh[11] * 4.f * 2.f * yz +
|
| 117 |
+
SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
|
| 118 |
+
SH_C3[4] * sh[13] * 4.f * 2.f * xz +
|
| 119 |
+
SH_C3[5] * sh[14] * (xx - yy));
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
// The view direction is an input to the computation. View direction
|
| 125 |
+
// is influenced by the Gaussian's mean, so SHs gradients
|
| 126 |
+
// must propagate back into 3D position.
|
| 127 |
+
glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB),
|
| 128 |
+
glm::dot(dRGBdz, dL_dRGB));
|
| 129 |
+
|
| 130 |
+
// Account for normalization of direction
|
| 131 |
+
float3 dL_dmean = dnormvdv(float3{dir_orig.x, dir_orig.y, dir_orig.z},
|
| 132 |
+
float3{dL_ddir.x, dL_ddir.y, dL_ddir.z});
|
| 133 |
+
|
| 134 |
+
// Gradients of loss w.r.t. Gaussian means, but only the portion
|
| 135 |
+
// that is caused because the mean affects the view-dependent color.
|
| 136 |
+
// Additional mean gradient is accumulated in below methods.
|
| 137 |
+
dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Backward version of INVERSE 2D covariance matrix computation
|
| 141 |
+
// (due to length launched as separate kernel before other
|
| 142 |
+
// backward steps contained in preprocess)
|
| 143 |
+
__global__ void computeCov2DCUDA(int P, const float3 *means, const int *radii,
|
| 144 |
+
const float *cov3Ds, const float h_x,
|
| 145 |
+
float h_y, const float tan_fovx,
|
| 146 |
+
float tan_fovy, const float *view_matrix,
|
| 147 |
+
const float *dL_dconics, float3 *dL_dmeans,
|
| 148 |
+
float *dL_dcov) {
|
| 149 |
+
auto idx = cg::this_grid().thread_rank();
|
| 150 |
+
if (idx >= P || !(radii[idx] > 0))
|
| 151 |
+
return;
|
| 152 |
+
|
| 153 |
+
// Reading location of 3D covariance for this Gaussian
|
| 154 |
+
const float *cov3D = cov3Ds + 6 * idx;
|
| 155 |
+
|
| 156 |
+
// Fetch gradients, recompute 2D covariance and relevant
|
| 157 |
+
// intermediate forward results needed in the backward.
|
| 158 |
+
float3 mean = means[idx];
|
| 159 |
+
float3 dL_dconic = {dL_dconics[4 * idx], dL_dconics[4 * idx + 1],
|
| 160 |
+
dL_dconics[4 * idx + 3]};
|
| 161 |
+
float3 t = transformPoint4x3(mean, view_matrix);
|
| 162 |
+
|
| 163 |
+
const float limx = 1.3f * tan_fovx;
|
| 164 |
+
const float limy = 1.3f * tan_fovy;
|
| 165 |
+
const float txtz = t.x / t.z;
|
| 166 |
+
const float tytz = t.y / t.z;
|
| 167 |
+
t.x = min(limx, max(-limx, txtz)) * t.z;
|
| 168 |
+
t.y = min(limy, max(-limy, tytz)) * t.z;
|
| 169 |
+
|
| 170 |
+
const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
|
| 171 |
+
const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;
|
| 172 |
+
|
| 173 |
+
glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z), 0.0f,
|
| 174 |
+
h_y / t.z, -(h_y * t.y) / (t.z * t.z), 0, 0, 0);
|
| 175 |
+
|
| 176 |
+
glm::mat3 W = glm::mat3(view_matrix[0], view_matrix[4], view_matrix[8],
|
| 177 |
+
view_matrix[1], view_matrix[5], view_matrix[9],
|
| 178 |
+
view_matrix[2], view_matrix[6], view_matrix[10]);
|
| 179 |
+
|
| 180 |
+
glm::mat3 Vrk = glm::mat3(cov3D[0], cov3D[1], cov3D[2], cov3D[1], cov3D[3],
|
| 181 |
+
cov3D[4], cov3D[2], cov3D[4], cov3D[5]);
|
| 182 |
+
|
| 183 |
+
glm::mat3 T = W * J;
|
| 184 |
+
|
| 185 |
+
glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T;
|
| 186 |
+
|
| 187 |
+
// Use helper variables for 2D covariance entries. More compact.
|
| 188 |
+
float a = cov2D[0][0] += 0.3f;
|
| 189 |
+
float b = cov2D[0][1];
|
| 190 |
+
float c = cov2D[1][1] += 0.3f;
|
| 191 |
+
|
| 192 |
+
float denom = a * c - b * b;
|
| 193 |
+
float dL_da = 0, dL_db = 0, dL_dc = 0;
|
| 194 |
+
float denom2inv = 1.0f / ((denom * denom) + 0.0000001f);
|
| 195 |
+
|
| 196 |
+
if (denom2inv != 0) {
|
| 197 |
+
// Gradients of loss w.r.t. entries of 2D covariance matrix,
|
| 198 |
+
// given gradients of loss w.r.t. conic matrix (inverse covariance matrix).
|
| 199 |
+
// e.g., dL / da = dL / d_conic_a * d_conic_a / d_a
|
| 200 |
+
dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y +
|
| 201 |
+
(denom - a * c) * dL_dconic.z);
|
| 202 |
+
dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y +
|
| 203 |
+
(denom - a * c) * dL_dconic.x);
|
| 204 |
+
dL_db = denom2inv * 2 *
|
| 205 |
+
(b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y +
|
| 206 |
+
a * b * dL_dconic.z);
|
| 207 |
+
|
| 208 |
+
// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
|
| 209 |
+
// given gradients w.r.t. 2D covariance matrix (diagonal).
|
| 210 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
| 211 |
+
dL_dcov[6 * idx + 0] =
|
| 212 |
+
(T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db +
|
| 213 |
+
T[1][0] * T[1][0] * dL_dc);
|
| 214 |
+
dL_dcov[6 * idx + 3] =
|
| 215 |
+
(T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db +
|
| 216 |
+
T[1][1] * T[1][1] * dL_dc);
|
| 217 |
+
dL_dcov[6 * idx + 5] =
|
| 218 |
+
(T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db +
|
| 219 |
+
T[1][2] * T[1][2] * dL_dc);
|
| 220 |
+
|
| 221 |
+
// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
|
| 222 |
+
// given gradients w.r.t. 2D covariance matrix (off-diagonal).
|
| 223 |
+
// Off-diagonal elements appear twice --> double the gradient.
|
| 224 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
| 225 |
+
dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da +
|
| 226 |
+
(T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db +
|
| 227 |
+
2 * T[1][0] * T[1][1] * dL_dc;
|
| 228 |
+
dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da +
|
| 229 |
+
(T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db +
|
| 230 |
+
2 * T[1][0] * T[1][2] * dL_dc;
|
| 231 |
+
dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da +
|
| 232 |
+
(T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db +
|
| 233 |
+
2 * T[1][1] * T[1][2] * dL_dc;
|
| 234 |
+
} else {
|
| 235 |
+
for (int i = 0; i < 6; i++)
|
| 236 |
+
dL_dcov[6 * idx + i] = 0;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
// Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T
|
| 240 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
| 241 |
+
float dL_dT00 =
|
| 242 |
+
2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) *
|
| 243 |
+
dL_da +
|
| 244 |
+
(T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db;
|
| 245 |
+
float dL_dT01 =
|
| 246 |
+
2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) *
|
| 247 |
+
dL_da +
|
| 248 |
+
(T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db;
|
| 249 |
+
float dL_dT02 =
|
| 250 |
+
2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) *
|
| 251 |
+
dL_da +
|
| 252 |
+
(T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db;
|
| 253 |
+
float dL_dT10 =
|
| 254 |
+
2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) *
|
| 255 |
+
dL_dc +
|
| 256 |
+
(T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db;
|
| 257 |
+
float dL_dT11 =
|
| 258 |
+
2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) *
|
| 259 |
+
dL_dc +
|
| 260 |
+
(T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db;
|
| 261 |
+
float dL_dT12 =
|
| 262 |
+
2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) *
|
| 263 |
+
dL_dc +
|
| 264 |
+
(T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db;
|
| 265 |
+
|
| 266 |
+
// Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix
|
| 267 |
+
// T = W * J
|
| 268 |
+
float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02;
|
| 269 |
+
float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02;
|
| 270 |
+
float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12;
|
| 271 |
+
float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12;
|
| 272 |
+
|
| 273 |
+
float tz = 1.f / t.z;
|
| 274 |
+
float tz2 = tz * tz;
|
| 275 |
+
float tz3 = tz2 * tz;
|
| 276 |
+
|
| 277 |
+
// Gradients of loss w.r.t. transformed Gaussian mean t
|
| 278 |
+
float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
|
| 279 |
+
float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
|
| 280 |
+
float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 +
|
| 281 |
+
(2 * h_x * t.x) * tz3 * dL_dJ02 +
|
| 282 |
+
(2 * h_y * t.y) * tz3 * dL_dJ12;
|
| 283 |
+
|
| 284 |
+
// Account for transformation of mean to t
|
| 285 |
+
// t = transformPoint4x3(mean, view_matrix);
|
| 286 |
+
float3 dL_dmean =
|
| 287 |
+
transformVec4x3Transpose({dL_dtx, dL_dty, dL_dtz}, view_matrix);
|
| 288 |
+
|
| 289 |
+
// Gradients of loss w.r.t. Gaussian means, but only the portion
|
| 290 |
+
// that is caused because the mean affects the covariance matrix.
|
| 291 |
+
// Additional mean gradient is accumulated in BACKWARD::preprocess.
|
| 292 |
+
dL_dmeans[idx] = dL_dmean;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
// Backward pass for the conversion of scale and rotation to a
|
| 296 |
+
// 3D covariance matrix for each Gaussian.
|
| 297 |
+
__device__ void computeCov3D(int idx, const glm::vec3 scale, float mod,
|
| 298 |
+
const glm::vec4 rot, const float *dL_dcov3Ds,
|
| 299 |
+
glm::vec3 *dL_dscales, glm::vec4 *dL_drots) {
|
| 300 |
+
// Recompute (intermediate) results for the 3D covariance computation.
|
| 301 |
+
glm::vec4 q = rot; // / glm::length(rot);
|
| 302 |
+
float r = q.x;
|
| 303 |
+
float x = q.y;
|
| 304 |
+
float y = q.z;
|
| 305 |
+
float z = q.w;
|
| 306 |
+
|
| 307 |
+
glm::mat3 R = glm::mat3(1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z),
|
| 308 |
+
2.f * (x * z + r * y), 2.f * (x * y + r * z),
|
| 309 |
+
1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
|
| 310 |
+
2.f * (x * z - r * y), 2.f * (y * z + r * x),
|
| 311 |
+
1.f - 2.f * (x * x + y * y));
|
| 312 |
+
|
| 313 |
+
glm::mat3 S = glm::mat3(1.0f);
|
| 314 |
+
|
| 315 |
+
glm::vec3 s = mod * scale;
|
| 316 |
+
S[0][0] = s.x;
|
| 317 |
+
S[1][1] = s.y;
|
| 318 |
+
S[2][2] = s.z;
|
| 319 |
+
|
| 320 |
+
glm::mat3 M = S * R;
|
| 321 |
+
|
| 322 |
+
const float *dL_dcov3D = dL_dcov3Ds + 6 * idx;
|
| 323 |
+
|
| 324 |
+
glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]);
|
| 325 |
+
glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]);
|
| 326 |
+
|
| 327 |
+
// Convert per-element covariance loss gradients to matrix form
|
| 328 |
+
glm::mat3 dL_dSigma =
|
| 329 |
+
glm::mat3(dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2],
|
| 330 |
+
0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4],
|
| 331 |
+
0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5]);
|
| 332 |
+
|
| 333 |
+
// Compute loss gradient w.r.t. matrix M
|
| 334 |
+
// dSigma_dM = 2 * M
|
| 335 |
+
glm::mat3 dL_dM = 2.0f * M * dL_dSigma;
|
| 336 |
+
|
| 337 |
+
glm::mat3 Rt = glm::transpose(R);
|
| 338 |
+
glm::mat3 dL_dMt = glm::transpose(dL_dM);
|
| 339 |
+
|
| 340 |
+
// Gradients of loss w.r.t. scale
|
| 341 |
+
glm::vec3 *dL_dscale = dL_dscales + idx;
|
| 342 |
+
dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
|
| 343 |
+
dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
|
| 344 |
+
dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);
|
| 345 |
+
|
| 346 |
+
dL_dMt[0] *= s.x;
|
| 347 |
+
dL_dMt[1] *= s.y;
|
| 348 |
+
dL_dMt[2] *= s.z;
|
| 349 |
+
|
| 350 |
+
// Gradients of loss w.r.t. normalized quaternion
|
| 351 |
+
glm::vec4 dL_dq;
|
| 352 |
+
dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) +
|
| 353 |
+
2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) +
|
| 354 |
+
2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
|
| 355 |
+
dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) +
|
| 356 |
+
2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) +
|
| 357 |
+
2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) -
|
| 358 |
+
4 * x * (dL_dMt[2][2] + dL_dMt[1][1]);
|
| 359 |
+
dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) +
|
| 360 |
+
2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) +
|
| 361 |
+
2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) -
|
| 362 |
+
4 * y * (dL_dMt[2][2] + dL_dMt[0][0]);
|
| 363 |
+
dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) +
|
| 364 |
+
2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) +
|
| 365 |
+
2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) -
|
| 366 |
+
4 * z * (dL_dMt[1][1] + dL_dMt[0][0]);
|
| 367 |
+
|
| 368 |
+
// Gradients of loss w.r.t. unnormalized quaternion
|
| 369 |
+
float4 *dL_drot = (float4 *)(dL_drots + idx);
|
| 370 |
+
*dL_drot = float4{dL_dq.x, dL_dq.y, dL_dq.z,
|
| 371 |
+
dL_dq.w}; // dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w },
|
| 372 |
+
// float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
// Backward pass of the preprocessing steps, except
|
| 376 |
+
// for the covariance computation and inversion
|
| 377 |
+
// (those are handled by a previous kernel call)
|
| 378 |
+
template <int C>
|
| 379 |
+
__global__ void
|
| 380 |
+
preprocessCUDA(int P, int D, int M, const float3 *means, const int *radii,
|
| 381 |
+
const float *shs, const bool *clamped, const glm::vec3 *scales,
|
| 382 |
+
const glm::vec4 *rotations, const float scale_modifier,
|
| 383 |
+
const float *proj, const glm::vec3 *campos,
|
| 384 |
+
const float3 *dL_dmean2D, glm::vec3 *dL_dmeans, float *dL_dcolor,
|
| 385 |
+
float *dL_dcov3D, float *dL_dsh, glm::vec3 *dL_dscale,
|
| 386 |
+
glm::vec4 *dL_drot) {
|
| 387 |
+
auto idx = cg::this_grid().thread_rank();
|
| 388 |
+
if (idx >= P || !(radii[idx] > 0))
|
| 389 |
+
return;
|
| 390 |
+
|
| 391 |
+
float3 m = means[idx];
|
| 392 |
+
|
| 393 |
+
// Taking care of gradients from the screenspace points
|
| 394 |
+
float4 m_hom = transformPoint4x4(m, proj);
|
| 395 |
+
float m_w = 1.0f / (m_hom.w + 0.0000001f);
|
| 396 |
+
|
| 397 |
+
// Compute loss gradient w.r.t. 3D means due to gradients of 2D means
|
| 398 |
+
// from rendering procedure
|
| 399 |
+
glm::vec3 dL_dmean;
|
| 400 |
+
float mul1 =
|
| 401 |
+
(proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w;
|
| 402 |
+
float mul2 =
|
| 403 |
+
(proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w;
|
| 404 |
+
dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x +
|
| 405 |
+
(proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y;
|
| 406 |
+
dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x +
|
| 407 |
+
(proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y;
|
| 408 |
+
dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x +
|
| 409 |
+
(proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y;
|
| 410 |
+
|
| 411 |
+
// That's the second part of the mean gradient. Previous computation
|
| 412 |
+
// of cov2D and following SH conversion also affects it.
|
| 413 |
+
dL_dmeans[idx] += dL_dmean;
|
| 414 |
+
|
| 415 |
+
// Compute gradient updates due to computing colors from SHs
|
| 416 |
+
if (shs)
|
| 417 |
+
computeColorFromSH(idx, D, M, (glm::vec3 *)means, *campos, shs, clamped,
|
| 418 |
+
(glm::vec3 *)dL_dcolor, (glm::vec3 *)dL_dmeans,
|
| 419 |
+
(glm::vec3 *)dL_dsh);
|
| 420 |
+
|
| 421 |
+
// Compute gradient updates due to computing covariance from scale/rotation
|
| 422 |
+
if (scales)
|
| 423 |
+
computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D,
|
| 424 |
+
dL_dscale, dL_drot);
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
// Backward version of the rendering procedure.
|
| 428 |
+
template <uint32_t C>
|
| 429 |
+
__global__ void __launch_bounds__(BLOCK_X *BLOCK_Y) renderCUDA(
|
| 430 |
+
const uint2 *__restrict__ ranges, const uint32_t *__restrict__ point_list,
|
| 431 |
+
int W, int H, const float *__restrict__ bg_color,
|
| 432 |
+
const float2 *__restrict__ points_xy_image,
|
| 433 |
+
const float4 *__restrict__ conic_opacity, const float *__restrict__ colors,
|
| 434 |
+
const float *__restrict__ final_Ts, const uint32_t *__restrict__ n_contrib,
|
| 435 |
+
const float *__restrict__ dL_dpixels, float3 *__restrict__ dL_dmean2D,
|
| 436 |
+
float4 *__restrict__ dL_dconic2D, float *__restrict__ dL_dopacity,
|
| 437 |
+
float *__restrict__ dL_dcolors) {
|
| 438 |
+
// We rasterize again. Compute necessary block info.
|
| 439 |
+
auto block = cg::this_thread_block();
|
| 440 |
+
const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
|
| 441 |
+
const uint2 pix_min = {block.group_index().x * BLOCK_X,
|
| 442 |
+
block.group_index().y * BLOCK_Y};
|
| 443 |
+
const uint2 pix_max = {min(pix_min.x + BLOCK_X, W),
|
| 444 |
+
min(pix_min.y + BLOCK_Y, H)};
|
| 445 |
+
const uint2 pix = {pix_min.x + block.thread_index().x,
|
| 446 |
+
pix_min.y + block.thread_index().y};
|
| 447 |
+
const uint32_t pix_id = W * pix.y + pix.x;
|
| 448 |
+
const float2 pixf = {(float)pix.x, (float)pix.y};
|
| 449 |
+
|
| 450 |
+
const bool inside = pix.x < W && pix.y < H;
|
| 451 |
+
const uint2 range =
|
| 452 |
+
ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
|
| 453 |
+
|
| 454 |
+
const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
| 455 |
+
|
| 456 |
+
bool done = !inside;
|
| 457 |
+
int toDo = range.y - range.x;
|
| 458 |
+
|
| 459 |
+
__shared__ int collected_id[BLOCK_SIZE];
|
| 460 |
+
__shared__ float2 collected_xy[BLOCK_SIZE];
|
| 461 |
+
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
|
| 462 |
+
__shared__ float collected_colors[C * BLOCK_SIZE];
|
| 463 |
+
|
| 464 |
+
// In the forward, we stored the final value for T, the
|
| 465 |
+
// product of all (1 - alpha) factors.
|
| 466 |
+
const float T_final = inside ? final_Ts[pix_id] : 0;
|
| 467 |
+
float T = T_final;
|
| 468 |
+
|
| 469 |
+
// We start from the back. The ID of the last contributing
|
| 470 |
+
// Gaussian is known from each pixel from the forward.
|
| 471 |
+
uint32_t contributor = toDo;
|
| 472 |
+
const int last_contributor = inside ? n_contrib[pix_id] : 0;
|
| 473 |
+
|
| 474 |
+
float accum_rec[C] = {0};
|
| 475 |
+
float dL_dpixel[C];
|
| 476 |
+
if (inside)
|
| 477 |
+
for (int i = 0; i < C; i++)
|
| 478 |
+
dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
|
| 479 |
+
|
| 480 |
+
float last_alpha = 0;
|
| 481 |
+
float last_color[C] = {0};
|
| 482 |
+
|
| 483 |
+
// Gradient of pixel coordinate w.r.t. normalized
|
| 484 |
+
// screen-space viewport corrdinates (-1 to 1)
|
| 485 |
+
const float ddelx_dx = 0.5 * W;
|
| 486 |
+
const float ddely_dy = 0.5 * H;
|
| 487 |
+
|
| 488 |
+
// Traverse all Gaussians
|
| 489 |
+
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) {
|
| 490 |
+
// Load auxiliary data into shared memory, start in the BACK
|
| 491 |
+
// and load them in revers order.
|
| 492 |
+
block.sync();
|
| 493 |
+
const int progress = i * BLOCK_SIZE + block.thread_rank();
|
| 494 |
+
if (range.x + progress < range.y) {
|
| 495 |
+
const int coll_id = point_list[range.y - progress - 1];
|
| 496 |
+
collected_id[block.thread_rank()] = coll_id;
|
| 497 |
+
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
|
| 498 |
+
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
|
| 499 |
+
for (int i = 0; i < C; i++)
|
| 500 |
+
collected_colors[i * BLOCK_SIZE + block.thread_rank()] =
|
| 501 |
+
colors[coll_id * C + i];
|
| 502 |
+
}
|
| 503 |
+
block.sync();
|
| 504 |
+
|
| 505 |
+
// Iterate over Gaussians
|
| 506 |
+
for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) {
|
| 507 |
+
// Keep track of current Gaussian ID. Skip, if this one
|
| 508 |
+
// is behind the last contributor for this pixel.
|
| 509 |
+
contributor--;
|
| 510 |
+
if (contributor >= last_contributor)
|
| 511 |
+
continue;
|
| 512 |
+
|
| 513 |
+
// Compute blending values, as before.
|
| 514 |
+
const float2 xy = collected_xy[j];
|
| 515 |
+
const float2 d = {xy.x - pixf.x, xy.y - pixf.y};
|
| 516 |
+
const float4 con_o = collected_conic_opacity[j];
|
| 517 |
+
const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) -
|
| 518 |
+
con_o.y * d.x * d.y;
|
| 519 |
+
if (power > 0.0f)
|
| 520 |
+
continue;
|
| 521 |
+
|
| 522 |
+
const float G = exp(power);
|
| 523 |
+
const float alpha = min(0.99f, con_o.w * G);
|
| 524 |
+
if (alpha < 1.0f / 255.0f)
|
| 525 |
+
continue;
|
| 526 |
+
|
| 527 |
+
T = T / (1.f - alpha);
|
| 528 |
+
const float dchannel_dcolor = alpha * T;
|
| 529 |
+
|
| 530 |
+
// Propagate gradients to per-Gaussian colors and keep
|
| 531 |
+
// gradients w.r.t. alpha (blending factor for a Gaussian/pixel
|
| 532 |
+
// pair).
|
| 533 |
+
float dL_dalpha = 0.0f;
|
| 534 |
+
const int global_id = collected_id[j];
|
| 535 |
+
for (int ch = 0; ch < C; ch++) {
|
| 536 |
+
const float c = collected_colors[ch * BLOCK_SIZE + j];
|
| 537 |
+
// Update last color (to be used in the next iteration)
|
| 538 |
+
accum_rec[ch] =
|
| 539 |
+
last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
|
| 540 |
+
last_color[ch] = c;
|
| 541 |
+
|
| 542 |
+
const float dL_dchannel = dL_dpixel[ch];
|
| 543 |
+
dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
|
| 544 |
+
// Update the gradients w.r.t. color of the Gaussian.
|
| 545 |
+
// Atomic, since this pixel is just one of potentially
|
| 546 |
+
// many that were affected by this Gaussian.
|
| 547 |
+
atomicAdd(&(dL_dcolors[global_id * C + ch]),
|
| 548 |
+
dchannel_dcolor * dL_dchannel);
|
| 549 |
+
}
|
| 550 |
+
dL_dalpha *= T;
|
| 551 |
+
// Update last alpha (to be used in the next iteration)
|
| 552 |
+
last_alpha = alpha;
|
| 553 |
+
|
| 554 |
+
// Account for fact that alpha also influences how much of
|
| 555 |
+
// the background color is added if nothing left to blend
|
| 556 |
+
float bg_dot_dpixel = 0;
|
| 557 |
+
for (int i = 0; i < C; i++)
|
| 558 |
+
bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
|
| 559 |
+
dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;
|
| 560 |
+
|
| 561 |
+
// Helpful reusable temporary variables
|
| 562 |
+
const float dL_dG = con_o.w * dL_dalpha;
|
| 563 |
+
const float gdx = G * d.x;
|
| 564 |
+
const float gdy = G * d.y;
|
| 565 |
+
const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
|
| 566 |
+
const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;
|
| 567 |
+
|
| 568 |
+
// Update gradients w.r.t. 2D mean position of the Gaussian
|
| 569 |
+
atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
|
| 570 |
+
atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);
|
| 571 |
+
|
| 572 |
+
// Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
|
| 573 |
+
atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
|
| 574 |
+
atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
|
| 575 |
+
atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
|
| 576 |
+
|
| 577 |
+
// Update gradients w.r.t. opacity of the Gaussian
|
| 578 |
+
atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha);
|
| 579 |
+
}
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
void BACKWARD::preprocess(
|
| 584 |
+
int P, int D, int M, const float3 *means3D, const int *radii,
|
| 585 |
+
const float *shs, const bool *clamped, const glm::vec3 *scales,
|
| 586 |
+
const glm::vec4 *rotations, const float scale_modifier, const float *cov3Ds,
|
| 587 |
+
const float *viewmatrix, const float *projmatrix, const float focal_x,
|
| 588 |
+
float focal_y, const float tan_fovx, float tan_fovy,
|
| 589 |
+
const glm::vec3 *campos, const float3 *dL_dmean2D, const float *dL_dconic,
|
| 590 |
+
glm::vec3 *dL_dmean3D, float *dL_dcolor, float *dL_dcov3D, float *dL_dsh,
|
| 591 |
+
glm::vec3 *dL_dscale, glm::vec4 *dL_drot) {
|
| 592 |
+
// Propagate gradients for the path of 2D conic matrix computation.
|
| 593 |
+
// Somewhat long, thus it is its own kernel rather than being part of
|
| 594 |
+
// "preprocess". When done, loss gradient w.r.t. 3D means has been
|
| 595 |
+
// modified and gradient w.r.t. 3D covariance matrix has been computed.
|
| 596 |
+
computeCov2DCUDA<<<(P + 255) / 256, 256>>>(
|
| 597 |
+
P, means3D, radii, cov3Ds, focal_x, focal_y, tan_fovx, tan_fovy,
|
| 598 |
+
viewmatrix, dL_dconic, (float3 *)dL_dmean3D, dL_dcov3D);
|
| 599 |
+
|
| 600 |
+
// Propagate gradients for remaining steps: finish 3D mean gradients,
|
| 601 |
+
// propagate color gradients to SH (if desireD), propagate 3D covariance
|
| 602 |
+
// matrix gradients to scale and rotation.
|
| 603 |
+
preprocessCUDA<NUM_CHANNELS><<<(P + 255) / 256, 256>>>(
|
| 604 |
+
P, D, M, (float3 *)means3D, radii, shs, clamped, (glm::vec3 *)scales,
|
| 605 |
+
(glm::vec4 *)rotations, scale_modifier, projmatrix, campos,
|
| 606 |
+
(float3 *)dL_dmean2D, (glm::vec3 *)dL_dmean3D, dL_dcolor, dL_dcov3D,
|
| 607 |
+
dL_dsh, dL_dscale, dL_drot);
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
void BACKWARD::render(const dim3 grid, const dim3 block, const uint2 *ranges,
|
| 611 |
+
const uint32_t *point_list, int W, int H,
|
| 612 |
+
const float *bg_color, const float2 *means2D,
|
| 613 |
+
const float4 *conic_opacity, const float *colors,
|
| 614 |
+
const float *final_Ts, const uint32_t *n_contrib,
|
| 615 |
+
const float *dL_dpixels, float3 *dL_dmean2D,
|
| 616 |
+
float4 *dL_dconic2D, float *dL_dopacity,
|
| 617 |
+
float *dL_dcolors) {
|
| 618 |
+
renderCUDA<NUM_CHANNELS>
|
| 619 |
+
<<<grid, block>>>(ranges, point_list, W, H, bg_color, means2D,
|
| 620 |
+
conic_opacity, colors, final_Ts, n_contrib, dL_dpixels,
|
| 621 |
+
dL_dmean2D, dL_dconic2D, dL_dopacity, dL_dcolors);
|
| 622 |
+
}
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/backward.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED
|
| 13 |
+
#define CUDA_RASTERIZER_BACKWARD_H_INCLUDED
|
| 14 |
+
|
| 15 |
+
#include "cuda_runtime.h"
|
| 16 |
+
#include "device_launch_parameters.h"
|
| 17 |
+
#include <cuda.h>
|
| 18 |
+
#define GLM_FORCE_CUDA
|
| 19 |
+
#include <glm/glm.hpp>
|
| 20 |
+
|
| 21 |
+
namespace BACKWARD {
|
| 22 |
+
void render(const dim3 grid, dim3 block, const uint2 *ranges,
|
| 23 |
+
const uint32_t *point_list, int W, int H, const float *bg_color,
|
| 24 |
+
const float2 *means2D, const float4 *conic_opacity,
|
| 25 |
+
const float *colors, const float *final_Ts,
|
| 26 |
+
const uint32_t *n_contrib, const float *dL_dpixels,
|
| 27 |
+
float3 *dL_dmean2D, float4 *dL_dconic2D, float *dL_dopacity,
|
| 28 |
+
float *dL_dcolors);
|
| 29 |
+
|
| 30 |
+
void preprocess(int P, int D, int M, const float3 *means, const int *radii,
|
| 31 |
+
const float *shs, const bool *clamped, const glm::vec3 *scales,
|
| 32 |
+
const glm::vec4 *rotations, const float scale_modifier,
|
| 33 |
+
const float *cov3Ds, const float *view, const float *proj,
|
| 34 |
+
const float focal_x, float focal_y, const float tan_fovx,
|
| 35 |
+
float tan_fovy, const glm::vec3 *campos,
|
| 36 |
+
const float3 *dL_dmean2D, const float *dL_dconics,
|
| 37 |
+
glm::vec3 *dL_dmeans, float *dL_dcolor, float *dL_dcov3D,
|
| 38 |
+
float *dL_dsh, glm::vec3 *dL_dscale, glm::vec4 *dL_drot);
|
| 39 |
+
} // namespace BACKWARD
|
| 40 |
+
|
| 41 |
+
#endif
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/config.h
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED
|
| 13 |
+
#define CUDA_RASTERIZER_CONFIG_H_INCLUDED
|
| 14 |
+
|
| 15 |
+
#define NUM_CHANNELS 3 // Default 3, RGB
|
| 16 |
+
#define BLOCK_X 16
|
| 17 |
+
#define BLOCK_Y 16
|
| 18 |
+
|
| 19 |
+
#endif
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/forward.cu
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include "auxiliary.h"
|
| 13 |
+
#include "forward.h"
|
| 14 |
+
#include <cooperative_groups.h>
|
| 15 |
+
#include <cooperative_groups/reduce.h>
|
| 16 |
+
namespace cg = cooperative_groups;
|
| 17 |
+
|
| 18 |
+
// Forward method for converting the input spherical harmonics
|
| 19 |
+
// coefficients of each Gaussian to a simple RGB color.
|
| 20 |
+
__device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs,
|
| 21 |
+
const glm::vec3 *means,
|
| 22 |
+
glm::vec3 campos, const float *shs,
|
| 23 |
+
bool *clamped) {
|
| 24 |
+
// The implementation is loosely based on code for
|
| 25 |
+
// "Differentiable Point-Based Radiance Fields for
|
| 26 |
+
// Efficient View Synthesis" by Zhang et al. (2022)
|
| 27 |
+
glm::vec3 pos = means[idx];
|
| 28 |
+
glm::vec3 dir = pos - campos;
|
| 29 |
+
dir = dir / glm::length(dir);
|
| 30 |
+
|
| 31 |
+
glm::vec3 *sh = ((glm::vec3 *)shs) + idx * max_coeffs;
|
| 32 |
+
glm::vec3 result = SH_C0 * sh[0];
|
| 33 |
+
|
| 34 |
+
if (deg > 0) {
|
| 35 |
+
float x = dir.x;
|
| 36 |
+
float y = dir.y;
|
| 37 |
+
float z = dir.z;
|
| 38 |
+
result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3];
|
| 39 |
+
|
| 40 |
+
if (deg > 1) {
|
| 41 |
+
float xx = x * x, yy = y * y, zz = z * z;
|
| 42 |
+
float xy = x * y, yz = y * z, xz = x * z;
|
| 43 |
+
result = result + SH_C2[0] * xy * sh[4] + SH_C2[1] * yz * sh[5] +
|
| 44 |
+
SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] +
|
| 45 |
+
SH_C2[3] * xz * sh[7] + SH_C2[4] * (xx - yy) * sh[8];
|
| 46 |
+
|
| 47 |
+
if (deg > 2) {
|
| 48 |
+
result = result + SH_C3[0] * y * (3.0f * xx - yy) * sh[9] +
|
| 49 |
+
SH_C3[1] * xy * z * sh[10] +
|
| 50 |
+
SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] +
|
| 51 |
+
SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] +
|
| 52 |
+
SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] +
|
| 53 |
+
SH_C3[5] * z * (xx - yy) * sh[14] +
|
| 54 |
+
SH_C3[6] * x * (xx - 3.0f * yy) * sh[15];
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
result += 0.5f;
|
| 59 |
+
|
| 60 |
+
// RGB colors are clamped to positive values. If values are
|
| 61 |
+
// clamped, we need to keep track of this for the backward pass.
|
| 62 |
+
clamped[3 * idx + 0] = (result.x < 0);
|
| 63 |
+
clamped[3 * idx + 1] = (result.y < 0);
|
| 64 |
+
clamped[3 * idx + 2] = (result.z < 0);
|
| 65 |
+
return glm::max(result, 0.0f);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
// Forward version of 2D covariance matrix computation
|
| 69 |
+
__device__ float3 computeCov2D(const float3 &mean, float focal_x, float focal_y,
|
| 70 |
+
float tan_fovx, float tan_fovy,
|
| 71 |
+
const float *cov3D, const float *viewmatrix) {
|
| 72 |
+
// The following models the steps outlined by equations 29
|
| 73 |
+
// and 31 in "EWA Splatting" (Zwicker et al., 2002).
|
| 74 |
+
// Additionally considers aspect / scaling of viewport.
|
| 75 |
+
// Transposes used to account for row-/column-major conventions.
|
| 76 |
+
float3 t = transformPoint4x3(mean, viewmatrix);
|
| 77 |
+
|
| 78 |
+
const float limx = 1.3f * tan_fovx;
|
| 79 |
+
const float limy = 1.3f * tan_fovy;
|
| 80 |
+
const float txtz = t.x / t.z;
|
| 81 |
+
const float tytz = t.y / t.z;
|
| 82 |
+
t.x = min(limx, max(-limx, txtz)) * t.z;
|
| 83 |
+
t.y = min(limy, max(-limy, tytz)) * t.z;
|
| 84 |
+
|
| 85 |
+
glm::mat3 J =
|
| 86 |
+
glm::mat3(focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z), 0.0f,
|
| 87 |
+
focal_y / t.z, -(focal_y * t.y) / (t.z * t.z), 0, 0, 0);
|
| 88 |
+
|
| 89 |
+
glm::mat3 W = glm::mat3(viewmatrix[0], viewmatrix[4], viewmatrix[8],
|
| 90 |
+
viewmatrix[1], viewmatrix[5], viewmatrix[9],
|
| 91 |
+
viewmatrix[2], viewmatrix[6], viewmatrix[10]);
|
| 92 |
+
|
| 93 |
+
glm::mat3 T = W * J;
|
| 94 |
+
|
| 95 |
+
glm::mat3 Vrk = glm::mat3(cov3D[0], cov3D[1], cov3D[2], cov3D[1], cov3D[3],
|
| 96 |
+
cov3D[4], cov3D[2], cov3D[4], cov3D[5]);
|
| 97 |
+
|
| 98 |
+
glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T;
|
| 99 |
+
|
| 100 |
+
// Apply low-pass filter: every Gaussian should be at least
|
| 101 |
+
// one pixel wide/high. Discard 3rd row and column.
|
| 102 |
+
cov[0][0] += 0.3f;
|
| 103 |
+
cov[1][1] += 0.3f;
|
| 104 |
+
return {float(cov[0][0]), float(cov[0][1]), float(cov[1][1])};
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// Forward method for converting scale and rotation properties of each
|
| 108 |
+
// Gaussian to a 3D covariance matrix in world space. Also takes care
|
| 109 |
+
// of quaternion normalization.
|
| 110 |
+
__device__ void computeCov3D(const glm::vec3 scale, float mod,
|
| 111 |
+
const glm::vec4 rot, float *cov3D) {
|
| 112 |
+
// Create scaling matrix
|
| 113 |
+
glm::mat3 S = glm::mat3(1.0f);
|
| 114 |
+
S[0][0] = mod * scale.x;
|
| 115 |
+
S[1][1] = mod * scale.y;
|
| 116 |
+
S[2][2] = mod * scale.z;
|
| 117 |
+
|
| 118 |
+
// Normalize quaternion to get valid rotation
|
| 119 |
+
glm::vec4 q = rot; // / glm::length(rot);
|
| 120 |
+
float r = q.x;
|
| 121 |
+
float x = q.y;
|
| 122 |
+
float y = q.z;
|
| 123 |
+
float z = q.w;
|
| 124 |
+
|
| 125 |
+
// Compute rotation matrix from quaternion
|
| 126 |
+
glm::mat3 R = glm::mat3(1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z),
|
| 127 |
+
2.f * (x * z + r * y), 2.f * (x * y + r * z),
|
| 128 |
+
1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
|
| 129 |
+
2.f * (x * z - r * y), 2.f * (y * z + r * x),
|
| 130 |
+
1.f - 2.f * (x * x + y * y));
|
| 131 |
+
|
| 132 |
+
glm::mat3 M = S * R;
|
| 133 |
+
|
| 134 |
+
// Compute 3D world covariance matrix Sigma
|
| 135 |
+
glm::mat3 Sigma = glm::transpose(M) * M;
|
| 136 |
+
|
| 137 |
+
// Covariance is symmetric, only store upper right
|
| 138 |
+
cov3D[0] = Sigma[0][0];
|
| 139 |
+
cov3D[1] = Sigma[0][1];
|
| 140 |
+
cov3D[2] = Sigma[0][2];
|
| 141 |
+
cov3D[3] = Sigma[1][1];
|
| 142 |
+
cov3D[4] = Sigma[1][2];
|
| 143 |
+
cov3D[5] = Sigma[2][2];
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
// Perform initial steps for each Gaussian prior to rasterization.
|
| 147 |
+
template <int C>
|
| 148 |
+
__global__ void
|
| 149 |
+
preprocessCUDA(int P, int D, int M, const float *orig_points,
|
| 150 |
+
const glm::vec3 *scales, const float scale_modifier,
|
| 151 |
+
const glm::vec4 *rotations, const float *opacities,
|
| 152 |
+
const float *shs, bool *clamped, const float *cov3D_precomp,
|
| 153 |
+
const float *colors_precomp, const float *viewmatrix,
|
| 154 |
+
const float *projmatrix, const glm::vec3 *cam_pos, const int W,
|
| 155 |
+
int H, const float tan_fovx, float tan_fovy, const float focal_x,
|
| 156 |
+
float focal_y, int *radii, float2 *points_xy_image,
|
| 157 |
+
float *depths, float *cov3Ds, float *rgb, float4 *conic_opacity,
|
| 158 |
+
const dim3 grid, uint32_t *tiles_touched, bool prefiltered) {
|
| 159 |
+
auto idx = cg::this_grid().thread_rank();
|
| 160 |
+
if (idx >= P)
|
| 161 |
+
return;
|
| 162 |
+
|
| 163 |
+
// Initialize radius and touched tiles to 0. If this isn't changed,
|
| 164 |
+
// this Gaussian will not be processed further.
|
| 165 |
+
radii[idx] = 0;
|
| 166 |
+
tiles_touched[idx] = 0;
|
| 167 |
+
|
| 168 |
+
// Perform near culling, quit if outside.
|
| 169 |
+
float3 p_view;
|
| 170 |
+
if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered,
|
| 171 |
+
p_view))
|
| 172 |
+
return;
|
| 173 |
+
|
| 174 |
+
// Transform point by projecting
|
| 175 |
+
float3 p_orig = {orig_points[3 * idx], orig_points[3 * idx + 1],
|
| 176 |
+
orig_points[3 * idx + 2]};
|
| 177 |
+
float4 p_hom = transformPoint4x4(p_orig, projmatrix);
|
| 178 |
+
float p_w = 1.0f / (p_hom.w + 0.0000001f);
|
| 179 |
+
float3 p_proj = {p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w};
|
| 180 |
+
|
| 181 |
+
// If 3D covariance matrix is precomputed, use it, otherwise compute
|
| 182 |
+
// from scaling and rotation parameters.
|
| 183 |
+
const float *cov3D;
|
| 184 |
+
if (cov3D_precomp != nullptr) {
|
| 185 |
+
cov3D = cov3D_precomp + idx * 6;
|
| 186 |
+
} else {
|
| 187 |
+
computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
|
| 188 |
+
cov3D = cov3Ds + idx * 6;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
// Compute 2D screen-space covariance matrix
|
| 192 |
+
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D,
|
| 193 |
+
viewmatrix);
|
| 194 |
+
|
| 195 |
+
// Invert covariance (EWA algorithm)
|
| 196 |
+
float det = (cov.x * cov.z - cov.y * cov.y);
|
| 197 |
+
if (det == 0.0f)
|
| 198 |
+
return;
|
| 199 |
+
float det_inv = 1.f / det;
|
| 200 |
+
float3 conic = {cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv};
|
| 201 |
+
|
| 202 |
+
// Compute extent in screen space (by finding eigenvalues of
|
| 203 |
+
// 2D covariance matrix). Use extent to compute a bounding rectangle
|
| 204 |
+
// of screen-space tiles that this Gaussian overlaps with. Quit if
|
| 205 |
+
// rectangle covers 0 tiles.
|
| 206 |
+
float mid = 0.5f * (cov.x + cov.z);
|
| 207 |
+
float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
|
| 208 |
+
float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
|
| 209 |
+
float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
|
| 210 |
+
float2 point_image = {ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H)};
|
| 211 |
+
uint2 rect_min, rect_max;
|
| 212 |
+
getRect(point_image, my_radius, rect_min, rect_max, grid);
|
| 213 |
+
if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
|
| 214 |
+
return;
|
| 215 |
+
|
| 216 |
+
// If colors have been precomputed, use them, otherwise convert
|
| 217 |
+
// spherical harmonics coefficients to RGB color.
|
| 218 |
+
if (colors_precomp == nullptr) {
|
| 219 |
+
glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3 *)orig_points,
|
| 220 |
+
*cam_pos, shs, clamped);
|
| 221 |
+
rgb[idx * C + 0] = result.x;
|
| 222 |
+
rgb[idx * C + 1] = result.y;
|
| 223 |
+
rgb[idx * C + 2] = result.z;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
// Store some useful helper data for the next steps.
|
| 227 |
+
depths[idx] = p_view.z;
|
| 228 |
+
radii[idx] = my_radius;
|
| 229 |
+
points_xy_image[idx] = point_image;
|
| 230 |
+
// Inverse 2D covariance and opacity neatly pack into one float4
|
| 231 |
+
conic_opacity[idx] = {conic.x, conic.y, conic.z, opacities[idx]};
|
| 232 |
+
tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
// Main rasterization method. Collaboratively works on one tile per
|
| 236 |
+
// block, each thread treats one pixel. Alternates between fetching
|
| 237 |
+
// and rasterizing data.
|
| 238 |
+
template <uint32_t CHANNELS>
|
| 239 |
+
__global__ void __launch_bounds__(BLOCK_X *BLOCK_Y)
|
| 240 |
+
renderCUDA(const uint2 *__restrict__ ranges,
|
| 241 |
+
const uint32_t *__restrict__ point_list, int W, int H,
|
| 242 |
+
const float2 *__restrict__ points_xy_image,
|
| 243 |
+
const float *__restrict__ features,
|
| 244 |
+
const float4 *__restrict__ conic_opacity,
|
| 245 |
+
float *__restrict__ final_T, uint32_t *__restrict__ n_contrib,
|
| 246 |
+
const float *__restrict__ bg_color,
|
| 247 |
+
float *__restrict__ out_color) {
|
| 248 |
+
// Identify current tile and associated min/max pixel range.
|
| 249 |
+
auto block = cg::this_thread_block();
|
| 250 |
+
uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
|
| 251 |
+
uint2 pix_min = {block.group_index().x * BLOCK_X,
|
| 252 |
+
block.group_index().y * BLOCK_Y};
|
| 253 |
+
uint2 pix_max = {min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y, H)};
|
| 254 |
+
uint2 pix = {pix_min.x + block.thread_index().x,
|
| 255 |
+
pix_min.y + block.thread_index().y};
|
| 256 |
+
uint32_t pix_id = W * pix.y + pix.x;
|
| 257 |
+
float2 pixf = {(float)pix.x, (float)pix.y};
|
| 258 |
+
|
| 259 |
+
// Check if this thread is associated with a valid pixel or outside.
|
| 260 |
+
bool inside = pix.x < W && pix.y < H;
|
| 261 |
+
// Done threads can help with fetching, but don't rasterize
|
| 262 |
+
bool done = !inside;
|
| 263 |
+
|
| 264 |
+
// Load start/end range of IDs to process in bit sorted list.
|
| 265 |
+
uint2 range =
|
| 266 |
+
ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
|
| 267 |
+
const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
| 268 |
+
int toDo = range.y - range.x;
|
| 269 |
+
|
| 270 |
+
// Allocate storage for batches of collectively fetched data.
|
| 271 |
+
__shared__ int collected_id[BLOCK_SIZE];
|
| 272 |
+
__shared__ float2 collected_xy[BLOCK_SIZE];
|
| 273 |
+
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
|
| 274 |
+
|
| 275 |
+
// Initialize helper variables
|
| 276 |
+
float T = 1.0f;
|
| 277 |
+
uint32_t contributor = 0;
|
| 278 |
+
uint32_t last_contributor = 0;
|
| 279 |
+
float C[CHANNELS] = {0};
|
| 280 |
+
|
| 281 |
+
// Iterate over batches until all done or range is complete
|
| 282 |
+
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE) {
|
| 283 |
+
// End if entire block votes that it is done rasterizing
|
| 284 |
+
int num_done = __syncthreads_count(done);
|
| 285 |
+
if (num_done == BLOCK_SIZE)
|
| 286 |
+
break;
|
| 287 |
+
|
| 288 |
+
// Collectively fetch per-Gaussian data from global to shared
|
| 289 |
+
int progress = i * BLOCK_SIZE + block.thread_rank();
|
| 290 |
+
if (range.x + progress < range.y) {
|
| 291 |
+
int coll_id = point_list[range.x + progress];
|
| 292 |
+
collected_id[block.thread_rank()] = coll_id;
|
| 293 |
+
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
|
| 294 |
+
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
|
| 295 |
+
}
|
| 296 |
+
block.sync();
|
| 297 |
+
|
| 298 |
+
// Iterate over current batch
|
| 299 |
+
for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) {
|
| 300 |
+
// Keep track of current position in range
|
| 301 |
+
contributor++;
|
| 302 |
+
|
| 303 |
+
// Resample using conic matrix (cf. "Surface
|
| 304 |
+
// Splatting" by Zwicker et al., 2001)
|
| 305 |
+
float2 xy = collected_xy[j];
|
| 306 |
+
float2 d = {xy.x - pixf.x, xy.y - pixf.y};
|
| 307 |
+
float4 con_o = collected_conic_opacity[j];
|
| 308 |
+
float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) -
|
| 309 |
+
con_o.y * d.x * d.y;
|
| 310 |
+
if (power > 0.0f)
|
| 311 |
+
continue;
|
| 312 |
+
|
| 313 |
+
// Eq. (2) from 3D Gaussian splatting paper.
|
| 314 |
+
// Obtain alpha by multiplying with Gaussian opacity
|
| 315 |
+
// and its exponential falloff from mean.
|
| 316 |
+
// Avoid numerical instabilities (see paper appendix).
|
| 317 |
+
float alpha = min(0.99f, con_o.w * exp(power));
|
| 318 |
+
if (alpha < 1.0f / 255.0f)
|
| 319 |
+
continue;
|
| 320 |
+
float test_T = T * (1 - alpha);
|
| 321 |
+
if (test_T < 0.0001f) {
|
| 322 |
+
done = true;
|
| 323 |
+
continue;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
// Eq. (3) from 3D Gaussian splatting paper.
|
| 327 |
+
for (int ch = 0; ch < CHANNELS; ch++)
|
| 328 |
+
C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
|
| 329 |
+
|
| 330 |
+
T = test_T;
|
| 331 |
+
|
| 332 |
+
// Keep track of last range entry to update this
|
| 333 |
+
// pixel.
|
| 334 |
+
last_contributor = contributor;
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
// All threads that treat valid pixel write out their final
|
| 339 |
+
// rendering data to the frame and auxiliary buffers.
|
| 340 |
+
if (inside) {
|
| 341 |
+
final_T[pix_id] = T;
|
| 342 |
+
n_contrib[pix_id] = last_contributor;
|
| 343 |
+
for (int ch = 0; ch < CHANNELS; ch++)
|
| 344 |
+
out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
void FORWARD::render(const dim3 grid, dim3 block, const uint2 *ranges,
|
| 349 |
+
const uint32_t *point_list, int W, int H,
|
| 350 |
+
const float2 *means2D, const float *colors,
|
| 351 |
+
const float4 *conic_opacity, float *final_T,
|
| 352 |
+
uint32_t *n_contrib, const float *bg_color,
|
| 353 |
+
float *out_color) {
|
| 354 |
+
renderCUDA<NUM_CHANNELS><<<grid, block>>>(ranges, point_list, W, H, means2D,
|
| 355 |
+
colors, conic_opacity, final_T,
|
| 356 |
+
n_contrib, bg_color, out_color);
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
void FORWARD::preprocess(int P, int D, int M, const float *means3D,
|
| 360 |
+
const glm::vec3 *scales, const float scale_modifier,
|
| 361 |
+
const glm::vec4 *rotations, const float *opacities,
|
| 362 |
+
const float *shs, bool *clamped,
|
| 363 |
+
const float *cov3D_precomp,
|
| 364 |
+
const float *colors_precomp, const float *viewmatrix,
|
| 365 |
+
const float *projmatrix, const glm::vec3 *cam_pos,
|
| 366 |
+
const int W, int H, const float focal_x, float focal_y,
|
| 367 |
+
const float tan_fovx, float tan_fovy, int *radii,
|
| 368 |
+
float2 *means2D, float *depths, float *cov3Ds,
|
| 369 |
+
float *rgb, float4 *conic_opacity, const dim3 grid,
|
| 370 |
+
uint32_t *tiles_touched, bool prefiltered) {
|
| 371 |
+
preprocessCUDA<NUM_CHANNELS><<<(P + 255) / 256, 256>>>(
|
| 372 |
+
P, D, M, means3D, scales, scale_modifier, rotations, opacities, shs,
|
| 373 |
+
clamped, cov3D_precomp, colors_precomp, viewmatrix, projmatrix, cam_pos,
|
| 374 |
+
W, H, tan_fovx, tan_fovy, focal_x, focal_y, radii, means2D, depths,
|
| 375 |
+
cov3Ds, rgb, conic_opacity, grid, tiles_touched, prefiltered);
|
| 376 |
+
}
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/forward.h
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED
|
| 13 |
+
#define CUDA_RASTERIZER_FORWARD_H_INCLUDED
|
| 14 |
+
|
| 15 |
+
#include "cuda_runtime.h"
|
| 16 |
+
#include "device_launch_parameters.h"
|
| 17 |
+
#include <cuda.h>
|
| 18 |
+
#define GLM_FORCE_CUDA
|
| 19 |
+
#include <glm/glm.hpp>
|
| 20 |
+
|
| 21 |
+
namespace FORWARD {
|
| 22 |
+
// Perform initial steps for each Gaussian prior to rasterization.
|
| 23 |
+
void preprocess(int P, int D, int M, const float *orig_points,
|
| 24 |
+
const glm::vec3 *scales, const float scale_modifier,
|
| 25 |
+
const glm::vec4 *rotations, const float *opacities,
|
| 26 |
+
const float *shs, bool *clamped, const float *cov3D_precomp,
|
| 27 |
+
const float *colors_precomp, const float *viewmatrix,
|
| 28 |
+
const float *projmatrix, const glm::vec3 *cam_pos, const int W,
|
| 29 |
+
int H, const float focal_x, float focal_y, const float tan_fovx,
|
| 30 |
+
float tan_fovy, int *radii, float2 *points_xy_image,
|
| 31 |
+
float *depths, float *cov3Ds, float *colors,
|
| 32 |
+
float4 *conic_opacity, const dim3 grid, uint32_t *tiles_touched,
|
| 33 |
+
bool prefiltered);
|
| 34 |
+
|
| 35 |
+
// Main rasterization method.
|
| 36 |
+
void render(const dim3 grid, dim3 block, const uint2 *ranges,
|
| 37 |
+
const uint32_t *point_list, int W, int H,
|
| 38 |
+
const float2 *points_xy_image, const float *features,
|
| 39 |
+
const float4 *conic_opacity, float *final_T, uint32_t *n_contrib,
|
| 40 |
+
const float *bg_color, float *out_color);
|
| 41 |
+
} // namespace FORWARD
|
| 42 |
+
|
| 43 |
+
#endif
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer.h
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#ifndef CUDA_RASTERIZER_H_INCLUDED
|
| 13 |
+
#define CUDA_RASTERIZER_H_INCLUDED
|
| 14 |
+
|
| 15 |
+
#include <functional>
|
| 16 |
+
#include <vector>
|
| 17 |
+
|
| 18 |
+
namespace CudaRasterizer {
|
| 19 |
+
class Rasterizer {
|
| 20 |
+
public:
|
| 21 |
+
static void markVisible(int P, float *means3D, float *viewmatrix,
|
| 22 |
+
float *projmatrix, bool *present);
|
| 23 |
+
|
| 24 |
+
static int forward(std::function<char *(size_t)> geometryBuffer,
|
| 25 |
+
std::function<char *(size_t)> binningBuffer,
|
| 26 |
+
std::function<char *(size_t)> imageBuffer, const int P,
|
| 27 |
+
int D, int M, const float *background, const int width,
|
| 28 |
+
int height, const float *means3D, const float *shs,
|
| 29 |
+
const float *colors_precomp, const float *opacities,
|
| 30 |
+
const float *scales, const float scale_modifier,
|
| 31 |
+
const float *rotations, const float *cov3D_precomp,
|
| 32 |
+
const float *viewmatrix, const float *projmatrix,
|
| 33 |
+
const float *cam_pos, const float tan_fovx, float tan_fovy,
|
| 34 |
+
const bool prefiltered, float *out_color,
|
| 35 |
+
int *radii = nullptr, bool debug = false);
|
| 36 |
+
|
| 37 |
+
static void
|
| 38 |
+
backward(const int P, int D, int M, int R, const float *background,
|
| 39 |
+
const int width, int height, const float *means3D, const float *shs,
|
| 40 |
+
const float *colors_precomp, const float *scales,
|
| 41 |
+
const float scale_modifier, const float *rotations,
|
| 42 |
+
const float *cov3D_precomp, const float *viewmatrix,
|
| 43 |
+
const float *projmatrix, const float *campos, const float tan_fovx,
|
| 44 |
+
float tan_fovy, const int *radii, char *geom_buffer,
|
| 45 |
+
char *binning_buffer, char *image_buffer, const float *dL_dpix,
|
| 46 |
+
float *dL_dmean2D, float *dL_dconic, float *dL_dopacity,
|
| 47 |
+
float *dL_dcolor, float *dL_dmean3D, float *dL_dcov3D, float *dL_dsh,
|
| 48 |
+
float *dL_dscale, float *dL_drot, bool debug);
|
| 49 |
+
};
|
| 50 |
+
}; // namespace CudaRasterizer
|
| 51 |
+
|
| 52 |
+
#endif
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer_impl.cu
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include "cuda_runtime.h"
|
| 13 |
+
#include "device_launch_parameters.h"
|
| 14 |
+
#include "rasterizer_impl.h"
|
| 15 |
+
#include <algorithm>
|
| 16 |
+
#include <cub/cub.cuh>
|
| 17 |
+
#include <cub/device/device_radix_sort.cuh>
|
| 18 |
+
#include <cuda.h>
|
| 19 |
+
#include <fstream>
|
| 20 |
+
#include <iostream>
|
| 21 |
+
#include <numeric>
|
| 22 |
+
#define GLM_FORCE_CUDA
|
| 23 |
+
#include <glm/glm.hpp>
|
| 24 |
+
|
| 25 |
+
#include <cooperative_groups.h>
|
| 26 |
+
#include <cooperative_groups/reduce.h>
|
| 27 |
+
namespace cg = cooperative_groups;
|
| 28 |
+
|
| 29 |
+
#include "auxiliary.h"
|
| 30 |
+
#include "backward.h"
|
| 31 |
+
#include "forward.h"
|
| 32 |
+
|
| 33 |
+
// Helper function to find the next-highest bit of the MSB
|
| 34 |
+
// on the CPU.
|
| 35 |
+
uint32_t getHigherMsb(uint32_t n) {
|
| 36 |
+
uint32_t msb = sizeof(n) * 4;
|
| 37 |
+
uint32_t step = msb;
|
| 38 |
+
while (step > 1) {
|
| 39 |
+
step /= 2;
|
| 40 |
+
if (n >> msb)
|
| 41 |
+
msb += step;
|
| 42 |
+
else
|
| 43 |
+
msb -= step;
|
| 44 |
+
}
|
| 45 |
+
if (n >> msb)
|
| 46 |
+
msb++;
|
| 47 |
+
return msb;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
// Wrapper method to call auxiliary coarse frustum containment test.
|
| 51 |
+
// Mark all Gaussians that pass it.
|
| 52 |
+
__global__ void checkFrustum(int P, const float *orig_points,
|
| 53 |
+
const float *viewmatrix, const float *projmatrix,
|
| 54 |
+
bool *present) {
|
| 55 |
+
auto idx = cg::this_grid().thread_rank();
|
| 56 |
+
if (idx >= P)
|
| 57 |
+
return;
|
| 58 |
+
|
| 59 |
+
float3 p_view;
|
| 60 |
+
present[idx] =
|
| 61 |
+
in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
// Generates one key/value pair for all Gaussian / tile overlaps.
|
| 65 |
+
// Run once per Gaussian (1:N mapping).
|
| 66 |
+
__global__ void duplicateWithKeys(int P, const float2 *points_xy,
|
| 67 |
+
const float *depths, const uint32_t *offsets,
|
| 68 |
+
uint64_t *gaussian_keys_unsorted,
|
| 69 |
+
uint32_t *gaussian_values_unsorted,
|
| 70 |
+
int *radii, dim3 grid) {
|
| 71 |
+
auto idx = cg::this_grid().thread_rank();
|
| 72 |
+
if (idx >= P)
|
| 73 |
+
return;
|
| 74 |
+
|
| 75 |
+
// Generate no key/value pair for invisible Gaussians
|
| 76 |
+
if (radii[idx] > 0) {
|
| 77 |
+
// Find this Gaussian's offset in buffer for writing keys/values.
|
| 78 |
+
uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];
|
| 79 |
+
uint2 rect_min, rect_max;
|
| 80 |
+
|
| 81 |
+
getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);
|
| 82 |
+
|
| 83 |
+
// For each tile that the bounding rect overlaps, emit a
|
| 84 |
+
// key/value pair. The key is | tile ID | depth |,
|
| 85 |
+
// and the value is the ID of the Gaussian. Sorting the values
|
| 86 |
+
// with this key yields Gaussian IDs in a list, such that they
|
| 87 |
+
// are first sorted by tile and then by depth.
|
| 88 |
+
for (int y = rect_min.y; y < rect_max.y; y++) {
|
| 89 |
+
for (int x = rect_min.x; x < rect_max.x; x++) {
|
| 90 |
+
uint64_t key = y * grid.x + x;
|
| 91 |
+
key <<= 32;
|
| 92 |
+
key |= *((uint32_t *)&depths[idx]);
|
| 93 |
+
gaussian_keys_unsorted[off] = key;
|
| 94 |
+
gaussian_values_unsorted[off] = idx;
|
| 95 |
+
off++;
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// Check keys to see if it is at the start/end of one tile's range in
|
| 102 |
+
// the full sorted list. If yes, write start/end of this tile.
|
| 103 |
+
// Run once per instanced (duplicated) Gaussian ID.
|
| 104 |
+
__global__ void identifyTileRanges(int L, uint64_t *point_list_keys,
|
| 105 |
+
uint2 *ranges) {
|
| 106 |
+
auto idx = cg::this_grid().thread_rank();
|
| 107 |
+
if (idx >= L)
|
| 108 |
+
return;
|
| 109 |
+
|
| 110 |
+
// Read tile ID from key. Update start/end of tile range if at limit.
|
| 111 |
+
uint64_t key = point_list_keys[idx];
|
| 112 |
+
uint32_t currtile = key >> 32;
|
| 113 |
+
if (idx == 0)
|
| 114 |
+
ranges[currtile].x = 0;
|
| 115 |
+
else {
|
| 116 |
+
uint32_t prevtile = point_list_keys[idx - 1] >> 32;
|
| 117 |
+
if (currtile != prevtile) {
|
| 118 |
+
ranges[prevtile].y = idx;
|
| 119 |
+
ranges[currtile].x = idx;
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
if (idx == L - 1)
|
| 123 |
+
ranges[currtile].y = L;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// Mark Gaussians as visible/invisible, based on view frustum testing
|
| 127 |
+
void CudaRasterizer::Rasterizer::markVisible(int P, float *means3D,
|
| 128 |
+
float *viewmatrix,
|
| 129 |
+
float *projmatrix, bool *present) {
|
| 130 |
+
checkFrustum<<<(P + 255) / 256, 256>>>(P, means3D, viewmatrix, projmatrix,
|
| 131 |
+
present);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
CudaRasterizer::GeometryState
|
| 135 |
+
CudaRasterizer::GeometryState::fromChunk(char *&chunk, size_t P) {
|
| 136 |
+
GeometryState geom;
|
| 137 |
+
obtain(chunk, geom.depths, P, 128);
|
| 138 |
+
obtain(chunk, geom.clamped, P * 3, 128);
|
| 139 |
+
obtain(chunk, geom.internal_radii, P, 128);
|
| 140 |
+
obtain(chunk, geom.means2D, P, 128);
|
| 141 |
+
obtain(chunk, geom.cov3D, P * 6, 128);
|
| 142 |
+
obtain(chunk, geom.conic_opacity, P, 128);
|
| 143 |
+
obtain(chunk, geom.rgb, P * 3, 128);
|
| 144 |
+
obtain(chunk, geom.tiles_touched, P, 128);
|
| 145 |
+
cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched,
|
| 146 |
+
geom.tiles_touched, P);
|
| 147 |
+
obtain(chunk, geom.scanning_space, geom.scan_size, 128);
|
| 148 |
+
obtain(chunk, geom.point_offsets, P, 128);
|
| 149 |
+
return geom;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char *&chunk,
|
| 153 |
+
size_t N) {
|
| 154 |
+
ImageState img;
|
| 155 |
+
obtain(chunk, img.accum_alpha, N, 128);
|
| 156 |
+
obtain(chunk, img.n_contrib, N, 128);
|
| 157 |
+
obtain(chunk, img.ranges, N, 128);
|
| 158 |
+
return img;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
CudaRasterizer::BinningState
|
| 162 |
+
CudaRasterizer::BinningState::fromChunk(char *&chunk, size_t P) {
|
| 163 |
+
BinningState binning;
|
| 164 |
+
obtain(chunk, binning.point_list, P, 128);
|
| 165 |
+
obtain(chunk, binning.point_list_unsorted, P, 128);
|
| 166 |
+
obtain(chunk, binning.point_list_keys, P, 128);
|
| 167 |
+
obtain(chunk, binning.point_list_keys_unsorted, P, 128);
|
| 168 |
+
cub::DeviceRadixSort::SortPairs(
|
| 169 |
+
nullptr, binning.sorting_size, binning.point_list_keys_unsorted,
|
| 170 |
+
binning.point_list_keys, binning.point_list_unsorted, binning.point_list,
|
| 171 |
+
P);
|
| 172 |
+
obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
|
| 173 |
+
return binning;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// Forward rendering procedure for differentiable rasterization
|
| 177 |
+
// of Gaussians.
|
| 178 |
+
int CudaRasterizer::Rasterizer::forward(
|
| 179 |
+
std::function<char *(size_t)> geometryBuffer,
|
| 180 |
+
std::function<char *(size_t)> binningBuffer,
|
| 181 |
+
std::function<char *(size_t)> imageBuffer, const int P, int D, int M,
|
| 182 |
+
const float *background, const int width, int height, const float *means3D,
|
| 183 |
+
const float *shs, const float *colors_precomp, const float *opacities,
|
| 184 |
+
const float *scales, const float scale_modifier, const float *rotations,
|
| 185 |
+
const float *cov3D_precomp, const float *viewmatrix,
|
| 186 |
+
const float *projmatrix, const float *cam_pos, const float tan_fovx,
|
| 187 |
+
float tan_fovy, const bool prefiltered, float *out_color, int *radii,
|
| 188 |
+
bool debug) {
|
| 189 |
+
const float focal_y = height / (2.0f * tan_fovy);
|
| 190 |
+
const float focal_x = width / (2.0f * tan_fovx);
|
| 191 |
+
|
| 192 |
+
size_t chunk_size = required<GeometryState>(P);
|
| 193 |
+
char *chunkptr = geometryBuffer(chunk_size);
|
| 194 |
+
GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
|
| 195 |
+
|
| 196 |
+
if (radii == nullptr) {
|
| 197 |
+
radii = geomState.internal_radii;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X,
|
| 201 |
+
(height + BLOCK_Y - 1) / BLOCK_Y, 1);
|
| 202 |
+
dim3 block(BLOCK_X, BLOCK_Y, 1);
|
| 203 |
+
|
| 204 |
+
// Dynamically resize image-based auxiliary buffers during training
|
| 205 |
+
size_t img_chunk_size = required<ImageState>(width * height);
|
| 206 |
+
char *img_chunkptr = imageBuffer(img_chunk_size);
|
| 207 |
+
ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
|
| 208 |
+
|
| 209 |
+
if (NUM_CHANNELS != 3 && colors_precomp == nullptr) {
|
| 210 |
+
throw std::runtime_error(
|
| 211 |
+
"For non-RGB, provide precomputed Gaussian colors!");
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
// Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs
|
| 215 |
+
// to RGB)
|
| 216 |
+
CHECK_CUDA(FORWARD::preprocess(
|
| 217 |
+
P, D, M, means3D, (glm::vec3 *)scales, scale_modifier,
|
| 218 |
+
(glm::vec4 *)rotations, opacities, shs, geomState.clamped,
|
| 219 |
+
cov3D_precomp, colors_precomp, viewmatrix, projmatrix,
|
| 220 |
+
(glm::vec3 *)cam_pos, width, height, focal_x, focal_y,
|
| 221 |
+
tan_fovx, tan_fovy, radii, geomState.means2D, geomState.depths,
|
| 222 |
+
geomState.cov3D, geomState.rgb, geomState.conic_opacity,
|
| 223 |
+
tile_grid, geomState.tiles_touched, prefiltered),
|
| 224 |
+
debug)
|
| 225 |
+
|
| 226 |
+
// Compute prefix sum over full list of touched tile counts by Gaussians
|
| 227 |
+
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
|
| 228 |
+
CHECK_CUDA(cub::DeviceScan::InclusiveSum(
|
| 229 |
+
geomState.scanning_space, geomState.scan_size,
|
| 230 |
+
geomState.tiles_touched, geomState.point_offsets, P),
|
| 231 |
+
debug)
|
| 232 |
+
|
| 233 |
+
// Retrieve total number of Gaussian instances to launch and resize aux
|
| 234 |
+
// buffers
|
| 235 |
+
int num_rendered;
|
| 236 |
+
CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1,
|
| 237 |
+
sizeof(int), cudaMemcpyDeviceToHost),
|
| 238 |
+
debug);
|
| 239 |
+
|
| 240 |
+
size_t binning_chunk_size = required<BinningState>(num_rendered);
|
| 241 |
+
char *binning_chunkptr = binningBuffer(binning_chunk_size);
|
| 242 |
+
BinningState binningState =
|
| 243 |
+
BinningState::fromChunk(binning_chunkptr, num_rendered);
|
| 244 |
+
|
| 245 |
+
// For each instance to be rendered, produce adequate [ tile | depth ] key
|
| 246 |
+
// and corresponding dublicated Gaussian indices to be sorted
|
| 247 |
+
duplicateWithKeys<<<(P + 255) / 256, 256>>>(
|
| 248 |
+
P, geomState.means2D, geomState.depths, geomState.point_offsets,
|
| 249 |
+
binningState.point_list_keys_unsorted, binningState.point_list_unsorted,
|
| 250 |
+
radii, tile_grid) CHECK_CUDA(, debug)
|
| 251 |
+
|
| 252 |
+
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
|
| 253 |
+
|
| 254 |
+
// Sort complete list of (duplicated) Gaussian indices by keys
|
| 255 |
+
CHECK_CUDA(cub::DeviceRadixSort::SortPairs(
|
| 256 |
+
binningState.list_sorting_space, binningState.sorting_size,
|
| 257 |
+
binningState.point_list_keys_unsorted,
|
| 258 |
+
binningState.point_list_keys, binningState.point_list_unsorted,
|
| 259 |
+
binningState.point_list, num_rendered, 0, 32 + bit),
|
| 260 |
+
debug)
|
| 261 |
+
|
| 262 |
+
CHECK_CUDA(
|
| 263 |
+
cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)),
|
| 264 |
+
debug);
|
| 265 |
+
|
| 266 |
+
// Identify start and end of per-tile workloads in sorted list
|
| 267 |
+
if (num_rendered > 0)
|
| 268 |
+
identifyTileRanges<<<(num_rendered + 255) / 256, 256>>>(
|
| 269 |
+
num_rendered, binningState.point_list_keys, imgState.ranges);
|
| 270 |
+
CHECK_CUDA(, debug)
|
| 271 |
+
|
| 272 |
+
// Let each tile blend its range of Gaussians independently in parallel
|
| 273 |
+
const float *feature_ptr =
|
| 274 |
+
colors_precomp != nullptr ? colors_precomp : geomState.rgb;
|
| 275 |
+
CHECK_CUDA(FORWARD::render(tile_grid, block, imgState.ranges,
|
| 276 |
+
binningState.point_list, width, height,
|
| 277 |
+
geomState.means2D, feature_ptr,
|
| 278 |
+
geomState.conic_opacity, imgState.accum_alpha,
|
| 279 |
+
imgState.n_contrib, background, out_color),
|
| 280 |
+
debug)
|
| 281 |
+
|
| 282 |
+
return num_rendered;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
// Produce necessary gradients for optimization, corresponding
|
| 286 |
+
// to forward render pass
|
| 287 |
+
void CudaRasterizer::Rasterizer::backward(
|
| 288 |
+
const int P, int D, int M, int R, const float *background, const int width,
|
| 289 |
+
int height, const float *means3D, const float *shs,
|
| 290 |
+
const float *colors_precomp, const float *scales,
|
| 291 |
+
const float scale_modifier, const float *rotations,
|
| 292 |
+
const float *cov3D_precomp, const float *viewmatrix,
|
| 293 |
+
const float *projmatrix, const float *campos, const float tan_fovx,
|
| 294 |
+
float tan_fovy, const int *radii, char *geom_buffer, char *binning_buffer,
|
| 295 |
+
char *img_buffer, const float *dL_dpix, float *dL_dmean2D, float *dL_dconic,
|
| 296 |
+
float *dL_dopacity, float *dL_dcolor, float *dL_dmean3D, float *dL_dcov3D,
|
| 297 |
+
float *dL_dsh, float *dL_dscale, float *dL_drot, bool debug) {
|
| 298 |
+
GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
|
| 299 |
+
BinningState binningState = BinningState::fromChunk(binning_buffer, R);
|
| 300 |
+
ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
|
| 301 |
+
|
| 302 |
+
if (radii == nullptr) {
|
| 303 |
+
radii = geomState.internal_radii;
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
const float focal_y = height / (2.0f * tan_fovy);
|
| 307 |
+
const float focal_x = width / (2.0f * tan_fovx);
|
| 308 |
+
|
| 309 |
+
const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X,
|
| 310 |
+
(height + BLOCK_Y - 1) / BLOCK_Y, 1);
|
| 311 |
+
const dim3 block(BLOCK_X, BLOCK_Y, 1);
|
| 312 |
+
|
| 313 |
+
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
|
| 314 |
+
// opacity and RGB of Gaussians from per-pixel loss gradients.
|
| 315 |
+
// If we were given precomputed colors and not SHs, use them.
|
| 316 |
+
const float *color_ptr =
|
| 317 |
+
(colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
|
| 318 |
+
CHECK_CUDA(BACKWARD::render(
|
| 319 |
+
tile_grid, block, imgState.ranges, binningState.point_list,
|
| 320 |
+
width, height, background, geomState.means2D,
|
| 321 |
+
geomState.conic_opacity, color_ptr, imgState.accum_alpha,
|
| 322 |
+
imgState.n_contrib, dL_dpix, (float3 *)dL_dmean2D,
|
| 323 |
+
(float4 *)dL_dconic, dL_dopacity, dL_dcolor),
|
| 324 |
+
debug)
|
| 325 |
+
|
| 326 |
+
// Take care of the rest of preprocessing. Was the precomputed covariance
|
| 327 |
+
// given to us or a scales/rot pair? If precomputed, pass that. If not,
|
| 328 |
+
// use the one we computed ourselves.
|
| 329 |
+
const float *cov3D_ptr =
|
| 330 |
+
(cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
|
| 331 |
+
CHECK_CUDA(BACKWARD::preprocess(
|
| 332 |
+
P, D, M, (float3 *)means3D, radii, shs, geomState.clamped,
|
| 333 |
+
(glm::vec3 *)scales, (glm::vec4 *)rotations, scale_modifier,
|
| 334 |
+
cov3D_ptr, viewmatrix, projmatrix, focal_x, focal_y, tan_fovx,
|
| 335 |
+
tan_fovy, (glm::vec3 *)campos, (float3 *)dL_dmean2D, dL_dconic,
|
| 336 |
+
(glm::vec3 *)dL_dmean3D, dL_dcolor, dL_dcov3D, dL_dsh,
|
| 337 |
+
(glm::vec3 *)dL_dscale, (glm::vec4 *)dL_drot),
|
| 338 |
+
debug)
|
| 339 |
+
}
|
gaussiancity/extensions/diff_gaussian_rasterization/cuda_rasterizer/rasterizer_impl.h
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#pragma once
|
| 13 |
+
|
| 14 |
+
#include "rasterizer.h"
|
| 15 |
+
#include <cuda_runtime_api.h>
|
| 16 |
+
#include <iostream>
|
| 17 |
+
#include <vector>
|
| 18 |
+
|
| 19 |
+
namespace CudaRasterizer {
|
| 20 |
+
template <typename T>
|
| 21 |
+
static void obtain(char *&chunk, T *&ptr, std::size_t count,
|
| 22 |
+
std::size_t alignment) {
|
| 23 |
+
std::size_t offset =
|
| 24 |
+
(reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) &
|
| 25 |
+
~(alignment - 1);
|
| 26 |
+
ptr = reinterpret_cast<T *>(offset);
|
| 27 |
+
chunk = reinterpret_cast<char *>(ptr + count);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
struct GeometryState {
|
| 31 |
+
size_t scan_size;
|
| 32 |
+
float *depths;
|
| 33 |
+
char *scanning_space;
|
| 34 |
+
bool *clamped;
|
| 35 |
+
int *internal_radii;
|
| 36 |
+
float2 *means2D;
|
| 37 |
+
float *cov3D;
|
| 38 |
+
float4 *conic_opacity;
|
| 39 |
+
float *rgb;
|
| 40 |
+
uint32_t *point_offsets;
|
| 41 |
+
uint32_t *tiles_touched;
|
| 42 |
+
|
| 43 |
+
static GeometryState fromChunk(char *&chunk, size_t P);
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
struct ImageState {
|
| 47 |
+
uint2 *ranges;
|
| 48 |
+
uint32_t *n_contrib;
|
| 49 |
+
float *accum_alpha;
|
| 50 |
+
|
| 51 |
+
static ImageState fromChunk(char *&chunk, size_t N);
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
struct BinningState {
|
| 55 |
+
size_t sorting_size;
|
| 56 |
+
uint64_t *point_list_keys_unsorted;
|
| 57 |
+
uint64_t *point_list_keys;
|
| 58 |
+
uint32_t *point_list_unsorted;
|
| 59 |
+
uint32_t *point_list;
|
| 60 |
+
char *list_sorting_space;
|
| 61 |
+
|
| 62 |
+
static BinningState fromChunk(char *&chunk, size_t P);
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
template <typename T> size_t required(size_t P) {
|
| 66 |
+
char *size = nullptr;
|
| 67 |
+
T::fromChunk(size, P);
|
| 68 |
+
return ((size_t)size) + 128;
|
| 69 |
+
}
|
| 70 |
+
}; // namespace CudaRasterizer
|
gaussiancity/extensions/diff_gaussian_rasterization/rasterize_points.cu
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#include "cuda_rasterizer/config.h"
|
| 13 |
+
#include "cuda_rasterizer/rasterizer.h"
|
| 14 |
+
#include <cstdio>
|
| 15 |
+
#include <cuda_runtime_api.h>
|
| 16 |
+
#include <fstream>
|
| 17 |
+
#include <functional>
|
| 18 |
+
#include <iostream>
|
| 19 |
+
#include <math.h>
|
| 20 |
+
#include <memory>
|
| 21 |
+
#include <sstream>
|
| 22 |
+
#include <stdio.h>
|
| 23 |
+
#include <string>
|
| 24 |
+
#include <torch/extension.h>
|
| 25 |
+
#include <tuple>
|
| 26 |
+
|
| 27 |
+
std::function<char *(size_t N)> resizeFunctional(torch::Tensor &t) {
|
| 28 |
+
auto lambda = [&t](size_t N) {
|
| 29 |
+
t.resize_({(long long)N});
|
| 30 |
+
return reinterpret_cast<char *>(t.contiguous().data_ptr());
|
| 31 |
+
};
|
| 32 |
+
return lambda;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
| 36 |
+
torch::Tensor>
|
| 37 |
+
RasterizeGaussiansCUDA(
|
| 38 |
+
const torch::Tensor &background, const torch::Tensor &means3D,
|
| 39 |
+
const torch::Tensor &colors, const torch::Tensor &opacity,
|
| 40 |
+
const torch::Tensor &scales, const torch::Tensor &rotations,
|
| 41 |
+
const float scale_modifier, const torch::Tensor &cov3D_precomp,
|
| 42 |
+
const torch::Tensor &viewmatrix, const torch::Tensor &projmatrix,
|
| 43 |
+
const float tan_fovx, const float tan_fovy, const int image_height,
|
| 44 |
+
const int image_width, const torch::Tensor &sh, const int degree,
|
| 45 |
+
const torch::Tensor &campos, const bool prefiltered, const bool debug) {
|
| 46 |
+
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
|
| 47 |
+
AT_ERROR("means3D must have dimensions (num_points, 3)");
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
const int P = means3D.size(0);
|
| 51 |
+
const int H = image_height;
|
| 52 |
+
const int W = image_width;
|
| 53 |
+
|
| 54 |
+
auto int_opts = means3D.options().dtype(torch::kInt32);
|
| 55 |
+
auto float_opts = means3D.options().dtype(torch::kFloat32);
|
| 56 |
+
|
| 57 |
+
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
|
| 58 |
+
torch::Tensor radii =
|
| 59 |
+
torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
|
| 60 |
+
|
| 61 |
+
torch::Device device(torch::kCUDA);
|
| 62 |
+
torch::TensorOptions options(torch::kByte);
|
| 63 |
+
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
|
| 64 |
+
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
|
| 65 |
+
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
|
| 66 |
+
std::function<char *(size_t)> geomFunc = resizeFunctional(geomBuffer);
|
| 67 |
+
std::function<char *(size_t)> binningFunc = resizeFunctional(binningBuffer);
|
| 68 |
+
std::function<char *(size_t)> imgFunc = resizeFunctional(imgBuffer);
|
| 69 |
+
|
| 70 |
+
int rendered = 0;
|
| 71 |
+
if (P != 0) {
|
| 72 |
+
int M = 0;
|
| 73 |
+
if (sh.size(0) != 0) {
|
| 74 |
+
M = sh.size(1);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
rendered = CudaRasterizer::Rasterizer::forward(
|
| 78 |
+
geomFunc, binningFunc, imgFunc, P, degree, M,
|
| 79 |
+
background.contiguous().data<float>(), W, H,
|
| 80 |
+
means3D.contiguous().data<float>(), sh.contiguous().data_ptr<float>(),
|
| 81 |
+
colors.contiguous().data<float>(), opacity.contiguous().data<float>(),
|
| 82 |
+
scales.contiguous().data_ptr<float>(), scale_modifier,
|
| 83 |
+
rotations.contiguous().data_ptr<float>(),
|
| 84 |
+
cov3D_precomp.contiguous().data<float>(),
|
| 85 |
+
viewmatrix.contiguous().data<float>(),
|
| 86 |
+
projmatrix.contiguous().data<float>(),
|
| 87 |
+
campos.contiguous().data<float>(), tan_fovx, tan_fovy, prefiltered,
|
| 88 |
+
out_color.contiguous().data<float>(), radii.contiguous().data<int>(),
|
| 89 |
+
debug);
|
| 90 |
+
}
|
| 91 |
+
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer,
|
| 92 |
+
imgBuffer);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
| 96 |
+
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 97 |
+
RasterizeGaussiansBackwardCUDA(
|
| 98 |
+
const torch::Tensor &background, const torch::Tensor &means3D,
|
| 99 |
+
const torch::Tensor &radii, const torch::Tensor &colors,
|
| 100 |
+
const torch::Tensor &scales, const torch::Tensor &rotations,
|
| 101 |
+
const float scale_modifier, const torch::Tensor &cov3D_precomp,
|
| 102 |
+
const torch::Tensor &viewmatrix, const torch::Tensor &projmatrix,
|
| 103 |
+
const float tan_fovx, const float tan_fovy,
|
| 104 |
+
const torch::Tensor &dL_dout_color, const torch::Tensor &sh,
|
| 105 |
+
const int degree, const torch::Tensor &campos,
|
| 106 |
+
const torch::Tensor &geomBuffer, const int R,
|
| 107 |
+
const torch::Tensor &binningBuffer, const torch::Tensor &imageBuffer,
|
| 108 |
+
const bool debug) {
|
| 109 |
+
const int P = means3D.size(0);
|
| 110 |
+
const int H = dL_dout_color.size(1);
|
| 111 |
+
const int W = dL_dout_color.size(2);
|
| 112 |
+
|
| 113 |
+
int M = 0;
|
| 114 |
+
if (sh.size(0) != 0) {
|
| 115 |
+
M = sh.size(1);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
|
| 119 |
+
torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
|
| 120 |
+
torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
|
| 121 |
+
torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
|
| 122 |
+
torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
|
| 123 |
+
torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
|
| 124 |
+
torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
|
| 125 |
+
torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
|
| 126 |
+
torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
|
| 127 |
+
|
| 128 |
+
if (P != 0) {
|
| 129 |
+
CudaRasterizer::Rasterizer::backward(
|
| 130 |
+
P, degree, M, R, background.contiguous().data<float>(), W, H,
|
| 131 |
+
means3D.contiguous().data<float>(), sh.contiguous().data<float>(),
|
| 132 |
+
colors.contiguous().data<float>(), scales.data_ptr<float>(),
|
| 133 |
+
scale_modifier, rotations.data_ptr<float>(),
|
| 134 |
+
cov3D_precomp.contiguous().data<float>(),
|
| 135 |
+
viewmatrix.contiguous().data<float>(),
|
| 136 |
+
projmatrix.contiguous().data<float>(),
|
| 137 |
+
campos.contiguous().data<float>(), tan_fovx, tan_fovy,
|
| 138 |
+
radii.contiguous().data<int>(),
|
| 139 |
+
reinterpret_cast<char *>(geomBuffer.contiguous().data_ptr()),
|
| 140 |
+
reinterpret_cast<char *>(binningBuffer.contiguous().data_ptr()),
|
| 141 |
+
reinterpret_cast<char *>(imageBuffer.contiguous().data_ptr()),
|
| 142 |
+
dL_dout_color.contiguous().data<float>(),
|
| 143 |
+
dL_dmeans2D.contiguous().data<float>(),
|
| 144 |
+
dL_dconic.contiguous().data<float>(),
|
| 145 |
+
dL_dopacity.contiguous().data<float>(),
|
| 146 |
+
dL_dcolors.contiguous().data<float>(),
|
| 147 |
+
dL_dmeans3D.contiguous().data<float>(),
|
| 148 |
+
dL_dcov3D.contiguous().data<float>(), dL_dsh.contiguous().data<float>(),
|
| 149 |
+
dL_dscales.contiguous().data<float>(),
|
| 150 |
+
dL_drotations.contiguous().data<float>(), debug);
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D,
|
| 154 |
+
dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
torch::Tensor markVisible(torch::Tensor &means3D, torch::Tensor &viewmatrix,
|
| 158 |
+
torch::Tensor &projmatrix) {
|
| 159 |
+
const int P = means3D.size(0);
|
| 160 |
+
|
| 161 |
+
torch::Tensor present =
|
| 162 |
+
torch::full({P}, false, means3D.options().dtype(at::kBool));
|
| 163 |
+
|
| 164 |
+
if (P != 0) {
|
| 165 |
+
CudaRasterizer::Rasterizer::markVisible(
|
| 166 |
+
P, means3D.contiguous().data<float>(),
|
| 167 |
+
viewmatrix.contiguous().data<float>(),
|
| 168 |
+
projmatrix.contiguous().data<float>(),
|
| 169 |
+
present.contiguous().data<bool>());
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
return present;
|
| 173 |
+
}
|
gaussiancity/extensions/diff_gaussian_rasterization/rasterize_points.h
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (C) 2023, Inria
|
| 3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
* All rights reserved.
|
| 5 |
+
*
|
| 6 |
+
* This software is free for non-commercial, research and evaluation use
|
| 7 |
+
* under the terms of the LICENSE.md file.
|
| 8 |
+
*
|
| 9 |
+
* For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#pragma once
|
| 13 |
+
#include <cstdio>
|
| 14 |
+
#include <string>
|
| 15 |
+
#include <torch/extension.h>
|
| 16 |
+
#include <tuple>
|
| 17 |
+
|
| 18 |
+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
| 19 |
+
torch::Tensor>
|
| 20 |
+
RasterizeGaussiansCUDA(
|
| 21 |
+
const torch::Tensor &background, const torch::Tensor &means3D,
|
| 22 |
+
const torch::Tensor &colors, const torch::Tensor &opacity,
|
| 23 |
+
const torch::Tensor &scales, const torch::Tensor &rotations,
|
| 24 |
+
const float scale_modifier, const torch::Tensor &cov3D_precomp,
|
| 25 |
+
const torch::Tensor &viewmatrix, const torch::Tensor &projmatrix,
|
| 26 |
+
const float tan_fovx, const float tan_fovy, const int image_height,
|
| 27 |
+
const int image_width, const torch::Tensor &sh, const int degree,
|
| 28 |
+
const torch::Tensor &campos, const bool prefiltered, const bool debug);
|
| 29 |
+
|
| 30 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
| 31 |
+
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
| 32 |
+
RasterizeGaussiansBackwardCUDA(
|
| 33 |
+
const torch::Tensor &background, const torch::Tensor &means3D,
|
| 34 |
+
const torch::Tensor &radii, const torch::Tensor &colors,
|
| 35 |
+
const torch::Tensor &scales, const torch::Tensor &rotations,
|
| 36 |
+
const float scale_modifier, const torch::Tensor &cov3D_precomp,
|
| 37 |
+
const torch::Tensor &viewmatrix, const torch::Tensor &projmatrix,
|
| 38 |
+
const float tan_fovx, const float tan_fovy,
|
| 39 |
+
const torch::Tensor &dL_dout_color, const torch::Tensor &sh,
|
| 40 |
+
const int degree, const torch::Tensor &campos,
|
| 41 |
+
const torch::Tensor &geomBuffer, const int R,
|
| 42 |
+
const torch::Tensor &binningBuffer, const torch::Tensor &imageBuffer,
|
| 43 |
+
const bool debug);
|
| 44 |
+
|
| 45 |
+
torch::Tensor markVisible(torch::Tensor &means3D, torch::Tensor &viewmatrix,
|
| 46 |
+
torch::Tensor &projmatrix);
|
gaussiancity/extensions/diff_gaussian_rasterization/setup.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (C) 2023, Inria
|
| 3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This software is free for non-commercial, research and evaluation use
|
| 7 |
+
# under the terms of the LICENSE.md file.
|
| 8 |
+
#
|
| 9 |
+
# For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
#
|
| 11 |
+
|
| 12 |
+
from setuptools import setup
|
| 13 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
setup(
|
| 17 |
+
name="diff_gaussian_rasterization",
|
| 18 |
+
version="1.0.0",
|
| 19 |
+
ext_modules=[
|
| 20 |
+
CUDAExtension(
|
| 21 |
+
name="diff_gaussian_rasterization_ext",
|
| 22 |
+
sources=[
|
| 23 |
+
"cuda_rasterizer/rasterizer_impl.cu",
|
| 24 |
+
"cuda_rasterizer/forward.cu",
|
| 25 |
+
"cuda_rasterizer/backward.cu",
|
| 26 |
+
"rasterize_points.cu",
|
| 27 |
+
"bindings.cpp",
|
| 28 |
+
],
|
| 29 |
+
extra_compile_args={
|
| 30 |
+
"nvcc": [
|
| 31 |
+
"-I"
|
| 32 |
+
+ os.path.join(
|
| 33 |
+
os.path.dirname(os.path.abspath(__file__)), "third_party/glm/"
|
| 34 |
+
)
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
)
|
| 38 |
+
],
|
| 39 |
+
cmdclass={"build_ext": BuildExtension},
|
| 40 |
+
)
|
gaussiancity/extensions/diff_gaussian_rasterization/third_party/glm
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 2d4c4b4dd31fde06cfffad7915c2b3006402322f
|
gaussiancity/extensions/diff_gaussian_rasterization/third_party/stbi_image_write.h
ADDED
|
@@ -0,0 +1,1724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* stb_image_write - v1.16 - public domain - http://nothings.org/stb
|
| 2 |
+
writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015
|
| 3 |
+
no warranty implied; use at your own risk
|
| 4 |
+
|
| 5 |
+
Before #including,
|
| 6 |
+
|
| 7 |
+
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
| 8 |
+
|
| 9 |
+
in the file that you want to have the implementation.
|
| 10 |
+
|
| 11 |
+
Will probably not work correctly with strict-aliasing optimizations.
|
| 12 |
+
|
| 13 |
+
ABOUT:
|
| 14 |
+
|
| 15 |
+
This header file is a library for writing images to C stdio or a callback.
|
| 16 |
+
|
| 17 |
+
The PNG output is not optimal; it is 20-50% larger than the file
|
| 18 |
+
written by a decent optimizing implementation; though providing a custom
|
| 19 |
+
zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that.
|
| 20 |
+
This library is designed for source code compactness and simplicity,
|
| 21 |
+
not optimal image file size or run-time performance.
|
| 22 |
+
|
| 23 |
+
BUILDING:
|
| 24 |
+
|
| 25 |
+
You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h.
|
| 26 |
+
You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace
|
| 27 |
+
malloc,realloc,free.
|
| 28 |
+
You can #define STBIW_MEMMOVE() to replace memmove()
|
| 29 |
+
You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function
|
| 30 |
+
for PNG compression (instead of the builtin one), it must have the following signature:
|
| 31 |
+
unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality);
|
| 32 |
+
The returned data will be freed with STBIW_FREE() (free() by default),
|
| 33 |
+
so it must be heap allocated with STBIW_MALLOC() (malloc() by default),
|
| 34 |
+
|
| 35 |
+
UNICODE:
|
| 36 |
+
|
| 37 |
+
If compiling for Windows and you wish to use Unicode filenames, compile
|
| 38 |
+
with
|
| 39 |
+
#define STBIW_WINDOWS_UTF8
|
| 40 |
+
and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert
|
| 41 |
+
Windows wchar_t filenames to utf8.
|
| 42 |
+
|
| 43 |
+
USAGE:
|
| 44 |
+
|
| 45 |
+
There are five functions, one for each image file format:
|
| 46 |
+
|
| 47 |
+
int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
|
| 48 |
+
int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
|
| 49 |
+
int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
|
| 50 |
+
int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality);
|
| 51 |
+
int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
|
| 52 |
+
|
| 53 |
+
void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically
|
| 54 |
+
|
| 55 |
+
There are also five equivalent functions that use an arbitrary write function. You are
|
| 56 |
+
expected to open/close your file-equivalent before and after calling these:
|
| 57 |
+
|
| 58 |
+
int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes);
|
| 59 |
+
int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
|
| 60 |
+
int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
|
| 61 |
+
int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);
|
| 62 |
+
int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);
|
| 63 |
+
|
| 64 |
+
where the callback is:
|
| 65 |
+
void stbi_write_func(void *context, void *data, int size);
|
| 66 |
+
|
| 67 |
+
You can configure it with these global variables:
|
| 68 |
+
int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE
|
| 69 |
+
int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression
|
| 70 |
+
int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
You can define STBI_WRITE_NO_STDIO to disable the file variant of these
|
| 74 |
+
functions, so the library will not use stdio.h at all. However, this will
|
| 75 |
+
also disable HDR writing, because it requires stdio for formatted output.
|
| 76 |
+
|
| 77 |
+
Each function returns 0 on failure and non-0 on success.
|
| 78 |
+
|
| 79 |
+
The functions create an image file defined by the parameters. The image
|
| 80 |
+
is a rectangle of pixels stored from left-to-right, top-to-bottom.
|
| 81 |
+
Each pixel contains 'comp' channels of data stored interleaved with 8-bits
|
| 82 |
+
per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is
|
| 83 |
+
monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall.
|
| 84 |
+
The *data pointer points to the first byte of the top-left-most pixel.
|
| 85 |
+
For PNG, "stride_in_bytes" is the distance in bytes from the first byte of
|
| 86 |
+
a row of pixels to the first byte of the next row of pixels.
|
| 87 |
+
|
| 88 |
+
PNG creates output files with the same number of components as the input.
|
| 89 |
+
The BMP format expands Y to RGB in the file format and does not
|
| 90 |
+
output alpha.
|
| 91 |
+
|
| 92 |
+
PNG supports writing rectangles of data even when the bytes storing rows of
|
| 93 |
+
data are not consecutive in memory (e.g. sub-rectangles of a larger image),
|
| 94 |
+
by supplying the stride between the beginning of adjacent rows. The other
|
| 95 |
+
formats do not. (Thus you cannot write a native-format BMP through the BMP
|
| 96 |
+
writer, both because it is in BGR order and because it may have padding
|
| 97 |
+
at the end of the line.)
|
| 98 |
+
|
| 99 |
+
PNG allows you to set the deflate compression level by setting the global
|
| 100 |
+
variable 'stbi_write_png_compression_level' (it defaults to 8).
|
| 101 |
+
|
| 102 |
+
HDR expects linear float data. Since the format is always 32-bit rgb(e)
|
| 103 |
+
data, alpha (if provided) is discarded, and for monochrome data it is
|
| 104 |
+
replicated across all three channels.
|
| 105 |
+
|
| 106 |
+
TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed
|
| 107 |
+
data, set the global variable 'stbi_write_tga_with_rle' to 0.
|
| 108 |
+
|
| 109 |
+
JPEG does ignore alpha channels in input data; quality is between 1 and 100.
|
| 110 |
+
Higher quality looks better but results in a bigger image.
|
| 111 |
+
JPEG baseline (no JPEG progressive).
|
| 112 |
+
|
| 113 |
+
CREDITS:
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
Sean Barrett - PNG/BMP/TGA
|
| 117 |
+
Baldur Karlsson - HDR
|
| 118 |
+
Jean-Sebastien Guay - TGA monochrome
|
| 119 |
+
Tim Kelsey - misc enhancements
|
| 120 |
+
Alan Hickman - TGA RLE
|
| 121 |
+
Emmanuel Julien - initial file IO callback implementation
|
| 122 |
+
Jon Olick - original jo_jpeg.cpp code
|
| 123 |
+
Daniel Gibson - integrate JPEG, allow external zlib
|
| 124 |
+
Aarni Koskela - allow choosing PNG filter
|
| 125 |
+
|
| 126 |
+
bugfixes:
|
| 127 |
+
github:Chribba
|
| 128 |
+
Guillaume Chereau
|
| 129 |
+
github:jry2
|
| 130 |
+
github:romigrou
|
| 131 |
+
Sergio Gonzalez
|
| 132 |
+
Jonas Karlsson
|
| 133 |
+
Filip Wasil
|
| 134 |
+
Thatcher Ulrich
|
| 135 |
+
github:poppolopoppo
|
| 136 |
+
Patrick Boettcher
|
| 137 |
+
github:xeekworx
|
| 138 |
+
Cap Petschulat
|
| 139 |
+
Simon Rodriguez
|
| 140 |
+
Ivan Tikhonov
|
| 141 |
+
github:ignotion
|
| 142 |
+
Adam Schackart
|
| 143 |
+
Andrew Kensler
|
| 144 |
+
|
| 145 |
+
LICENSE
|
| 146 |
+
|
| 147 |
+
See end of file for license information.
|
| 148 |
+
|
| 149 |
+
*/
|
| 150 |
+
|
| 151 |
+
#ifndef INCLUDE_STB_IMAGE_WRITE_H
|
| 152 |
+
#define INCLUDE_STB_IMAGE_WRITE_H
|
| 153 |
+
|
| 154 |
+
#include <stdlib.h>
|
| 155 |
+
|
| 156 |
+
// if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline'
|
| 157 |
+
#ifndef STBIWDEF
|
| 158 |
+
#ifdef STB_IMAGE_WRITE_STATIC
|
| 159 |
+
#define STBIWDEF static
|
| 160 |
+
#else
|
| 161 |
+
#ifdef __cplusplus
|
| 162 |
+
#define STBIWDEF extern "C"
|
| 163 |
+
#else
|
| 164 |
+
#define STBIWDEF extern
|
| 165 |
+
#endif
|
| 166 |
+
#endif
|
| 167 |
+
#endif
|
| 168 |
+
|
| 169 |
+
#ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations
|
| 170 |
+
STBIWDEF int stbi_write_tga_with_rle;
|
| 171 |
+
STBIWDEF int stbi_write_png_compression_level;
|
| 172 |
+
STBIWDEF int stbi_write_force_png_filter;
|
| 173 |
+
#endif
|
| 174 |
+
|
| 175 |
+
#ifndef STBI_WRITE_NO_STDIO
|
| 176 |
+
STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes);
|
| 177 |
+
STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data);
|
| 178 |
+
STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data);
|
| 179 |
+
STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data);
|
| 180 |
+
STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality);
|
| 181 |
+
|
| 182 |
+
#ifdef STBIW_WINDOWS_UTF8
|
| 183 |
+
STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);
|
| 184 |
+
#endif
|
| 185 |
+
#endif
|
| 186 |
+
|
| 187 |
+
typedef void stbi_write_func(void *context, void *data, int size);
|
| 188 |
+
|
| 189 |
+
STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes);
|
| 190 |
+
STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
|
| 191 |
+
STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data);
|
| 192 |
+
STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data);
|
| 193 |
+
STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality);
|
| 194 |
+
|
| 195 |
+
STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean);
|
| 196 |
+
|
| 197 |
+
#endif//INCLUDE_STB_IMAGE_WRITE_H
|
| 198 |
+
|
| 199 |
+
#ifdef STB_IMAGE_WRITE_IMPLEMENTATION
|
| 200 |
+
|
| 201 |
+
#ifdef _WIN32
|
| 202 |
+
#ifndef _CRT_SECURE_NO_WARNINGS
|
| 203 |
+
#define _CRT_SECURE_NO_WARNINGS
|
| 204 |
+
#endif
|
| 205 |
+
#ifndef _CRT_NONSTDC_NO_DEPRECATE
|
| 206 |
+
#define _CRT_NONSTDC_NO_DEPRECATE
|
| 207 |
+
#endif
|
| 208 |
+
#endif
|
| 209 |
+
|
| 210 |
+
#ifndef STBI_WRITE_NO_STDIO
|
| 211 |
+
#include <stdio.h>
|
| 212 |
+
#endif // STBI_WRITE_NO_STDIO
|
| 213 |
+
|
| 214 |
+
#include <stdarg.h>
|
| 215 |
+
#include <stdlib.h>
|
| 216 |
+
#include <string.h>
|
| 217 |
+
#include <math.h>
|
| 218 |
+
|
| 219 |
+
#if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED))
|
| 220 |
+
// ok
|
| 221 |
+
#elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED)
|
| 222 |
+
// ok
|
| 223 |
+
#else
|
| 224 |
+
#error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)."
|
| 225 |
+
#endif
|
| 226 |
+
|
| 227 |
+
#ifndef STBIW_MALLOC
|
| 228 |
+
#define STBIW_MALLOC(sz) malloc(sz)
|
| 229 |
+
#define STBIW_REALLOC(p,newsz) realloc(p,newsz)
|
| 230 |
+
#define STBIW_FREE(p) free(p)
|
| 231 |
+
#endif
|
| 232 |
+
|
| 233 |
+
#ifndef STBIW_REALLOC_SIZED
|
| 234 |
+
#define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz)
|
| 235 |
+
#endif
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
#ifndef STBIW_MEMMOVE
|
| 239 |
+
#define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz)
|
| 240 |
+
#endif
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
#ifndef STBIW_ASSERT
|
| 244 |
+
#include <assert.h>
|
| 245 |
+
#define STBIW_ASSERT(x) assert(x)
|
| 246 |
+
#endif
|
| 247 |
+
|
| 248 |
+
#define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff)
|
| 249 |
+
|
| 250 |
+
#ifdef STB_IMAGE_WRITE_STATIC
|
| 251 |
+
static int stbi_write_png_compression_level = 8;
|
| 252 |
+
static int stbi_write_tga_with_rle = 1;
|
| 253 |
+
static int stbi_write_force_png_filter = -1;
|
| 254 |
+
#else
|
| 255 |
+
int stbi_write_png_compression_level = 8;
|
| 256 |
+
int stbi_write_tga_with_rle = 1;
|
| 257 |
+
int stbi_write_force_png_filter = -1;
|
| 258 |
+
#endif
|
| 259 |
+
|
| 260 |
+
static int stbi__flip_vertically_on_write = 0;
|
| 261 |
+
|
| 262 |
+
STBIWDEF void stbi_flip_vertically_on_write(int flag)
|
| 263 |
+
{
|
| 264 |
+
stbi__flip_vertically_on_write = flag;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
typedef struct
|
| 268 |
+
{
|
| 269 |
+
stbi_write_func *func;
|
| 270 |
+
void *context;
|
| 271 |
+
unsigned char buffer[64];
|
| 272 |
+
int buf_used;
|
| 273 |
+
} stbi__write_context;
|
| 274 |
+
|
| 275 |
+
// initialize a callback-based context
|
| 276 |
+
static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context)
|
| 277 |
+
{
|
| 278 |
+
s->func = c;
|
| 279 |
+
s->context = context;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
#ifndef STBI_WRITE_NO_STDIO
|
| 283 |
+
|
| 284 |
+
static void stbi__stdio_write(void *context, void *data, int size)
|
| 285 |
+
{
|
| 286 |
+
fwrite(data,1,size,(FILE*) context);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)
|
| 290 |
+
#ifdef __cplusplus
|
| 291 |
+
#define STBIW_EXTERN extern "C"
|
| 292 |
+
#else
|
| 293 |
+
#define STBIW_EXTERN extern
|
| 294 |
+
#endif
|
| 295 |
+
STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide);
|
| 296 |
+
STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default);
|
| 297 |
+
|
| 298 |
+
STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input)
|
| 299 |
+
{
|
| 300 |
+
return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL);
|
| 301 |
+
}
|
| 302 |
+
#endif
|
| 303 |
+
|
| 304 |
+
static FILE *stbiw__fopen(char const *filename, char const *mode)
|
| 305 |
+
{
|
| 306 |
+
FILE *f;
|
| 307 |
+
#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8)
|
| 308 |
+
wchar_t wMode[64];
|
| 309 |
+
wchar_t wFilename[1024];
|
| 310 |
+
if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename)))
|
| 311 |
+
return 0;
|
| 312 |
+
|
| 313 |
+
if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode)))
|
| 314 |
+
return 0;
|
| 315 |
+
|
| 316 |
+
#if defined(_MSC_VER) && _MSC_VER >= 1400
|
| 317 |
+
if (0 != _wfopen_s(&f, wFilename, wMode))
|
| 318 |
+
f = 0;
|
| 319 |
+
#else
|
| 320 |
+
f = _wfopen(wFilename, wMode);
|
| 321 |
+
#endif
|
| 322 |
+
|
| 323 |
+
#elif defined(_MSC_VER) && _MSC_VER >= 1400
|
| 324 |
+
if (0 != fopen_s(&f, filename, mode))
|
| 325 |
+
f=0;
|
| 326 |
+
#else
|
| 327 |
+
f = fopen(filename, mode);
|
| 328 |
+
#endif
|
| 329 |
+
return f;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
static int stbi__start_write_file(stbi__write_context *s, const char *filename)
|
| 333 |
+
{
|
| 334 |
+
FILE *f = stbiw__fopen(filename, "wb");
|
| 335 |
+
stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f);
|
| 336 |
+
return f != NULL;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
static void stbi__end_write_file(stbi__write_context *s)
|
| 340 |
+
{
|
| 341 |
+
fclose((FILE *)s->context);
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
#endif // !STBI_WRITE_NO_STDIO
|
| 345 |
+
|
| 346 |
+
typedef unsigned int stbiw_uint32;
|
| 347 |
+
typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1];
|
| 348 |
+
|
| 349 |
+
static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v)
|
| 350 |
+
{
|
| 351 |
+
while (*fmt) {
|
| 352 |
+
switch (*fmt++) {
|
| 353 |
+
case ' ': break;
|
| 354 |
+
case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int));
|
| 355 |
+
s->func(s->context,&x,1);
|
| 356 |
+
break; }
|
| 357 |
+
case '2': { int x = va_arg(v,int);
|
| 358 |
+
unsigned char b[2];
|
| 359 |
+
b[0] = STBIW_UCHAR(x);
|
| 360 |
+
b[1] = STBIW_UCHAR(x>>8);
|
| 361 |
+
s->func(s->context,b,2);
|
| 362 |
+
break; }
|
| 363 |
+
case '4': { stbiw_uint32 x = va_arg(v,int);
|
| 364 |
+
unsigned char b[4];
|
| 365 |
+
b[0]=STBIW_UCHAR(x);
|
| 366 |
+
b[1]=STBIW_UCHAR(x>>8);
|
| 367 |
+
b[2]=STBIW_UCHAR(x>>16);
|
| 368 |
+
b[3]=STBIW_UCHAR(x>>24);
|
| 369 |
+
s->func(s->context,b,4);
|
| 370 |
+
break; }
|
| 371 |
+
default:
|
| 372 |
+
STBIW_ASSERT(0);
|
| 373 |
+
return;
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
static void stbiw__writef(stbi__write_context *s, const char *fmt, ...)
|
| 379 |
+
{
|
| 380 |
+
va_list v;
|
| 381 |
+
va_start(v, fmt);
|
| 382 |
+
stbiw__writefv(s, fmt, v);
|
| 383 |
+
va_end(v);
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
static void stbiw__write_flush(stbi__write_context *s)
|
| 387 |
+
{
|
| 388 |
+
if (s->buf_used) {
|
| 389 |
+
s->func(s->context, &s->buffer, s->buf_used);
|
| 390 |
+
s->buf_used = 0;
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
static void stbiw__putc(stbi__write_context *s, unsigned char c)
|
| 395 |
+
{
|
| 396 |
+
s->func(s->context, &c, 1);
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
static void stbiw__write1(stbi__write_context *s, unsigned char a)
|
| 400 |
+
{
|
| 401 |
+
if ((size_t)s->buf_used + 1 > sizeof(s->buffer))
|
| 402 |
+
stbiw__write_flush(s);
|
| 403 |
+
s->buffer[s->buf_used++] = a;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c)
|
| 407 |
+
{
|
| 408 |
+
int n;
|
| 409 |
+
if ((size_t)s->buf_used + 3 > sizeof(s->buffer))
|
| 410 |
+
stbiw__write_flush(s);
|
| 411 |
+
n = s->buf_used;
|
| 412 |
+
s->buf_used = n+3;
|
| 413 |
+
s->buffer[n+0] = a;
|
| 414 |
+
s->buffer[n+1] = b;
|
| 415 |
+
s->buffer[n+2] = c;
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d)
|
| 419 |
+
{
|
| 420 |
+
unsigned char bg[3] = { 255, 0, 255}, px[3];
|
| 421 |
+
int k;
|
| 422 |
+
|
| 423 |
+
if (write_alpha < 0)
|
| 424 |
+
stbiw__write1(s, d[comp - 1]);
|
| 425 |
+
|
| 426 |
+
switch (comp) {
|
| 427 |
+
case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case
|
| 428 |
+
case 1:
|
| 429 |
+
if (expand_mono)
|
| 430 |
+
stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp
|
| 431 |
+
else
|
| 432 |
+
stbiw__write1(s, d[0]); // monochrome TGA
|
| 433 |
+
break;
|
| 434 |
+
case 4:
|
| 435 |
+
if (!write_alpha) {
|
| 436 |
+
// composite against pink background
|
| 437 |
+
for (k = 0; k < 3; ++k)
|
| 438 |
+
px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255;
|
| 439 |
+
stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]);
|
| 440 |
+
break;
|
| 441 |
+
}
|
| 442 |
+
/* FALLTHROUGH */
|
| 443 |
+
case 3:
|
| 444 |
+
stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]);
|
| 445 |
+
break;
|
| 446 |
+
}
|
| 447 |
+
if (write_alpha > 0)
|
| 448 |
+
stbiw__write1(s, d[comp - 1]);
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono)
|
| 452 |
+
{
|
| 453 |
+
stbiw_uint32 zero = 0;
|
| 454 |
+
int i,j, j_end;
|
| 455 |
+
|
| 456 |
+
if (y <= 0)
|
| 457 |
+
return;
|
| 458 |
+
|
| 459 |
+
if (stbi__flip_vertically_on_write)
|
| 460 |
+
vdir *= -1;
|
| 461 |
+
|
| 462 |
+
if (vdir < 0) {
|
| 463 |
+
j_end = -1; j = y-1;
|
| 464 |
+
} else {
|
| 465 |
+
j_end = y; j = 0;
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
for (; j != j_end; j += vdir) {
|
| 469 |
+
for (i=0; i < x; ++i) {
|
| 470 |
+
unsigned char *d = (unsigned char *) data + (j*x+i)*comp;
|
| 471 |
+
stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d);
|
| 472 |
+
}
|
| 473 |
+
stbiw__write_flush(s);
|
| 474 |
+
s->func(s->context, &zero, scanline_pad);
|
| 475 |
+
}
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...)
|
| 479 |
+
{
|
| 480 |
+
if (y < 0 || x < 0) {
|
| 481 |
+
return 0;
|
| 482 |
+
} else {
|
| 483 |
+
va_list v;
|
| 484 |
+
va_start(v, fmt);
|
| 485 |
+
stbiw__writefv(s, fmt, v);
|
| 486 |
+
va_end(v);
|
| 487 |
+
stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono);
|
| 488 |
+
return 1;
|
| 489 |
+
}
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data)
|
| 493 |
+
{
|
| 494 |
+
if (comp != 4) {
|
| 495 |
+
// write RGB bitmap
|
| 496 |
+
int pad = (-x*3) & 3;
|
| 497 |
+
return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad,
|
| 498 |
+
"11 4 22 4" "4 44 22 444444",
|
| 499 |
+
'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header
|
| 500 |
+
40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header
|
| 501 |
+
} else {
|
| 502 |
+
// RGBA bitmaps need a v4 header
|
| 503 |
+
// use BI_BITFIELDS mode with 32bpp and alpha mask
|
| 504 |
+
// (straight BI_RGB with alpha mask doesn't work in most readers)
|
| 505 |
+
return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0,
|
| 506 |
+
"11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444",
|
| 507 |
+
'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header
|
| 508 |
+
108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)
|
| 513 |
+
{
|
| 514 |
+
stbi__write_context s = { 0 };
|
| 515 |
+
stbi__start_write_callbacks(&s, func, context);
|
| 516 |
+
return stbi_write_bmp_core(&s, x, y, comp, data);
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
#ifndef STBI_WRITE_NO_STDIO
|
| 520 |
+
STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data)
|
| 521 |
+
{
|
| 522 |
+
stbi__write_context s = { 0 };
|
| 523 |
+
if (stbi__start_write_file(&s,filename)) {
|
| 524 |
+
int r = stbi_write_bmp_core(&s, x, y, comp, data);
|
| 525 |
+
stbi__end_write_file(&s);
|
| 526 |
+
return r;
|
| 527 |
+
} else
|
| 528 |
+
return 0;
|
| 529 |
+
}
|
| 530 |
+
#endif //!STBI_WRITE_NO_STDIO
|
| 531 |
+
|
| 532 |
+
static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data)
|
| 533 |
+
{
|
| 534 |
+
int has_alpha = (comp == 2 || comp == 4);
|
| 535 |
+
int colorbytes = has_alpha ? comp-1 : comp;
|
| 536 |
+
int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3
|
| 537 |
+
|
| 538 |
+
if (y < 0 || x < 0)
|
| 539 |
+
return 0;
|
| 540 |
+
|
| 541 |
+
if (!stbi_write_tga_with_rle) {
|
| 542 |
+
return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0,
|
| 543 |
+
"111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8);
|
| 544 |
+
} else {
|
| 545 |
+
int i,j,k;
|
| 546 |
+
int jend, jdir;
|
| 547 |
+
|
| 548 |
+
stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8);
|
| 549 |
+
|
| 550 |
+
if (stbi__flip_vertically_on_write) {
|
| 551 |
+
j = 0;
|
| 552 |
+
jend = y;
|
| 553 |
+
jdir = 1;
|
| 554 |
+
} else {
|
| 555 |
+
j = y-1;
|
| 556 |
+
jend = -1;
|
| 557 |
+
jdir = -1;
|
| 558 |
+
}
|
| 559 |
+
for (; j != jend; j += jdir) {
|
| 560 |
+
unsigned char *row = (unsigned char *) data + j * x * comp;
|
| 561 |
+
int len;
|
| 562 |
+
|
| 563 |
+
for (i = 0; i < x; i += len) {
|
| 564 |
+
unsigned char *begin = row + i * comp;
|
| 565 |
+
int diff = 1;
|
| 566 |
+
len = 1;
|
| 567 |
+
|
| 568 |
+
if (i < x - 1) {
|
| 569 |
+
++len;
|
| 570 |
+
diff = memcmp(begin, row + (i + 1) * comp, comp);
|
| 571 |
+
if (diff) {
|
| 572 |
+
const unsigned char *prev = begin;
|
| 573 |
+
for (k = i + 2; k < x && len < 128; ++k) {
|
| 574 |
+
if (memcmp(prev, row + k * comp, comp)) {
|
| 575 |
+
prev += comp;
|
| 576 |
+
++len;
|
| 577 |
+
} else {
|
| 578 |
+
--len;
|
| 579 |
+
break;
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
+
} else {
|
| 583 |
+
for (k = i + 2; k < x && len < 128; ++k) {
|
| 584 |
+
if (!memcmp(begin, row + k * comp, comp)) {
|
| 585 |
+
++len;
|
| 586 |
+
} else {
|
| 587 |
+
break;
|
| 588 |
+
}
|
| 589 |
+
}
|
| 590 |
+
}
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
if (diff) {
|
| 594 |
+
unsigned char header = STBIW_UCHAR(len - 1);
|
| 595 |
+
stbiw__write1(s, header);
|
| 596 |
+
for (k = 0; k < len; ++k) {
|
| 597 |
+
stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp);
|
| 598 |
+
}
|
| 599 |
+
} else {
|
| 600 |
+
unsigned char header = STBIW_UCHAR(len - 129);
|
| 601 |
+
stbiw__write1(s, header);
|
| 602 |
+
stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin);
|
| 603 |
+
}
|
| 604 |
+
}
|
| 605 |
+
}
|
| 606 |
+
stbiw__write_flush(s);
|
| 607 |
+
}
|
| 608 |
+
return 1;
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data)
|
| 612 |
+
{
|
| 613 |
+
stbi__write_context s = { 0 };
|
| 614 |
+
stbi__start_write_callbacks(&s, func, context);
|
| 615 |
+
return stbi_write_tga_core(&s, x, y, comp, (void *) data);
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
#ifndef STBI_WRITE_NO_STDIO
|
| 619 |
+
STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data)
|
| 620 |
+
{
|
| 621 |
+
stbi__write_context s = { 0 };
|
| 622 |
+
if (stbi__start_write_file(&s,filename)) {
|
| 623 |
+
int r = stbi_write_tga_core(&s, x, y, comp, (void *) data);
|
| 624 |
+
stbi__end_write_file(&s);
|
| 625 |
+
return r;
|
| 626 |
+
} else
|
| 627 |
+
return 0;
|
| 628 |
+
}
|
| 629 |
+
#endif
|
| 630 |
+
|
| 631 |
+
// *************************************************************************************************
|
| 632 |
+
// Radiance RGBE HDR writer
|
| 633 |
+
// by Baldur Karlsson
|
| 634 |
+
|
| 635 |
+
#define stbiw__max(a, b) ((a) > (b) ? (a) : (b))
|
| 636 |
+
|
| 637 |
+
#ifndef STBI_WRITE_NO_STDIO
|
| 638 |
+
|
| 639 |
+
static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear)
|
| 640 |
+
{
|
| 641 |
+
int exponent;
|
| 642 |
+
float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2]));
|
| 643 |
+
|
| 644 |
+
if (maxcomp < 1e-32f) {
|
| 645 |
+
rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0;
|
| 646 |
+
} else {
|
| 647 |
+
float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp;
|
| 648 |
+
|
| 649 |
+
rgbe[0] = (unsigned char)(linear[0] * normalize);
|
| 650 |
+
rgbe[1] = (unsigned char)(linear[1] * normalize);
|
| 651 |
+
rgbe[2] = (unsigned char)(linear[2] * normalize);
|
| 652 |
+
rgbe[3] = (unsigned char)(exponent + 128);
|
| 653 |
+
}
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte)
|
| 657 |
+
{
|
| 658 |
+
unsigned char lengthbyte = STBIW_UCHAR(length+128);
|
| 659 |
+
STBIW_ASSERT(length+128 <= 255);
|
| 660 |
+
s->func(s->context, &lengthbyte, 1);
|
| 661 |
+
s->func(s->context, &databyte, 1);
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data)
|
| 665 |
+
{
|
| 666 |
+
unsigned char lengthbyte = STBIW_UCHAR(length);
|
| 667 |
+
STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code
|
| 668 |
+
s->func(s->context, &lengthbyte, 1);
|
| 669 |
+
s->func(s->context, data, length);
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline)
|
| 673 |
+
{
|
| 674 |
+
unsigned char scanlineheader[4] = { 2, 2, 0, 0 };
|
| 675 |
+
unsigned char rgbe[4];
|
| 676 |
+
float linear[3];
|
| 677 |
+
int x;
|
| 678 |
+
|
| 679 |
+
scanlineheader[2] = (width&0xff00)>>8;
|
| 680 |
+
scanlineheader[3] = (width&0x00ff);
|
| 681 |
+
|
| 682 |
+
/* skip RLE for images too small or large */
|
| 683 |
+
if (width < 8 || width >= 32768) {
|
| 684 |
+
for (x=0; x < width; x++) {
|
| 685 |
+
switch (ncomp) {
|
| 686 |
+
case 4: /* fallthrough */
|
| 687 |
+
case 3: linear[2] = scanline[x*ncomp + 2];
|
| 688 |
+
linear[1] = scanline[x*ncomp + 1];
|
| 689 |
+
linear[0] = scanline[x*ncomp + 0];
|
| 690 |
+
break;
|
| 691 |
+
default:
|
| 692 |
+
linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];
|
| 693 |
+
break;
|
| 694 |
+
}
|
| 695 |
+
stbiw__linear_to_rgbe(rgbe, linear);
|
| 696 |
+
s->func(s->context, rgbe, 4);
|
| 697 |
+
}
|
| 698 |
+
} else {
|
| 699 |
+
int c,r;
|
| 700 |
+
/* encode into scratch buffer */
|
| 701 |
+
for (x=0; x < width; x++) {
|
| 702 |
+
switch(ncomp) {
|
| 703 |
+
case 4: /* fallthrough */
|
| 704 |
+
case 3: linear[2] = scanline[x*ncomp + 2];
|
| 705 |
+
linear[1] = scanline[x*ncomp + 1];
|
| 706 |
+
linear[0] = scanline[x*ncomp + 0];
|
| 707 |
+
break;
|
| 708 |
+
default:
|
| 709 |
+
linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0];
|
| 710 |
+
break;
|
| 711 |
+
}
|
| 712 |
+
stbiw__linear_to_rgbe(rgbe, linear);
|
| 713 |
+
scratch[x + width*0] = rgbe[0];
|
| 714 |
+
scratch[x + width*1] = rgbe[1];
|
| 715 |
+
scratch[x + width*2] = rgbe[2];
|
| 716 |
+
scratch[x + width*3] = rgbe[3];
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
s->func(s->context, scanlineheader, 4);
|
| 720 |
+
|
| 721 |
+
/* RLE each component separately */
|
| 722 |
+
for (c=0; c < 4; c++) {
|
| 723 |
+
unsigned char *comp = &scratch[width*c];
|
| 724 |
+
|
| 725 |
+
x = 0;
|
| 726 |
+
while (x < width) {
|
| 727 |
+
// find first run
|
| 728 |
+
r = x;
|
| 729 |
+
while (r+2 < width) {
|
| 730 |
+
if (comp[r] == comp[r+1] && comp[r] == comp[r+2])
|
| 731 |
+
break;
|
| 732 |
+
++r;
|
| 733 |
+
}
|
| 734 |
+
if (r+2 >= width)
|
| 735 |
+
r = width;
|
| 736 |
+
// dump up to first run
|
| 737 |
+
while (x < r) {
|
| 738 |
+
int len = r-x;
|
| 739 |
+
if (len > 128) len = 128;
|
| 740 |
+
stbiw__write_dump_data(s, len, &comp[x]);
|
| 741 |
+
x += len;
|
| 742 |
+
}
|
| 743 |
+
// if there's a run, output it
|
| 744 |
+
if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd
|
| 745 |
+
// find next byte after run
|
| 746 |
+
while (r < width && comp[r] == comp[x])
|
| 747 |
+
++r;
|
| 748 |
+
// output run up to r
|
| 749 |
+
while (x < r) {
|
| 750 |
+
int len = r-x;
|
| 751 |
+
if (len > 127) len = 127;
|
| 752 |
+
stbiw__write_run_data(s, len, comp[x]);
|
| 753 |
+
x += len;
|
| 754 |
+
}
|
| 755 |
+
}
|
| 756 |
+
}
|
| 757 |
+
}
|
| 758 |
+
}
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data)
|
| 762 |
+
{
|
| 763 |
+
if (y <= 0 || x <= 0 || data == NULL)
|
| 764 |
+
return 0;
|
| 765 |
+
else {
|
| 766 |
+
// Each component is stored separately. Allocate scratch space for full output scanline.
|
| 767 |
+
unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4);
|
| 768 |
+
int i, len;
|
| 769 |
+
char buffer[128];
|
| 770 |
+
char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n";
|
| 771 |
+
s->func(s->context, header, sizeof(header)-1);
|
| 772 |
+
|
| 773 |
+
#ifdef __STDC_LIB_EXT1__
|
| 774 |
+
len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x);
|
| 775 |
+
#else
|
| 776 |
+
len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x);
|
| 777 |
+
#endif
|
| 778 |
+
s->func(s->context, buffer, len);
|
| 779 |
+
|
| 780 |
+
for(i=0; i < y; i++)
|
| 781 |
+
stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i));
|
| 782 |
+
STBIW_FREE(scratch);
|
| 783 |
+
return 1;
|
| 784 |
+
}
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data)
|
| 788 |
+
{
|
| 789 |
+
stbi__write_context s = { 0 };
|
| 790 |
+
stbi__start_write_callbacks(&s, func, context);
|
| 791 |
+
return stbi_write_hdr_core(&s, x, y, comp, (float *) data);
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data)
|
| 795 |
+
{
|
| 796 |
+
stbi__write_context s = { 0 };
|
| 797 |
+
if (stbi__start_write_file(&s,filename)) {
|
| 798 |
+
int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data);
|
| 799 |
+
stbi__end_write_file(&s);
|
| 800 |
+
return r;
|
| 801 |
+
} else
|
| 802 |
+
return 0;
|
| 803 |
+
}
|
| 804 |
+
#endif // STBI_WRITE_NO_STDIO
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
//////////////////////////////////////////////////////////////////////////////
|
| 808 |
+
//
|
| 809 |
+
// PNG writer
|
| 810 |
+
//
|
| 811 |
+
|
| 812 |
+
#ifndef STBIW_ZLIB_COMPRESS
|
| 813 |
+
// stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size()
|
| 814 |
+
#define stbiw__sbraw(a) ((int *) (void *) (a) - 2)
|
| 815 |
+
#define stbiw__sbm(a) stbiw__sbraw(a)[0]
|
| 816 |
+
#define stbiw__sbn(a) stbiw__sbraw(a)[1]
|
| 817 |
+
|
| 818 |
+
#define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a))
|
| 819 |
+
#define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0)
|
| 820 |
+
#define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a)))
|
| 821 |
+
|
| 822 |
+
#define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v))
|
| 823 |
+
#define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0)
|
| 824 |
+
#define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0)
|
| 825 |
+
|
| 826 |
+
static void *stbiw__sbgrowf(void **arr, int increment, int itemsize)
|
| 827 |
+
{
|
| 828 |
+
int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1;
|
| 829 |
+
void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2);
|
| 830 |
+
STBIW_ASSERT(p);
|
| 831 |
+
if (p) {
|
| 832 |
+
if (!*arr) ((int *) p)[1] = 0;
|
| 833 |
+
*arr = (void *) ((int *) p + 2);
|
| 834 |
+
stbiw__sbm(*arr) = m;
|
| 835 |
+
}
|
| 836 |
+
return *arr;
|
| 837 |
+
}
|
| 838 |
+
|
| 839 |
+
static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount)
|
| 840 |
+
{
|
| 841 |
+
while (*bitcount >= 8) {
|
| 842 |
+
stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer));
|
| 843 |
+
*bitbuffer >>= 8;
|
| 844 |
+
*bitcount -= 8;
|
| 845 |
+
}
|
| 846 |
+
return data;
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
static int stbiw__zlib_bitrev(int code, int codebits)
|
| 850 |
+
{
|
| 851 |
+
int res=0;
|
| 852 |
+
while (codebits--) {
|
| 853 |
+
res = (res << 1) | (code & 1);
|
| 854 |
+
code >>= 1;
|
| 855 |
+
}
|
| 856 |
+
return res;
|
| 857 |
+
}
|
| 858 |
+
|
| 859 |
+
static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit)
|
| 860 |
+
{
|
| 861 |
+
int i;
|
| 862 |
+
for (i=0; i < limit && i < 258; ++i)
|
| 863 |
+
if (a[i] != b[i]) break;
|
| 864 |
+
return i;
|
| 865 |
+
}
|
| 866 |
+
|
| 867 |
+
static unsigned int stbiw__zhash(unsigned char *data)
|
| 868 |
+
{
|
| 869 |
+
stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16);
|
| 870 |
+
hash ^= hash << 3;
|
| 871 |
+
hash += hash >> 5;
|
| 872 |
+
hash ^= hash << 4;
|
| 873 |
+
hash += hash >> 17;
|
| 874 |
+
hash ^= hash << 25;
|
| 875 |
+
hash += hash >> 6;
|
| 876 |
+
return hash;
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
#define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount))
|
| 880 |
+
#define stbiw__zlib_add(code,codebits) \
|
| 881 |
+
(bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush())
|
| 882 |
+
#define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c)
|
| 883 |
+
// default huffman tables
|
| 884 |
+
#define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8)
|
| 885 |
+
#define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9)
|
| 886 |
+
#define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7)
|
| 887 |
+
#define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8)
|
| 888 |
+
#define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n))
|
| 889 |
+
#define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n))
|
| 890 |
+
|
| 891 |
+
#define stbiw__ZHASH 16384
|
| 892 |
+
|
| 893 |
+
#endif // STBIW_ZLIB_COMPRESS
|
| 894 |
+
|
| 895 |
+
STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality)
|
| 896 |
+
{
|
| 897 |
+
#ifdef STBIW_ZLIB_COMPRESS
|
| 898 |
+
// user provided a zlib compress implementation, use that
|
| 899 |
+
return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality);
|
| 900 |
+
#else // use builtin
|
| 901 |
+
static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 };
|
| 902 |
+
static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 };
|
| 903 |
+
static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 };
|
| 904 |
+
static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 };
|
| 905 |
+
unsigned int bitbuf=0;
|
| 906 |
+
int i,j, bitcount=0;
|
| 907 |
+
unsigned char *out = NULL;
|
| 908 |
+
unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**));
|
| 909 |
+
if (hash_table == NULL)
|
| 910 |
+
return NULL;
|
| 911 |
+
if (quality < 5) quality = 5;
|
| 912 |
+
|
| 913 |
+
stbiw__sbpush(out, 0x78); // DEFLATE 32K window
|
| 914 |
+
stbiw__sbpush(out, 0x5e); // FLEVEL = 1
|
| 915 |
+
stbiw__zlib_add(1,1); // BFINAL = 1
|
| 916 |
+
stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman
|
| 917 |
+
|
| 918 |
+
for (i=0; i < stbiw__ZHASH; ++i)
|
| 919 |
+
hash_table[i] = NULL;
|
| 920 |
+
|
| 921 |
+
i=0;
|
| 922 |
+
while (i < data_len-3) {
|
| 923 |
+
// hash next 3 bytes of data to be compressed
|
| 924 |
+
int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3;
|
| 925 |
+
unsigned char *bestloc = 0;
|
| 926 |
+
unsigned char **hlist = hash_table[h];
|
| 927 |
+
int n = stbiw__sbcount(hlist);
|
| 928 |
+
for (j=0; j < n; ++j) {
|
| 929 |
+
if (hlist[j]-data > i-32768) { // if entry lies within window
|
| 930 |
+
int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i);
|
| 931 |
+
if (d >= best) { best=d; bestloc=hlist[j]; }
|
| 932 |
+
}
|
| 933 |
+
}
|
| 934 |
+
// when hash table entry is too long, delete half the entries
|
| 935 |
+
if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) {
|
| 936 |
+
STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality);
|
| 937 |
+
stbiw__sbn(hash_table[h]) = quality;
|
| 938 |
+
}
|
| 939 |
+
stbiw__sbpush(hash_table[h],data+i);
|
| 940 |
+
|
| 941 |
+
if (bestloc) {
|
| 942 |
+
// "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal
|
| 943 |
+
h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1);
|
| 944 |
+
hlist = hash_table[h];
|
| 945 |
+
n = stbiw__sbcount(hlist);
|
| 946 |
+
for (j=0; j < n; ++j) {
|
| 947 |
+
if (hlist[j]-data > i-32767) {
|
| 948 |
+
int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1);
|
| 949 |
+
if (e > best) { // if next match is better, bail on current match
|
| 950 |
+
bestloc = NULL;
|
| 951 |
+
break;
|
| 952 |
+
}
|
| 953 |
+
}
|
| 954 |
+
}
|
| 955 |
+
}
|
| 956 |
+
|
| 957 |
+
if (bestloc) {
|
| 958 |
+
int d = (int) (data+i - bestloc); // distance back
|
| 959 |
+
STBIW_ASSERT(d <= 32767 && best <= 258);
|
| 960 |
+
for (j=0; best > lengthc[j+1]-1; ++j);
|
| 961 |
+
stbiw__zlib_huff(j+257);
|
| 962 |
+
if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]);
|
| 963 |
+
for (j=0; d > distc[j+1]-1; ++j);
|
| 964 |
+
stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5);
|
| 965 |
+
if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]);
|
| 966 |
+
i += best;
|
| 967 |
+
} else {
|
| 968 |
+
stbiw__zlib_huffb(data[i]);
|
| 969 |
+
++i;
|
| 970 |
+
}
|
| 971 |
+
}
|
| 972 |
+
// write out final bytes
|
| 973 |
+
for (;i < data_len; ++i)
|
| 974 |
+
stbiw__zlib_huffb(data[i]);
|
| 975 |
+
stbiw__zlib_huff(256); // end of block
|
| 976 |
+
// pad with 0 bits to byte boundary
|
| 977 |
+
while (bitcount)
|
| 978 |
+
stbiw__zlib_add(0,1);
|
| 979 |
+
|
| 980 |
+
for (i=0; i < stbiw__ZHASH; ++i)
|
| 981 |
+
(void) stbiw__sbfree(hash_table[i]);
|
| 982 |
+
STBIW_FREE(hash_table);
|
| 983 |
+
|
| 984 |
+
// store uncompressed instead if compression was worse
|
| 985 |
+
if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) {
|
| 986 |
+
stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1
|
| 987 |
+
for (j = 0; j < data_len;) {
|
| 988 |
+
int blocklen = data_len - j;
|
| 989 |
+
if (blocklen > 32767) blocklen = 32767;
|
| 990 |
+
stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression
|
| 991 |
+
stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN
|
| 992 |
+
stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8));
|
| 993 |
+
stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN
|
| 994 |
+
stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8));
|
| 995 |
+
memcpy(out+stbiw__sbn(out), data+j, blocklen);
|
| 996 |
+
stbiw__sbn(out) += blocklen;
|
| 997 |
+
j += blocklen;
|
| 998 |
+
}
|
| 999 |
+
}
|
| 1000 |
+
|
| 1001 |
+
{
|
| 1002 |
+
// compute adler32 on input
|
| 1003 |
+
unsigned int s1=1, s2=0;
|
| 1004 |
+
int blocklen = (int) (data_len % 5552);
|
| 1005 |
+
j=0;
|
| 1006 |
+
while (j < data_len) {
|
| 1007 |
+
for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; }
|
| 1008 |
+
s1 %= 65521; s2 %= 65521;
|
| 1009 |
+
j += blocklen;
|
| 1010 |
+
blocklen = 5552;
|
| 1011 |
+
}
|
| 1012 |
+
stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8));
|
| 1013 |
+
stbiw__sbpush(out, STBIW_UCHAR(s2));
|
| 1014 |
+
stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8));
|
| 1015 |
+
stbiw__sbpush(out, STBIW_UCHAR(s1));
|
| 1016 |
+
}
|
| 1017 |
+
*out_len = stbiw__sbn(out);
|
| 1018 |
+
// make returned pointer freeable
|
| 1019 |
+
STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len);
|
| 1020 |
+
return (unsigned char *) stbiw__sbraw(out);
|
| 1021 |
+
#endif // STBIW_ZLIB_COMPRESS
|
| 1022 |
+
}
|
| 1023 |
+
|
| 1024 |
+
static unsigned int stbiw__crc32(unsigned char *buffer, int len)
|
| 1025 |
+
{
|
| 1026 |
+
#ifdef STBIW_CRC32
|
| 1027 |
+
return STBIW_CRC32(buffer, len);
|
| 1028 |
+
#else
|
| 1029 |
+
static unsigned int crc_table[256] =
|
| 1030 |
+
{
|
| 1031 |
+
0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3,
|
| 1032 |
+
0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91,
|
| 1033 |
+
0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7,
|
| 1034 |
+
0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5,
|
| 1035 |
+
0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B,
|
| 1036 |
+
0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59,
|
| 1037 |
+
0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F,
|
| 1038 |
+
0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D,
|
| 1039 |
+
0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433,
|
| 1040 |
+
0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01,
|
| 1041 |
+
0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457,
|
| 1042 |
+
0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65,
|
| 1043 |
+
0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB,
|
| 1044 |
+
0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9,
|
| 1045 |
+
0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F,
|
| 1046 |
+
0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD,
|
| 1047 |
+
0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683,
|
| 1048 |
+
0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1,
|
| 1049 |
+
0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7,
|
| 1050 |
+
0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5,
|
| 1051 |
+
0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B,
|
| 1052 |
+
0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79,
|
| 1053 |
+
0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F,
|
| 1054 |
+
0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D,
|
| 1055 |
+
0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713,
|
| 1056 |
+
0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21,
|
| 1057 |
+
0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777,
|
| 1058 |
+
0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45,
|
| 1059 |
+
0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB,
|
| 1060 |
+
0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9,
|
| 1061 |
+
0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF,
|
| 1062 |
+
0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D
|
| 1063 |
+
};
|
| 1064 |
+
|
| 1065 |
+
unsigned int crc = ~0u;
|
| 1066 |
+
int i;
|
| 1067 |
+
for (i=0; i < len; ++i)
|
| 1068 |
+
crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)];
|
| 1069 |
+
return ~crc;
|
| 1070 |
+
#endif
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
+
#define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4)
|
| 1074 |
+
#define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v));
|
| 1075 |
+
#define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3])
|
| 1076 |
+
|
| 1077 |
+
static void stbiw__wpcrc(unsigned char **data, int len)
|
| 1078 |
+
{
|
| 1079 |
+
unsigned int crc = stbiw__crc32(*data - len - 4, len+4);
|
| 1080 |
+
stbiw__wp32(*data, crc);
|
| 1081 |
+
}
|
| 1082 |
+
|
| 1083 |
+
static unsigned char stbiw__paeth(int a, int b, int c)
|
| 1084 |
+
{
|
| 1085 |
+
int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c);
|
| 1086 |
+
if (pa <= pb && pa <= pc) return STBIW_UCHAR(a);
|
| 1087 |
+
if (pb <= pc) return STBIW_UCHAR(b);
|
| 1088 |
+
return STBIW_UCHAR(c);
|
| 1089 |
+
}
|
| 1090 |
+
|
| 1091 |
+
// @OPTIMIZE: provide an option that always forces left-predict or paeth predict
|
| 1092 |
+
static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer)
|
| 1093 |
+
{
|
| 1094 |
+
static int mapping[] = { 0,1,2,3,4 };
|
| 1095 |
+
static int firstmap[] = { 0,1,0,5,6 };
|
| 1096 |
+
int *mymap = (y != 0) ? mapping : firstmap;
|
| 1097 |
+
int i;
|
| 1098 |
+
int type = mymap[filter_type];
|
| 1099 |
+
unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y);
|
| 1100 |
+
int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes;
|
| 1101 |
+
|
| 1102 |
+
if (type==0) {
|
| 1103 |
+
memcpy(line_buffer, z, width*n);
|
| 1104 |
+
return;
|
| 1105 |
+
}
|
| 1106 |
+
|
| 1107 |
+
// first loop isn't optimized since it's just one pixel
|
| 1108 |
+
for (i = 0; i < n; ++i) {
|
| 1109 |
+
switch (type) {
|
| 1110 |
+
case 1: line_buffer[i] = z[i]; break;
|
| 1111 |
+
case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break;
|
| 1112 |
+
case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break;
|
| 1113 |
+
case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break;
|
| 1114 |
+
case 5: line_buffer[i] = z[i]; break;
|
| 1115 |
+
case 6: line_buffer[i] = z[i]; break;
|
| 1116 |
+
}
|
| 1117 |
+
}
|
| 1118 |
+
switch (type) {
|
| 1119 |
+
case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break;
|
| 1120 |
+
case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break;
|
| 1121 |
+
case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break;
|
| 1122 |
+
case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break;
|
| 1123 |
+
case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break;
|
| 1124 |
+
case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break;
|
| 1125 |
+
}
|
| 1126 |
+
}
|
| 1127 |
+
|
| 1128 |
+
STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len)
|
| 1129 |
+
{
|
| 1130 |
+
int force_filter = stbi_write_force_png_filter;
|
| 1131 |
+
int ctype[5] = { -1, 0, 4, 2, 6 };
|
| 1132 |
+
unsigned char sig[8] = { 137,80,78,71,13,10,26,10 };
|
| 1133 |
+
unsigned char *out,*o, *filt, *zlib;
|
| 1134 |
+
signed char *line_buffer;
|
| 1135 |
+
int j,zlen;
|
| 1136 |
+
|
| 1137 |
+
if (stride_bytes == 0)
|
| 1138 |
+
stride_bytes = x * n;
|
| 1139 |
+
|
| 1140 |
+
if (force_filter >= 5) {
|
| 1141 |
+
force_filter = -1;
|
| 1142 |
+
}
|
| 1143 |
+
|
| 1144 |
+
filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0;
|
| 1145 |
+
line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; }
|
| 1146 |
+
for (j=0; j < y; ++j) {
|
| 1147 |
+
int filter_type;
|
| 1148 |
+
if (force_filter > -1) {
|
| 1149 |
+
filter_type = force_filter;
|
| 1150 |
+
stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer);
|
| 1151 |
+
} else { // Estimate the best filter by running through all of them:
|
| 1152 |
+
int best_filter = 0, best_filter_val = 0x7fffffff, est, i;
|
| 1153 |
+
for (filter_type = 0; filter_type < 5; filter_type++) {
|
| 1154 |
+
stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer);
|
| 1155 |
+
|
| 1156 |
+
// Estimate the entropy of the line using this filter; the less, the better.
|
| 1157 |
+
est = 0;
|
| 1158 |
+
for (i = 0; i < x*n; ++i) {
|
| 1159 |
+
est += abs((signed char) line_buffer[i]);
|
| 1160 |
+
}
|
| 1161 |
+
if (est < best_filter_val) {
|
| 1162 |
+
best_filter_val = est;
|
| 1163 |
+
best_filter = filter_type;
|
| 1164 |
+
}
|
| 1165 |
+
}
|
| 1166 |
+
if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it
|
| 1167 |
+
stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer);
|
| 1168 |
+
filter_type = best_filter;
|
| 1169 |
+
}
|
| 1170 |
+
}
|
| 1171 |
+
// when we get here, filter_type contains the filter type, and line_buffer contains the data
|
| 1172 |
+
filt[j*(x*n+1)] = (unsigned char) filter_type;
|
| 1173 |
+
STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n);
|
| 1174 |
+
}
|
| 1175 |
+
STBIW_FREE(line_buffer);
|
| 1176 |
+
zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level);
|
| 1177 |
+
STBIW_FREE(filt);
|
| 1178 |
+
if (!zlib) return 0;
|
| 1179 |
+
|
| 1180 |
+
// each tag requires 12 bytes of overhead
|
| 1181 |
+
out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12);
|
| 1182 |
+
if (!out) return 0;
|
| 1183 |
+
*out_len = 8 + 12+13 + 12+zlen + 12;
|
| 1184 |
+
|
| 1185 |
+
o=out;
|
| 1186 |
+
STBIW_MEMMOVE(o,sig,8); o+= 8;
|
| 1187 |
+
stbiw__wp32(o, 13); // header length
|
| 1188 |
+
stbiw__wptag(o, "IHDR");
|
| 1189 |
+
stbiw__wp32(o, x);
|
| 1190 |
+
stbiw__wp32(o, y);
|
| 1191 |
+
*o++ = 8;
|
| 1192 |
+
*o++ = STBIW_UCHAR(ctype[n]);
|
| 1193 |
+
*o++ = 0;
|
| 1194 |
+
*o++ = 0;
|
| 1195 |
+
*o++ = 0;
|
| 1196 |
+
stbiw__wpcrc(&o,13);
|
| 1197 |
+
|
| 1198 |
+
stbiw__wp32(o, zlen);
|
| 1199 |
+
stbiw__wptag(o, "IDAT");
|
| 1200 |
+
STBIW_MEMMOVE(o, zlib, zlen);
|
| 1201 |
+
o += zlen;
|
| 1202 |
+
STBIW_FREE(zlib);
|
| 1203 |
+
stbiw__wpcrc(&o, zlen);
|
| 1204 |
+
|
| 1205 |
+
stbiw__wp32(o,0);
|
| 1206 |
+
stbiw__wptag(o, "IEND");
|
| 1207 |
+
stbiw__wpcrc(&o,0);
|
| 1208 |
+
|
| 1209 |
+
STBIW_ASSERT(o == out + *out_len);
|
| 1210 |
+
|
| 1211 |
+
return out;
|
| 1212 |
+
}
|
| 1213 |
+
|
| 1214 |
+
#ifndef STBI_WRITE_NO_STDIO
|
| 1215 |
+
STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes)
|
| 1216 |
+
{
|
| 1217 |
+
FILE *f;
|
| 1218 |
+
int len;
|
| 1219 |
+
unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);
|
| 1220 |
+
if (png == NULL) return 0;
|
| 1221 |
+
|
| 1222 |
+
f = stbiw__fopen(filename, "wb");
|
| 1223 |
+
if (!f) { STBIW_FREE(png); return 0; }
|
| 1224 |
+
fwrite(png, 1, len, f);
|
| 1225 |
+
fclose(f);
|
| 1226 |
+
STBIW_FREE(png);
|
| 1227 |
+
return 1;
|
| 1228 |
+
}
|
| 1229 |
+
#endif
|
| 1230 |
+
|
| 1231 |
+
STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes)
|
| 1232 |
+
{
|
| 1233 |
+
int len;
|
| 1234 |
+
unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len);
|
| 1235 |
+
if (png == NULL) return 0;
|
| 1236 |
+
func(context, png, len);
|
| 1237 |
+
STBIW_FREE(png);
|
| 1238 |
+
return 1;
|
| 1239 |
+
}
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
+
/* ***************************************************************************
|
| 1243 |
+
*
|
| 1244 |
+
* JPEG writer
|
| 1245 |
+
*
|
| 1246 |
+
* This is based on Jon Olick's jo_jpeg.cpp:
|
| 1247 |
+
* public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html
|
| 1248 |
+
*/
|
| 1249 |
+
|
| 1250 |
+
static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18,
|
| 1251 |
+
24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 };
|
| 1252 |
+
|
| 1253 |
+
static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) {
|
| 1254 |
+
int bitBuf = *bitBufP, bitCnt = *bitCntP;
|
| 1255 |
+
bitCnt += bs[1];
|
| 1256 |
+
bitBuf |= bs[0] << (24 - bitCnt);
|
| 1257 |
+
while(bitCnt >= 8) {
|
| 1258 |
+
unsigned char c = (bitBuf >> 16) & 255;
|
| 1259 |
+
stbiw__putc(s, c);
|
| 1260 |
+
if(c == 255) {
|
| 1261 |
+
stbiw__putc(s, 0);
|
| 1262 |
+
}
|
| 1263 |
+
bitBuf <<= 8;
|
| 1264 |
+
bitCnt -= 8;
|
| 1265 |
+
}
|
| 1266 |
+
*bitBufP = bitBuf;
|
| 1267 |
+
*bitCntP = bitCnt;
|
| 1268 |
+
}
|
| 1269 |
+
|
| 1270 |
+
static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) {
|
| 1271 |
+
float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p;
|
| 1272 |
+
float z1, z2, z3, z4, z5, z11, z13;
|
| 1273 |
+
|
| 1274 |
+
float tmp0 = d0 + d7;
|
| 1275 |
+
float tmp7 = d0 - d7;
|
| 1276 |
+
float tmp1 = d1 + d6;
|
| 1277 |
+
float tmp6 = d1 - d6;
|
| 1278 |
+
float tmp2 = d2 + d5;
|
| 1279 |
+
float tmp5 = d2 - d5;
|
| 1280 |
+
float tmp3 = d3 + d4;
|
| 1281 |
+
float tmp4 = d3 - d4;
|
| 1282 |
+
|
| 1283 |
+
// Even part
|
| 1284 |
+
float tmp10 = tmp0 + tmp3; // phase 2
|
| 1285 |
+
float tmp13 = tmp0 - tmp3;
|
| 1286 |
+
float tmp11 = tmp1 + tmp2;
|
| 1287 |
+
float tmp12 = tmp1 - tmp2;
|
| 1288 |
+
|
| 1289 |
+
d0 = tmp10 + tmp11; // phase 3
|
| 1290 |
+
d4 = tmp10 - tmp11;
|
| 1291 |
+
|
| 1292 |
+
z1 = (tmp12 + tmp13) * 0.707106781f; // c4
|
| 1293 |
+
d2 = tmp13 + z1; // phase 5
|
| 1294 |
+
d6 = tmp13 - z1;
|
| 1295 |
+
|
| 1296 |
+
// Odd part
|
| 1297 |
+
tmp10 = tmp4 + tmp5; // phase 2
|
| 1298 |
+
tmp11 = tmp5 + tmp6;
|
| 1299 |
+
tmp12 = tmp6 + tmp7;
|
| 1300 |
+
|
| 1301 |
+
// The rotator is modified from fig 4-8 to avoid extra negations.
|
| 1302 |
+
z5 = (tmp10 - tmp12) * 0.382683433f; // c6
|
| 1303 |
+
z2 = tmp10 * 0.541196100f + z5; // c2-c6
|
| 1304 |
+
z4 = tmp12 * 1.306562965f + z5; // c2+c6
|
| 1305 |
+
z3 = tmp11 * 0.707106781f; // c4
|
| 1306 |
+
|
| 1307 |
+
z11 = tmp7 + z3; // phase 5
|
| 1308 |
+
z13 = tmp7 - z3;
|
| 1309 |
+
|
| 1310 |
+
*d5p = z13 + z2; // phase 6
|
| 1311 |
+
*d3p = z13 - z2;
|
| 1312 |
+
*d1p = z11 + z4;
|
| 1313 |
+
*d7p = z11 - z4;
|
| 1314 |
+
|
| 1315 |
+
*d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6;
|
| 1316 |
+
}
|
| 1317 |
+
|
| 1318 |
+
static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) {
|
| 1319 |
+
int tmp1 = val < 0 ? -val : val;
|
| 1320 |
+
val = val < 0 ? val-1 : val;
|
| 1321 |
+
bits[1] = 1;
|
| 1322 |
+
while(tmp1 >>= 1) {
|
| 1323 |
+
++bits[1];
|
| 1324 |
+
}
|
| 1325 |
+
bits[0] = val & ((1<<bits[1])-1);
|
| 1326 |
+
}
|
| 1327 |
+
|
| 1328 |
+
static int stbiw__jpg_processDU(stbi__write_context *s, int *bitBuf, int *bitCnt, float *CDU, int du_stride, float *fdtbl, int DC, const unsigned short HTDC[256][2], const unsigned short HTAC[256][2]) {
|
| 1329 |
+
const unsigned short EOB[2] = { HTAC[0x00][0], HTAC[0x00][1] };
|
| 1330 |
+
const unsigned short M16zeroes[2] = { HTAC[0xF0][0], HTAC[0xF0][1] };
|
| 1331 |
+
int dataOff, i, j, n, diff, end0pos, x, y;
|
| 1332 |
+
int DU[64];
|
| 1333 |
+
|
| 1334 |
+
// DCT rows
|
| 1335 |
+
for(dataOff=0, n=du_stride*8; dataOff<n; dataOff+=du_stride) {
|
| 1336 |
+
stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+1], &CDU[dataOff+2], &CDU[dataOff+3], &CDU[dataOff+4], &CDU[dataOff+5], &CDU[dataOff+6], &CDU[dataOff+7]);
|
| 1337 |
+
}
|
| 1338 |
+
// DCT columns
|
| 1339 |
+
for(dataOff=0; dataOff<8; ++dataOff) {
|
| 1340 |
+
stbiw__jpg_DCT(&CDU[dataOff], &CDU[dataOff+du_stride], &CDU[dataOff+du_stride*2], &CDU[dataOff+du_stride*3], &CDU[dataOff+du_stride*4],
|
| 1341 |
+
&CDU[dataOff+du_stride*5], &CDU[dataOff+du_stride*6], &CDU[dataOff+du_stride*7]);
|
| 1342 |
+
}
|
| 1343 |
+
// Quantize/descale/zigzag the coefficients
|
| 1344 |
+
for(y = 0, j=0; y < 8; ++y) {
|
| 1345 |
+
for(x = 0; x < 8; ++x,++j) {
|
| 1346 |
+
float v;
|
| 1347 |
+
i = y*du_stride+x;
|
| 1348 |
+
v = CDU[i]*fdtbl[j];
|
| 1349 |
+
// DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? ceilf(v - 0.5f) : floorf(v + 0.5f));
|
| 1350 |
+
// ceilf() and floorf() are C99, not C89, but I /think/ they're not needed here anyway?
|
| 1351 |
+
DU[stbiw__jpg_ZigZag[j]] = (int)(v < 0 ? v - 0.5f : v + 0.5f);
|
| 1352 |
+
}
|
| 1353 |
+
}
|
| 1354 |
+
|
| 1355 |
+
// Encode DC
|
| 1356 |
+
diff = DU[0] - DC;
|
| 1357 |
+
if (diff == 0) {
|
| 1358 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[0]);
|
| 1359 |
+
} else {
|
| 1360 |
+
unsigned short bits[2];
|
| 1361 |
+
stbiw__jpg_calcBits(diff, bits);
|
| 1362 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTDC[bits[1]]);
|
| 1363 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);
|
| 1364 |
+
}
|
| 1365 |
+
// Encode ACs
|
| 1366 |
+
end0pos = 63;
|
| 1367 |
+
for(; (end0pos>0)&&(DU[end0pos]==0); --end0pos) {
|
| 1368 |
+
}
|
| 1369 |
+
// end0pos = first element in reverse order !=0
|
| 1370 |
+
if(end0pos == 0) {
|
| 1371 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);
|
| 1372 |
+
return DU[0];
|
| 1373 |
+
}
|
| 1374 |
+
for(i = 1; i <= end0pos; ++i) {
|
| 1375 |
+
int startpos = i;
|
| 1376 |
+
int nrzeroes;
|
| 1377 |
+
unsigned short bits[2];
|
| 1378 |
+
for (; DU[i]==0 && i<=end0pos; ++i) {
|
| 1379 |
+
}
|
| 1380 |
+
nrzeroes = i-startpos;
|
| 1381 |
+
if ( nrzeroes >= 16 ) {
|
| 1382 |
+
int lng = nrzeroes>>4;
|
| 1383 |
+
int nrmarker;
|
| 1384 |
+
for (nrmarker=1; nrmarker <= lng; ++nrmarker)
|
| 1385 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes);
|
| 1386 |
+
nrzeroes &= 15;
|
| 1387 |
+
}
|
| 1388 |
+
stbiw__jpg_calcBits(DU[i], bits);
|
| 1389 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]);
|
| 1390 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits);
|
| 1391 |
+
}
|
| 1392 |
+
if(end0pos != 63) {
|
| 1393 |
+
stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB);
|
| 1394 |
+
}
|
| 1395 |
+
return DU[0];
|
| 1396 |
+
}
|
| 1397 |
+
|
| 1398 |
+
static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) {
|
| 1399 |
+
// Constants that don't pollute global namespace
|
| 1400 |
+
static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0};
|
| 1401 |
+
static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
|
| 1402 |
+
static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d};
|
| 1403 |
+
static const unsigned char std_ac_luminance_values[] = {
|
| 1404 |
+
0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08,
|
| 1405 |
+
0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28,
|
| 1406 |
+
0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59,
|
| 1407 |
+
0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89,
|
| 1408 |
+
0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6,
|
| 1409 |
+
0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2,
|
| 1410 |
+
0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa
|
| 1411 |
+
};
|
| 1412 |
+
static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0};
|
| 1413 |
+
static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11};
|
| 1414 |
+
static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77};
|
| 1415 |
+
static const unsigned char std_ac_chrominance_values[] = {
|
| 1416 |
+
0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91,
|
| 1417 |
+
0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26,
|
| 1418 |
+
0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,
|
| 1419 |
+
0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87,
|
| 1420 |
+
0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,
|
| 1421 |
+
0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,
|
| 1422 |
+
0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa
|
| 1423 |
+
};
|
| 1424 |
+
// Huffman tables
|
| 1425 |
+
static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}};
|
| 1426 |
+
static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}};
|
| 1427 |
+
static const unsigned short YAC_HT[256][2] = {
|
| 1428 |
+
{10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1429 |
+
{12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1430 |
+
{28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1431 |
+
{58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1432 |
+
{59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1433 |
+
{122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1434 |
+
{123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1435 |
+
{250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1436 |
+
{504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1437 |
+
{505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1438 |
+
{506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1439 |
+
{1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1440 |
+
{1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1441 |
+
{2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1442 |
+
{65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1443 |
+
{2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}
|
| 1444 |
+
};
|
| 1445 |
+
static const unsigned short UVAC_HT[256][2] = {
|
| 1446 |
+
{0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1447 |
+
{11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1448 |
+
{26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1449 |
+
{27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1450 |
+
{58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1451 |
+
{59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1452 |
+
{121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1453 |
+
{122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1454 |
+
{249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1455 |
+
{503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1456 |
+
{504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1457 |
+
{505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1458 |
+
{506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1459 |
+
{2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1460 |
+
{16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0},
|
| 1461 |
+
{1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0}
|
| 1462 |
+
};
|
| 1463 |
+
static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22,
|
| 1464 |
+
37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99};
|
| 1465 |
+
static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99,
|
| 1466 |
+
99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99};
|
| 1467 |
+
static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f,
|
| 1468 |
+
1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f };
|
| 1469 |
+
|
| 1470 |
+
int row, col, i, k, subsample;
|
| 1471 |
+
float fdtbl_Y[64], fdtbl_UV[64];
|
| 1472 |
+
unsigned char YTable[64], UVTable[64];
|
| 1473 |
+
|
| 1474 |
+
if(!data || !width || !height || comp > 4 || comp < 1) {
|
| 1475 |
+
return 0;
|
| 1476 |
+
}
|
| 1477 |
+
|
| 1478 |
+
quality = quality ? quality : 90;
|
| 1479 |
+
subsample = quality <= 90 ? 1 : 0;
|
| 1480 |
+
quality = quality < 1 ? 1 : quality > 100 ? 100 : quality;
|
| 1481 |
+
quality = quality < 50 ? 5000 / quality : 200 - quality * 2;
|
| 1482 |
+
|
| 1483 |
+
for(i = 0; i < 64; ++i) {
|
| 1484 |
+
int uvti, yti = (YQT[i]*quality+50)/100;
|
| 1485 |
+
YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti);
|
| 1486 |
+
uvti = (UVQT[i]*quality+50)/100;
|
| 1487 |
+
UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti);
|
| 1488 |
+
}
|
| 1489 |
+
|
| 1490 |
+
for(row = 0, k = 0; row < 8; ++row) {
|
| 1491 |
+
for(col = 0; col < 8; ++col, ++k) {
|
| 1492 |
+
fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);
|
| 1493 |
+
fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]);
|
| 1494 |
+
}
|
| 1495 |
+
}
|
| 1496 |
+
|
| 1497 |
+
// Write Headers
|
| 1498 |
+
{
|
| 1499 |
+
static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 };
|
| 1500 |
+
static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 };
|
| 1501 |
+
const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width),
|
| 1502 |
+
3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 };
|
| 1503 |
+
s->func(s->context, (void*)head0, sizeof(head0));
|
| 1504 |
+
s->func(s->context, (void*)YTable, sizeof(YTable));
|
| 1505 |
+
stbiw__putc(s, 1);
|
| 1506 |
+
s->func(s->context, UVTable, sizeof(UVTable));
|
| 1507 |
+
s->func(s->context, (void*)head1, sizeof(head1));
|
| 1508 |
+
s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1);
|
| 1509 |
+
s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values));
|
| 1510 |
+
stbiw__putc(s, 0x10); // HTYACinfo
|
| 1511 |
+
s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1);
|
| 1512 |
+
s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values));
|
| 1513 |
+
stbiw__putc(s, 1); // HTUDCinfo
|
| 1514 |
+
s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1);
|
| 1515 |
+
s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values));
|
| 1516 |
+
stbiw__putc(s, 0x11); // HTUACinfo
|
| 1517 |
+
s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1);
|
| 1518 |
+
s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values));
|
| 1519 |
+
s->func(s->context, (void*)head2, sizeof(head2));
|
| 1520 |
+
}
|
| 1521 |
+
|
| 1522 |
+
// Encode 8x8 macroblocks
|
| 1523 |
+
{
|
| 1524 |
+
static const unsigned short fillBits[] = {0x7F, 7};
|
| 1525 |
+
int DCY=0, DCU=0, DCV=0;
|
| 1526 |
+
int bitBuf=0, bitCnt=0;
|
| 1527 |
+
// comp == 2 is grey+alpha (alpha is ignored)
|
| 1528 |
+
int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0;
|
| 1529 |
+
const unsigned char *dataR = (const unsigned char *)data;
|
| 1530 |
+
const unsigned char *dataG = dataR + ofsG;
|
| 1531 |
+
const unsigned char *dataB = dataR + ofsB;
|
| 1532 |
+
int x, y, pos;
|
| 1533 |
+
if(subsample) {
|
| 1534 |
+
for(y = 0; y < height; y += 16) {
|
| 1535 |
+
for(x = 0; x < width; x += 16) {
|
| 1536 |
+
float Y[256], U[256], V[256];
|
| 1537 |
+
for(row = y, pos = 0; row < y+16; ++row) {
|
| 1538 |
+
// row >= height => use last input row
|
| 1539 |
+
int clamped_row = (row < height) ? row : height - 1;
|
| 1540 |
+
int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;
|
| 1541 |
+
for(col = x; col < x+16; ++col, ++pos) {
|
| 1542 |
+
// if col >= width => use pixel from last input column
|
| 1543 |
+
int p = base_p + ((col < width) ? col : (width-1))*comp;
|
| 1544 |
+
float r = dataR[p], g = dataG[p], b = dataB[p];
|
| 1545 |
+
Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;
|
| 1546 |
+
U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;
|
| 1547 |
+
V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;
|
| 1548 |
+
}
|
| 1549 |
+
}
|
| 1550 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
| 1551 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
| 1552 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
| 1553 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
| 1554 |
+
|
| 1555 |
+
// subsample U,V
|
| 1556 |
+
{
|
| 1557 |
+
float subU[64], subV[64];
|
| 1558 |
+
int yy, xx;
|
| 1559 |
+
for(yy = 0, pos = 0; yy < 8; ++yy) {
|
| 1560 |
+
for(xx = 0; xx < 8; ++xx, ++pos) {
|
| 1561 |
+
int j = yy*32+xx*2;
|
| 1562 |
+
subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f;
|
| 1563 |
+
subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f;
|
| 1564 |
+
}
|
| 1565 |
+
}
|
| 1566 |
+
DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);
|
| 1567 |
+
DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);
|
| 1568 |
+
}
|
| 1569 |
+
}
|
| 1570 |
+
}
|
| 1571 |
+
} else {
|
| 1572 |
+
for(y = 0; y < height; y += 8) {
|
| 1573 |
+
for(x = 0; x < width; x += 8) {
|
| 1574 |
+
float Y[64], U[64], V[64];
|
| 1575 |
+
for(row = y, pos = 0; row < y+8; ++row) {
|
| 1576 |
+
// row >= height => use last input row
|
| 1577 |
+
int clamped_row = (row < height) ? row : height - 1;
|
| 1578 |
+
int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp;
|
| 1579 |
+
for(col = x; col < x+8; ++col, ++pos) {
|
| 1580 |
+
// if col >= width => use pixel from last input column
|
| 1581 |
+
int p = base_p + ((col < width) ? col : (width-1))*comp;
|
| 1582 |
+
float r = dataR[p], g = dataG[p], b = dataB[p];
|
| 1583 |
+
Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128;
|
| 1584 |
+
U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b;
|
| 1585 |
+
V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b;
|
| 1586 |
+
}
|
| 1587 |
+
}
|
| 1588 |
+
|
| 1589 |
+
DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT);
|
| 1590 |
+
DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT);
|
| 1591 |
+
DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT);
|
| 1592 |
+
}
|
| 1593 |
+
}
|
| 1594 |
+
}
|
| 1595 |
+
|
| 1596 |
+
// Do the bit alignment of the EOI marker
|
| 1597 |
+
stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits);
|
| 1598 |
+
}
|
| 1599 |
+
|
| 1600 |
+
// EOI
|
| 1601 |
+
stbiw__putc(s, 0xFF);
|
| 1602 |
+
stbiw__putc(s, 0xD9);
|
| 1603 |
+
|
| 1604 |
+
return 1;
|
| 1605 |
+
}
|
| 1606 |
+
|
| 1607 |
+
STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality)
|
| 1608 |
+
{
|
| 1609 |
+
stbi__write_context s = { 0 };
|
| 1610 |
+
stbi__start_write_callbacks(&s, func, context);
|
| 1611 |
+
return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality);
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
|
| 1615 |
+
#ifndef STBI_WRITE_NO_STDIO
|
| 1616 |
+
STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality)
|
| 1617 |
+
{
|
| 1618 |
+
stbi__write_context s = { 0 };
|
| 1619 |
+
if (stbi__start_write_file(&s,filename)) {
|
| 1620 |
+
int r = stbi_write_jpg_core(&s, x, y, comp, data, quality);
|
| 1621 |
+
stbi__end_write_file(&s);
|
| 1622 |
+
return r;
|
| 1623 |
+
} else
|
| 1624 |
+
return 0;
|
| 1625 |
+
}
|
| 1626 |
+
#endif
|
| 1627 |
+
|
| 1628 |
+
#endif // STB_IMAGE_WRITE_IMPLEMENTATION
|
| 1629 |
+
|
| 1630 |
+
/* Revision history
|
| 1631 |
+
1.16 (2021-07-11)
|
| 1632 |
+
make Deflate code emit uncompressed blocks when it would otherwise expand
|
| 1633 |
+
support writing BMPs with alpha channel
|
| 1634 |
+
1.15 (2020-07-13) unknown
|
| 1635 |
+
1.14 (2020-02-02) updated JPEG writer to downsample chroma channels
|
| 1636 |
+
1.13
|
| 1637 |
+
1.12
|
| 1638 |
+
1.11 (2019-08-11)
|
| 1639 |
+
|
| 1640 |
+
1.10 (2019-02-07)
|
| 1641 |
+
support utf8 filenames in Windows; fix warnings and platform ifdefs
|
| 1642 |
+
1.09 (2018-02-11)
|
| 1643 |
+
fix typo in zlib quality API, improve STB_I_W_STATIC in C++
|
| 1644 |
+
1.08 (2018-01-29)
|
| 1645 |
+
add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter
|
| 1646 |
+
1.07 (2017-07-24)
|
| 1647 |
+
doc fix
|
| 1648 |
+
1.06 (2017-07-23)
|
| 1649 |
+
writing JPEG (using Jon Olick's code)
|
| 1650 |
+
1.05 ???
|
| 1651 |
+
1.04 (2017-03-03)
|
| 1652 |
+
monochrome BMP expansion
|
| 1653 |
+
1.03 ???
|
| 1654 |
+
1.02 (2016-04-02)
|
| 1655 |
+
avoid allocating large structures on the stack
|
| 1656 |
+
1.01 (2016-01-16)
|
| 1657 |
+
STBIW_REALLOC_SIZED: support allocators with no realloc support
|
| 1658 |
+
avoid race-condition in crc initialization
|
| 1659 |
+
minor compile issues
|
| 1660 |
+
1.00 (2015-09-14)
|
| 1661 |
+
installable file IO function
|
| 1662 |
+
0.99 (2015-09-13)
|
| 1663 |
+
warning fixes; TGA rle support
|
| 1664 |
+
0.98 (2015-04-08)
|
| 1665 |
+
added STBIW_MALLOC, STBIW_ASSERT etc
|
| 1666 |
+
0.97 (2015-01-18)
|
| 1667 |
+
fixed HDR asserts, rewrote HDR rle logic
|
| 1668 |
+
0.96 (2015-01-17)
|
| 1669 |
+
add HDR output
|
| 1670 |
+
fix monochrome BMP
|
| 1671 |
+
0.95 (2014-08-17)
|
| 1672 |
+
add monochrome TGA output
|
| 1673 |
+
0.94 (2014-05-31)
|
| 1674 |
+
rename private functions to avoid conflicts with stb_image.h
|
| 1675 |
+
0.93 (2014-05-27)
|
| 1676 |
+
warning fixes
|
| 1677 |
+
0.92 (2010-08-01)
|
| 1678 |
+
casts to unsigned char to fix warnings
|
| 1679 |
+
0.91 (2010-07-17)
|
| 1680 |
+
first public release
|
| 1681 |
+
0.90 first internal release
|
| 1682 |
+
*/
|
| 1683 |
+
|
| 1684 |
+
/*
|
| 1685 |
+
------------------------------------------------------------------------------
|
| 1686 |
+
This software is available under 2 licenses -- choose whichever you prefer.
|
| 1687 |
+
------------------------------------------------------------------------------
|
| 1688 |
+
ALTERNATIVE A - MIT License
|
| 1689 |
+
Copyright (c) 2017 Sean Barrett
|
| 1690 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
| 1691 |
+
this software and associated documentation files (the "Software"), to deal in
|
| 1692 |
+
the Software without restriction, including without limitation the rights to
|
| 1693 |
+
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
| 1694 |
+
of the Software, and to permit persons to whom the Software is furnished to do
|
| 1695 |
+
so, subject to the following conditions:
|
| 1696 |
+
The above copyright notice and this permission notice shall be included in all
|
| 1697 |
+
copies or substantial portions of the Software.
|
| 1698 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 1699 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 1700 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 1701 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 1702 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 1703 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 1704 |
+
SOFTWARE.
|
| 1705 |
+
------------------------------------------------------------------------------
|
| 1706 |
+
ALTERNATIVE B - Public Domain (www.unlicense.org)
|
| 1707 |
+
This is free and unencumbered software released into the public domain.
|
| 1708 |
+
Anyone is free to copy, modify, publish, use, compile, sell, or distribute this
|
| 1709 |
+
software, either in source code form or as a compiled binary, for any purpose,
|
| 1710 |
+
commercial or non-commercial, and by any means.
|
| 1711 |
+
In jurisdictions that recognize copyright laws, the author or authors of this
|
| 1712 |
+
software dedicate any and all copyright interest in the software to the public
|
| 1713 |
+
domain. We make this dedication for the benefit of the public at large and to
|
| 1714 |
+
the detriment of our heirs and successors. We intend this dedication to be an
|
| 1715 |
+
overt act of relinquishment in perpetuity of all present and future rights to
|
| 1716 |
+
this software under copyright law.
|
| 1717 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 1718 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 1719 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 1720 |
+
AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
| 1721 |
+
ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
| 1722 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 1723 |
+
------------------------------------------------------------------------------
|
| 1724 |
+
*/
|
gaussiancity/extensions/grid_encoder/__init__.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: __init__.py
|
| 4 |
+
# @Author: Jiaxiang Tang (@ashawkey)
|
| 5 |
+
# @Date: 2023-04-15 10:39:28
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2023-04-15 13:08:46
|
| 8 |
+
# @Email: ashawkey1999@gmail.com
|
| 9 |
+
# @Ref: https://github.com/ashawkey/torch-ngp
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
import grid_encoder_ext
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GridEncoderFunction(torch.autograd.Function):
|
| 19 |
+
@staticmethod
|
| 20 |
+
def forward(
|
| 21 |
+
ctx,
|
| 22 |
+
inputs,
|
| 23 |
+
embeddings,
|
| 24 |
+
offsets,
|
| 25 |
+
per_level_scale,
|
| 26 |
+
base_resolution,
|
| 27 |
+
calc_grad_inputs=False,
|
| 28 |
+
gridtype=0,
|
| 29 |
+
align_corners=False,
|
| 30 |
+
):
|
| 31 |
+
# inputs: [B, D], float in [0, 1]
|
| 32 |
+
# embeddings: [sO, C], float
|
| 33 |
+
# offsets: [L + 1], int
|
| 34 |
+
# RETURN: [B, F], float
|
| 35 |
+
inputs = inputs.contiguous()
|
| 36 |
+
# batch size, coord dim
|
| 37 |
+
B, D = inputs.shape
|
| 38 |
+
# level
|
| 39 |
+
L = offsets.shape[0] - 1
|
| 40 |
+
# embedding dim for each level
|
| 41 |
+
C = embeddings.shape[1]
|
| 42 |
+
# resolution multiplier at each level, apply log2 for later CUDA exp2f
|
| 43 |
+
S = math.log2(per_level_scale)
|
| 44 |
+
# base resolution
|
| 45 |
+
H = base_resolution
|
| 46 |
+
# L first, optimize cache for cuda kernel, but needs an extra permute later
|
| 47 |
+
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
|
| 48 |
+
|
| 49 |
+
if calc_grad_inputs:
|
| 50 |
+
dy_dx = torch.empty(
|
| 51 |
+
B, L * D * C, device=inputs.device, dtype=embeddings.dtype
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
dy_dx = torch.empty(
|
| 55 |
+
1, device=inputs.device, dtype=embeddings.dtype
|
| 56 |
+
) # placeholder... TODO: a better way?
|
| 57 |
+
|
| 58 |
+
grid_encoder_ext.forward(
|
| 59 |
+
inputs,
|
| 60 |
+
embeddings,
|
| 61 |
+
offsets,
|
| 62 |
+
outputs,
|
| 63 |
+
B,
|
| 64 |
+
D,
|
| 65 |
+
C,
|
| 66 |
+
L,
|
| 67 |
+
S,
|
| 68 |
+
H,
|
| 69 |
+
calc_grad_inputs,
|
| 70 |
+
dy_dx,
|
| 71 |
+
gridtype,
|
| 72 |
+
align_corners,
|
| 73 |
+
)
|
| 74 |
+
# permute back to [B, L * C]
|
| 75 |
+
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
|
| 76 |
+
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
|
| 77 |
+
ctx.dims = [B, D, C, L, S, H, gridtype]
|
| 78 |
+
ctx.calc_grad_inputs = calc_grad_inputs
|
| 79 |
+
ctx.align_corners = align_corners
|
| 80 |
+
|
| 81 |
+
return outputs
|
| 82 |
+
|
| 83 |
+
@staticmethod
|
| 84 |
+
def backward(ctx, grad):
|
| 85 |
+
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
|
| 86 |
+
B, D, C, L, S, H, gridtype = ctx.dims
|
| 87 |
+
calc_grad_inputs = ctx.calc_grad_inputs
|
| 88 |
+
align_corners = ctx.align_corners
|
| 89 |
+
|
| 90 |
+
# grad: [B, L * C] --> [L, B, C]
|
| 91 |
+
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
|
| 92 |
+
grad_embeddings = torch.zeros_like(embeddings)
|
| 93 |
+
|
| 94 |
+
if calc_grad_inputs:
|
| 95 |
+
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
|
| 96 |
+
else:
|
| 97 |
+
grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
|
| 98 |
+
|
| 99 |
+
grid_encoder_ext.backward(
|
| 100 |
+
grad,
|
| 101 |
+
inputs,
|
| 102 |
+
embeddings,
|
| 103 |
+
offsets,
|
| 104 |
+
grad_embeddings,
|
| 105 |
+
B,
|
| 106 |
+
D,
|
| 107 |
+
C,
|
| 108 |
+
L,
|
| 109 |
+
S,
|
| 110 |
+
H,
|
| 111 |
+
calc_grad_inputs,
|
| 112 |
+
dy_dx,
|
| 113 |
+
grad_inputs,
|
| 114 |
+
gridtype,
|
| 115 |
+
align_corners,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if calc_grad_inputs:
|
| 119 |
+
grad_inputs = grad_inputs.to(inputs.dtype)
|
| 120 |
+
return grad_inputs, grad_embeddings, None, None, None, None, None, None
|
| 121 |
+
else:
|
| 122 |
+
return None, grad_embeddings, None, None, None, None, None, None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class GridEncoder(torch.nn.Module):
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
in_channels,
|
| 129 |
+
n_levels,
|
| 130 |
+
lvl_channels,
|
| 131 |
+
desired_resolution,
|
| 132 |
+
per_level_scale=2,
|
| 133 |
+
base_resolution=16,
|
| 134 |
+
log2_hashmap_size=19,
|
| 135 |
+
gridtype="hash",
|
| 136 |
+
align_corners=False,
|
| 137 |
+
):
|
| 138 |
+
super(GridEncoder, self).__init__()
|
| 139 |
+
self.in_channels = in_channels
|
| 140 |
+
self.n_levels = n_levels # num levels, each level multiply resolution by 2
|
| 141 |
+
self.lvl_channels = lvl_channels # encode channels per level
|
| 142 |
+
self.per_level_scale = 2 ** (
|
| 143 |
+
math.log2(desired_resolution / base_resolution) / (n_levels - 1)
|
| 144 |
+
)
|
| 145 |
+
self.log2_hashmap_size = log2_hashmap_size
|
| 146 |
+
self.base_resolution = base_resolution
|
| 147 |
+
self.output_dim = n_levels * lvl_channels
|
| 148 |
+
self.gridtype = gridtype
|
| 149 |
+
self.gridtype_id = 0 if gridtype == "hash" else 1
|
| 150 |
+
self.align_corners = align_corners
|
| 151 |
+
|
| 152 |
+
# allocate parameters
|
| 153 |
+
offsets = []
|
| 154 |
+
offset = 0
|
| 155 |
+
self.max_params = 2**log2_hashmap_size
|
| 156 |
+
for i in range(n_levels):
|
| 157 |
+
resolution = int(math.ceil(base_resolution * per_level_scale**i))
|
| 158 |
+
params_in_level = min(
|
| 159 |
+
self.max_params,
|
| 160 |
+
(resolution if align_corners else resolution + 1) ** in_channels,
|
| 161 |
+
) # limit max number
|
| 162 |
+
params_in_level = int(math.ceil(params_in_level / 8) * 8) # make divisible
|
| 163 |
+
offsets.append(offset)
|
| 164 |
+
offset += params_in_level
|
| 165 |
+
|
| 166 |
+
offsets.append(offset)
|
| 167 |
+
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
|
| 168 |
+
self.register_buffer("offsets", offsets)
|
| 169 |
+
|
| 170 |
+
self.n_params = offsets[-1] * lvl_channels
|
| 171 |
+
self.embeddings = torch.nn.Parameter(torch.empty(offset, lvl_channels))
|
| 172 |
+
self._init_weights()
|
| 173 |
+
|
| 174 |
+
def _init_weights(self):
|
| 175 |
+
self.embeddings.data.uniform_(-1e-4, 1e-4)
|
| 176 |
+
|
| 177 |
+
def forward(self, inputs, bound=1):
|
| 178 |
+
# inputs: [..., in_channels], normalized real world positions in [-bound, bound]
|
| 179 |
+
# return: [..., n_levels * lvl_channels]
|
| 180 |
+
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
|
| 181 |
+
prefix_shape = list(inputs.shape[:-1])
|
| 182 |
+
inputs = inputs.view(-1, self.in_channels)
|
| 183 |
+
outputs = GridEncoderFunction.apply(
|
| 184 |
+
inputs,
|
| 185 |
+
self.embeddings,
|
| 186 |
+
self.offsets,
|
| 187 |
+
self.per_level_scale,
|
| 188 |
+
self.base_resolution,
|
| 189 |
+
inputs.requires_grad,
|
| 190 |
+
self.gridtype_id,
|
| 191 |
+
self.align_corners,
|
| 192 |
+
)
|
| 193 |
+
return outputs.view(prefix_shape + [self.output_dim])
|
gaussiancity/extensions/grid_encoder/bindings.cpp
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: grid_encoder_ext_cuda.cpp
|
| 3 |
+
* @Author: Jiaxiang Tang (@ashawkey)
|
| 4 |
+
* @Date: 2023-04-15 10:39:17
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2023-04-15 11:01:32
|
| 7 |
+
* @Email: ashawkey1999@gmail.com
|
| 8 |
+
* @Ref: https://github.com/ashawkey/torch-ngp
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <stdint.h>
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
#include <torch/torch.h>
|
| 14 |
+
|
| 15 |
+
// inputs: [B, D], float, in [0, 1]
|
| 16 |
+
// embeddings: [sO, C], float
|
| 17 |
+
// offsets: [L + 1], uint32_t
|
| 18 |
+
// outputs: [B, L * C], float
|
| 19 |
+
// H: base resolution
|
| 20 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings,
|
| 21 |
+
const at::Tensor offsets, at::Tensor outputs,
|
| 22 |
+
const uint32_t B, const uint32_t D, const uint32_t C,
|
| 23 |
+
const uint32_t L, const float S, const uint32_t H,
|
| 24 |
+
const bool calc_grad_inputs, at::Tensor dy_dx,
|
| 25 |
+
const uint32_t gridtype, const bool align_corners);
|
| 26 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs,
|
| 27 |
+
const at::Tensor embeddings, const at::Tensor offsets,
|
| 28 |
+
at::Tensor grad_embeddings, const uint32_t B,
|
| 29 |
+
const uint32_t D, const uint32_t C, const uint32_t L,
|
| 30 |
+
const float S, const uint32_t H,
|
| 31 |
+
const bool calc_grad_inputs, const at::Tensor dy_dx,
|
| 32 |
+
at::Tensor grad_inputs, const uint32_t gridtype,
|
| 33 |
+
const bool align_corners);
|
| 34 |
+
|
| 35 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 36 |
+
m.def("forward", &grid_encode_forward,
|
| 37 |
+
"grid_encode_forward (CUDA)");
|
| 38 |
+
m.def("backward", &grid_encode_backward,
|
| 39 |
+
"grid_encode_backward (CUDA)");
|
| 40 |
+
}
|
gaussiancity/extensions/grid_encoder/grid_encoder_ext.cu
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: grid_encoder_ext.cu
|
| 3 |
+
* @Author: Jiaxiang Tang (@ashawkey)
|
| 4 |
+
* @Date: 2023-04-15 10:43:16
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2023-04-29 11:47:54
|
| 7 |
+
* @Email: ashawkey1999@gmail.com
|
| 8 |
+
* @Ref: https://github.com/ashawkey/torch-ngp
|
| 9 |
+
*/
|
| 10 |
+
|
| 11 |
+
#include <cuda.h>
|
| 12 |
+
#include <cuda_fp16.h>
|
| 13 |
+
#include <cuda_runtime.h>
|
| 14 |
+
|
| 15 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 16 |
+
#include <torch/torch.h>
|
| 17 |
+
|
| 18 |
+
#include <algorithm>
|
| 19 |
+
#include <stdexcept>
|
| 20 |
+
|
| 21 |
+
#include <cstdio>
|
| 22 |
+
#include <stdint.h>
|
| 23 |
+
|
| 24 |
+
#define CHECK_CUDA(x) \
|
| 25 |
+
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
| 26 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 27 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
| 28 |
+
#define CHECK_IS_INT(x) \
|
| 29 |
+
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
|
| 30 |
+
#x " must be an int tensor")
|
| 31 |
+
#define CHECK_IS_FLOATING(x) \
|
| 32 |
+
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \
|
| 33 |
+
x.scalar_type() == at::ScalarType::Half || \
|
| 34 |
+
x.scalar_type() == at::ScalarType::Double, \
|
| 35 |
+
#x " must be a floating tensor")
|
| 36 |
+
|
| 37 |
+
// just for compatability of half precision in
|
| 38 |
+
// AT_DISPATCH_FLOATING_TYPES_AND_HALF...
|
| 39 |
+
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
|
| 40 |
+
// requires CUDA >= 10 and ARCH >= 70
|
| 41 |
+
// this is very slow compared to float or __half2, and never used.
|
| 42 |
+
// return atomicAdd(reinterpret_cast<__half*>(address), val);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
template <typename T>
|
| 46 |
+
static inline __host__ __device__ T div_round_up(T val, T divisor) {
|
| 47 |
+
return (val + divisor - 1) / divisor;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
template <uint32_t D>
|
| 51 |
+
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
|
| 52 |
+
static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
|
| 53 |
+
|
| 54 |
+
// While 1 is technically not a good prime for hashing (or a prime at all), it
|
| 55 |
+
// helps memory coherence and is sufficient for our use case of obtaining a
|
| 56 |
+
// uniformly colliding index from high-dimensional coordinates.
|
| 57 |
+
constexpr uint32_t primes[7] = {1, 2654435761, 805459861, 3674653429,
|
| 58 |
+
2097192037, 1434869437, 2165219737};
|
| 59 |
+
|
| 60 |
+
uint32_t result = 0;
|
| 61 |
+
#pragma unroll
|
| 62 |
+
for (uint32_t i = 0; i < D; ++i) {
|
| 63 |
+
result ^= pos_grid[i] * primes[i];
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
return result;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
template <uint32_t D, uint32_t C>
|
| 70 |
+
__device__ uint32_t get_grid_index(const uint32_t gridtype,
|
| 71 |
+
const bool align_corners, const uint32_t ch,
|
| 72 |
+
const uint32_t hashmap_size,
|
| 73 |
+
const uint32_t resolution,
|
| 74 |
+
const uint32_t pos_grid[D]) {
|
| 75 |
+
uint32_t stride = 1;
|
| 76 |
+
uint32_t index = 0;
|
| 77 |
+
|
| 78 |
+
#pragma unroll
|
| 79 |
+
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
|
| 80 |
+
index += pos_grid[d] * stride;
|
| 81 |
+
stride *= align_corners ? resolution : (resolution + 1);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
// NOTE: for NeRF, the hash is in fact not necessary. Check
|
| 85 |
+
// https://github.com/NVlabs/instant-ngp/issues/97. gridtype: 0 == hash, 1 ==
|
| 86 |
+
// tiled
|
| 87 |
+
if (gridtype == 0 && stride > hashmap_size) {
|
| 88 |
+
index = fast_hash<D>(pos_grid);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return (index % hashmap_size) * C + ch;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
| 95 |
+
__global__ void
|
| 96 |
+
kernel_grid(const float *__restrict__ inputs, const scalar_t *__restrict__ grid,
|
| 97 |
+
const int *__restrict__ offsets, scalar_t *__restrict__ outputs,
|
| 98 |
+
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
|
| 99 |
+
const bool calc_grad_inputs, scalar_t *__restrict__ dy_dx,
|
| 100 |
+
const uint32_t gridtype, const bool align_corners) {
|
| 101 |
+
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
|
| 102 |
+
|
| 103 |
+
if (b >= B)
|
| 104 |
+
return;
|
| 105 |
+
|
| 106 |
+
const uint32_t level = blockIdx.y;
|
| 107 |
+
|
| 108 |
+
// locate
|
| 109 |
+
grid += (uint32_t)offsets[level] * C;
|
| 110 |
+
inputs += b * D;
|
| 111 |
+
outputs += level * B * C + b * C;
|
| 112 |
+
|
| 113 |
+
// check input range (should be in [0, 1])
|
| 114 |
+
bool flag_oob = false;
|
| 115 |
+
#pragma unroll
|
| 116 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 117 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
| 118 |
+
flag_oob = true;
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
// if input out of bound, just set output to 0
|
| 122 |
+
if (flag_oob) {
|
| 123 |
+
#pragma unroll
|
| 124 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 125 |
+
outputs[ch] = 0;
|
| 126 |
+
}
|
| 127 |
+
if (calc_grad_inputs) {
|
| 128 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
| 129 |
+
#pragma unroll
|
| 130 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 131 |
+
#pragma unroll
|
| 132 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 133 |
+
dy_dx[d * C + ch] = 0;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
return;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
| 141 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
| 142 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
| 143 |
+
|
| 144 |
+
// calculate coordinate
|
| 145 |
+
float pos[D];
|
| 146 |
+
uint32_t pos_grid[D];
|
| 147 |
+
|
| 148 |
+
#pragma unroll
|
| 149 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 150 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
| 151 |
+
pos_grid[d] = floorf(pos[d]);
|
| 152 |
+
pos[d] -= (float)pos_grid[d];
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1],
|
| 156 |
+
// pos_grid[0], pos_grid[1]);
|
| 157 |
+
|
| 158 |
+
// interpolate
|
| 159 |
+
scalar_t results[C] = {0}; // temp results in register
|
| 160 |
+
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
| 163 |
+
float w = 1;
|
| 164 |
+
uint32_t pos_grid_local[D];
|
| 165 |
+
|
| 166 |
+
#pragma unroll
|
| 167 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 168 |
+
if ((idx & (1 << d)) == 0) {
|
| 169 |
+
w *= 1 - pos[d];
|
| 170 |
+
pos_grid_local[d] = pos_grid[d];
|
| 171 |
+
} else {
|
| 172 |
+
w *= pos[d];
|
| 173 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
uint32_t index = get_grid_index<D, C>(
|
| 178 |
+
gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
|
| 179 |
+
|
| 180 |
+
// writing to register (fast)
|
| 181 |
+
#pragma unroll
|
| 182 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 183 |
+
results[ch] += w * grid[index + ch];
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
// printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx,
|
| 187 |
+
// index, w, grid[index]);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
// writing to global memory (slow)
|
| 191 |
+
#pragma unroll
|
| 192 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 193 |
+
outputs[ch] = results[ch];
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
// prepare dy_dx for calc_grad_inputs
|
| 197 |
+
// differentiable (soft) indexing:
|
| 198 |
+
// https://discuss.pytorch.org/t/differentiable-indexing/17647/9
|
| 199 |
+
if (calc_grad_inputs) {
|
| 200 |
+
|
| 201 |
+
dy_dx += b * D * L * C + level * D * C; // B L D C
|
| 202 |
+
|
| 203 |
+
#pragma unroll
|
| 204 |
+
for (uint32_t gd = 0; gd < D; gd++) {
|
| 205 |
+
|
| 206 |
+
scalar_t results_grad[C] = {0};
|
| 207 |
+
|
| 208 |
+
#pragma unroll
|
| 209 |
+
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
|
| 210 |
+
float w = scale;
|
| 211 |
+
uint32_t pos_grid_local[D];
|
| 212 |
+
|
| 213 |
+
#pragma unroll
|
| 214 |
+
for (uint32_t nd = 0; nd < D - 1; nd++) {
|
| 215 |
+
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
|
| 216 |
+
|
| 217 |
+
if ((idx & (1 << nd)) == 0) {
|
| 218 |
+
w *= 1 - pos[d];
|
| 219 |
+
pos_grid_local[d] = pos_grid[d];
|
| 220 |
+
} else {
|
| 221 |
+
w *= pos[d];
|
| 222 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
pos_grid_local[gd] = pos_grid[gd];
|
| 227 |
+
uint32_t index_left =
|
| 228 |
+
get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,
|
| 229 |
+
resolution, pos_grid_local);
|
| 230 |
+
pos_grid_local[gd] = pos_grid[gd] + 1;
|
| 231 |
+
uint32_t index_right =
|
| 232 |
+
get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,
|
| 233 |
+
resolution, pos_grid_local);
|
| 234 |
+
|
| 235 |
+
#pragma unroll
|
| 236 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 237 |
+
results_grad[ch] +=
|
| 238 |
+
w * (grid[index_right + ch] - grid[index_left + ch]);
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
#pragma unroll
|
| 243 |
+
for (uint32_t ch = 0; ch < C; ch++) {
|
| 244 |
+
dy_dx[gd * C + ch] = results_grad[ch];
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
|
| 251 |
+
__global__ void kernel_grid_backward(
|
| 252 |
+
const scalar_t *__restrict__ grad, const float *__restrict__ inputs,
|
| 253 |
+
const scalar_t *__restrict__ grid, const int *__restrict__ offsets,
|
| 254 |
+
scalar_t *__restrict__ grad_grid, const uint32_t B, const uint32_t L,
|
| 255 |
+
const float S, const uint32_t H, const uint32_t gridtype,
|
| 256 |
+
const bool align_corners) {
|
| 257 |
+
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
|
| 258 |
+
if (b >= B)
|
| 259 |
+
return;
|
| 260 |
+
|
| 261 |
+
const uint32_t level = blockIdx.y;
|
| 262 |
+
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
|
| 263 |
+
|
| 264 |
+
// locate
|
| 265 |
+
grad_grid += offsets[level] * C;
|
| 266 |
+
inputs += b * D;
|
| 267 |
+
grad += level * B * C + b * C + ch; // L, B, C
|
| 268 |
+
|
| 269 |
+
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
|
| 270 |
+
const float scale = exp2f(level * S) * H - 1.0f;
|
| 271 |
+
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
|
| 272 |
+
|
| 273 |
+
// check input range (should be in [0, 1])
|
| 274 |
+
#pragma unroll
|
| 275 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 276 |
+
if (inputs[d] < 0 || inputs[d] > 1) {
|
| 277 |
+
return; // grad is init as 0, so we simply return.
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
// calculate coordinate
|
| 282 |
+
float pos[D];
|
| 283 |
+
uint32_t pos_grid[D];
|
| 284 |
+
|
| 285 |
+
#pragma unroll
|
| 286 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 287 |
+
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
|
| 288 |
+
pos_grid[d] = floorf(pos[d]);
|
| 289 |
+
pos[d] -= (float)pos_grid[d];
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
scalar_t grad_cur[N_C] = {0}; // fetch to register
|
| 293 |
+
#pragma unroll
|
| 294 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
| 295 |
+
grad_cur[c] = grad[c];
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
// interpolate
|
| 299 |
+
#pragma unroll
|
| 300 |
+
for (uint32_t idx = 0; idx < (1 << D); idx++) {
|
| 301 |
+
float w = 1;
|
| 302 |
+
uint32_t pos_grid_local[D];
|
| 303 |
+
|
| 304 |
+
#pragma unroll
|
| 305 |
+
for (uint32_t d = 0; d < D; d++) {
|
| 306 |
+
if ((idx & (1 << d)) == 0) {
|
| 307 |
+
w *= 1 - pos[d];
|
| 308 |
+
pos_grid_local[d] = pos_grid[d];
|
| 309 |
+
} else {
|
| 310 |
+
w *= pos[d];
|
| 311 |
+
pos_grid_local[d] = pos_grid[d] + 1;
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
uint32_t index = get_grid_index<D, C>(
|
| 316 |
+
gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
|
| 317 |
+
|
| 318 |
+
// atomicAdd for __half is slow (especially for large values), so we use
|
| 319 |
+
// __half2 if N_C % 2 == 0
|
| 320 |
+
// TODO: use float which is better than __half, if N_C % 2 != 0
|
| 321 |
+
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
|
| 322 |
+
#pragma unroll
|
| 323 |
+
for (uint32_t c = 0; c < N_C; c += 2) {
|
| 324 |
+
// process two __half at once (by interpreting as a __half2)
|
| 325 |
+
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
|
| 326 |
+
atomicAdd((__half2 *)&grad_grid[index + c], v);
|
| 327 |
+
}
|
| 328 |
+
// float, or __half when N_C % 2 != 0 (which means C == 1)
|
| 329 |
+
} else {
|
| 330 |
+
#pragma unroll
|
| 331 |
+
for (uint32_t c = 0; c < N_C; c++) {
|
| 332 |
+
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
template <typename scalar_t, uint32_t D, uint32_t C>
|
| 339 |
+
__global__ void kernel_input_backward(const scalar_t *__restrict__ grad,
|
| 340 |
+
const scalar_t *__restrict__ dy_dx,
|
| 341 |
+
scalar_t *__restrict__ grad_inputs,
|
| 342 |
+
uint32_t B, uint32_t L) {
|
| 343 |
+
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
|
| 344 |
+
if (t >= B * D)
|
| 345 |
+
return;
|
| 346 |
+
|
| 347 |
+
const uint32_t b = t / D;
|
| 348 |
+
const uint32_t d = t - b * D;
|
| 349 |
+
|
| 350 |
+
dy_dx += b * L * D * C;
|
| 351 |
+
|
| 352 |
+
scalar_t result = 0;
|
| 353 |
+
|
| 354 |
+
#pragma unroll
|
| 355 |
+
for (int l = 0; l < L; l++) {
|
| 356 |
+
#pragma unroll
|
| 357 |
+
for (int ch = 0; ch < C; ch++) {
|
| 358 |
+
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
|
| 359 |
+
}
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
grad_inputs[t] = result;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
template <typename scalar_t, uint32_t D>
|
| 366 |
+
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings,
|
| 367 |
+
const int *offsets, scalar_t *outputs,
|
| 368 |
+
const uint32_t B, const uint32_t C, const uint32_t L,
|
| 369 |
+
const float S, const uint32_t H,
|
| 370 |
+
const bool calc_grad_inputs, scalar_t *dy_dx,
|
| 371 |
+
const uint32_t gridtype, const bool align_corners) {
|
| 372 |
+
static constexpr uint32_t N_THREAD = 512;
|
| 373 |
+
const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1};
|
| 374 |
+
switch (C) {
|
| 375 |
+
case 1:
|
| 376 |
+
kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(
|
| 377 |
+
inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
|
| 378 |
+
dy_dx, gridtype, align_corners);
|
| 379 |
+
break;
|
| 380 |
+
case 2:
|
| 381 |
+
kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(
|
| 382 |
+
inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
|
| 383 |
+
dy_dx, gridtype, align_corners);
|
| 384 |
+
break;
|
| 385 |
+
case 4:
|
| 386 |
+
kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(
|
| 387 |
+
inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
|
| 388 |
+
dy_dx, gridtype, align_corners);
|
| 389 |
+
break;
|
| 390 |
+
case 8:
|
| 391 |
+
kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(
|
| 392 |
+
inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs,
|
| 393 |
+
dy_dx, gridtype, align_corners);
|
| 394 |
+
break;
|
| 395 |
+
default:
|
| 396 |
+
throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
// inputs: [B, D], float, in [0, 1]
|
| 401 |
+
// embeddings: [sO, C], float
|
| 402 |
+
// offsets: [L + 1], uint32_t
|
| 403 |
+
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit
|
| 404 |
+
// into cache at a time.) H: base resolution dy_dx: [B, L * D * C]
|
| 405 |
+
template <typename scalar_t>
|
| 406 |
+
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings,
|
| 407 |
+
const int *offsets, scalar_t *outputs,
|
| 408 |
+
const uint32_t B, const uint32_t D,
|
| 409 |
+
const uint32_t C, const uint32_t L, const float S,
|
| 410 |
+
const uint32_t H, const bool calc_grad_inputs,
|
| 411 |
+
scalar_t *dy_dx, const uint32_t gridtype,
|
| 412 |
+
const bool align_corners) {
|
| 413 |
+
switch (D) {
|
| 414 |
+
case 2:
|
| 415 |
+
kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C,
|
| 416 |
+
L, S, H, calc_grad_inputs, dy_dx, gridtype,
|
| 417 |
+
align_corners);
|
| 418 |
+
break;
|
| 419 |
+
case 3:
|
| 420 |
+
kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C,
|
| 421 |
+
L, S, H, calc_grad_inputs, dy_dx, gridtype,
|
| 422 |
+
align_corners);
|
| 423 |
+
break;
|
| 424 |
+
case 4:
|
| 425 |
+
kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C,
|
| 426 |
+
L, S, H, calc_grad_inputs, dy_dx, gridtype,
|
| 427 |
+
align_corners);
|
| 428 |
+
break;
|
| 429 |
+
case 5:
|
| 430 |
+
kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C,
|
| 431 |
+
L, S, H, calc_grad_inputs, dy_dx, gridtype,
|
| 432 |
+
align_corners);
|
| 433 |
+
break;
|
| 434 |
+
default:
|
| 435 |
+
throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
template <typename scalar_t, uint32_t D>
|
| 440 |
+
void kernel_grid_backward_wrapper(
|
| 441 |
+
const scalar_t *grad, const float *inputs, const scalar_t *embeddings,
|
| 442 |
+
const int *offsets, scalar_t *grad_embeddings, const uint32_t B,
|
| 443 |
+
const uint32_t C, const uint32_t L, const float S, const uint32_t H,
|
| 444 |
+
const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs,
|
| 445 |
+
const uint32_t gridtype, const bool align_corners) {
|
| 446 |
+
static constexpr uint32_t N_THREAD = 256;
|
| 447 |
+
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
|
| 448 |
+
const dim3 blocks_hashgrid = {div_round_up(B * C / N_C, N_THREAD), L, 1};
|
| 449 |
+
switch (C) {
|
| 450 |
+
case 1:
|
| 451 |
+
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(
|
| 452 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
|
| 453 |
+
gridtype, align_corners);
|
| 454 |
+
if (calc_grad_inputs)
|
| 455 |
+
kernel_input_backward<scalar_t, D, 1>
|
| 456 |
+
<<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
|
| 457 |
+
grad_inputs, B, L);
|
| 458 |
+
break;
|
| 459 |
+
case 2:
|
| 460 |
+
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(
|
| 461 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
|
| 462 |
+
gridtype, align_corners);
|
| 463 |
+
if (calc_grad_inputs)
|
| 464 |
+
kernel_input_backward<scalar_t, D, 2>
|
| 465 |
+
<<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
|
| 466 |
+
grad_inputs, B, L);
|
| 467 |
+
break;
|
| 468 |
+
case 4:
|
| 469 |
+
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(
|
| 470 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
|
| 471 |
+
gridtype, align_corners);
|
| 472 |
+
if (calc_grad_inputs)
|
| 473 |
+
kernel_input_backward<scalar_t, D, 4>
|
| 474 |
+
<<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
|
| 475 |
+
grad_inputs, B, L);
|
| 476 |
+
break;
|
| 477 |
+
case 8:
|
| 478 |
+
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(
|
| 479 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H,
|
| 480 |
+
gridtype, align_corners);
|
| 481 |
+
if (calc_grad_inputs)
|
| 482 |
+
kernel_input_backward<scalar_t, D, 8>
|
| 483 |
+
<<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx,
|
| 484 |
+
grad_inputs, B, L);
|
| 485 |
+
break;
|
| 486 |
+
default:
|
| 487 |
+
throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
// grad: [L, B, C], float
|
| 492 |
+
// inputs: [B, D], float, in [0, 1]
|
| 493 |
+
// embeddings: [sO, C], float
|
| 494 |
+
// offsets: [L + 1], uint32_t
|
| 495 |
+
// grad_embeddings: [sO, C]
|
| 496 |
+
// H: base resolution
|
| 497 |
+
template <typename scalar_t>
|
| 498 |
+
void grid_encode_backward_cuda(
|
| 499 |
+
const scalar_t *grad, const float *inputs, const scalar_t *embeddings,
|
| 500 |
+
const int *offsets, scalar_t *grad_embeddings, const uint32_t B,
|
| 501 |
+
const uint32_t D, const uint32_t C, const uint32_t L, const float S,
|
| 502 |
+
const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx,
|
| 503 |
+
scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
|
| 504 |
+
switch (D) {
|
| 505 |
+
case 2:
|
| 506 |
+
kernel_grid_backward_wrapper<scalar_t, 2>(
|
| 507 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
|
| 508 |
+
calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
|
| 509 |
+
break;
|
| 510 |
+
case 3:
|
| 511 |
+
kernel_grid_backward_wrapper<scalar_t, 3>(
|
| 512 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
|
| 513 |
+
calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
|
| 514 |
+
break;
|
| 515 |
+
case 4:
|
| 516 |
+
kernel_grid_backward_wrapper<scalar_t, 4>(
|
| 517 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
|
| 518 |
+
calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
|
| 519 |
+
break;
|
| 520 |
+
case 5:
|
| 521 |
+
kernel_grid_backward_wrapper<scalar_t, 5>(
|
| 522 |
+
grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H,
|
| 523 |
+
calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners);
|
| 524 |
+
break;
|
| 525 |
+
default:
|
| 526 |
+
throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
|
| 527 |
+
}
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings,
|
| 531 |
+
const at::Tensor offsets, at::Tensor outputs,
|
| 532 |
+
const uint32_t B, const uint32_t D, const uint32_t C,
|
| 533 |
+
const uint32_t L, const float S, const uint32_t H,
|
| 534 |
+
const bool calc_grad_inputs, at::Tensor dy_dx,
|
| 535 |
+
const uint32_t gridtype, const bool align_corners) {
|
| 536 |
+
CHECK_CUDA(inputs);
|
| 537 |
+
CHECK_CUDA(embeddings);
|
| 538 |
+
CHECK_CUDA(offsets);
|
| 539 |
+
CHECK_CUDA(outputs);
|
| 540 |
+
CHECK_CUDA(dy_dx);
|
| 541 |
+
|
| 542 |
+
CHECK_CONTIGUOUS(inputs);
|
| 543 |
+
CHECK_CONTIGUOUS(embeddings);
|
| 544 |
+
CHECK_CONTIGUOUS(offsets);
|
| 545 |
+
CHECK_CONTIGUOUS(outputs);
|
| 546 |
+
CHECK_CONTIGUOUS(dy_dx);
|
| 547 |
+
|
| 548 |
+
CHECK_IS_FLOATING(inputs);
|
| 549 |
+
CHECK_IS_FLOATING(embeddings);
|
| 550 |
+
CHECK_IS_INT(offsets);
|
| 551 |
+
CHECK_IS_FLOATING(outputs);
|
| 552 |
+
CHECK_IS_FLOATING(dy_dx);
|
| 553 |
+
|
| 554 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 555 |
+
embeddings.scalar_type(), "grid_encode_forward", ([&] {
|
| 556 |
+
grid_encode_forward_cuda<scalar_t>(
|
| 557 |
+
inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(),
|
| 558 |
+
offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L,
|
| 559 |
+
S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype,
|
| 560 |
+
align_corners);
|
| 561 |
+
}));
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs,
|
| 565 |
+
const at::Tensor embeddings, const at::Tensor offsets,
|
| 566 |
+
at::Tensor grad_embeddings, const uint32_t B,
|
| 567 |
+
const uint32_t D, const uint32_t C, const uint32_t L,
|
| 568 |
+
const float S, const uint32_t H,
|
| 569 |
+
const bool calc_grad_inputs, const at::Tensor dy_dx,
|
| 570 |
+
at::Tensor grad_inputs, const uint32_t gridtype,
|
| 571 |
+
const bool align_corners) {
|
| 572 |
+
CHECK_CUDA(grad);
|
| 573 |
+
CHECK_CUDA(inputs);
|
| 574 |
+
CHECK_CUDA(embeddings);
|
| 575 |
+
CHECK_CUDA(offsets);
|
| 576 |
+
CHECK_CUDA(grad_embeddings);
|
| 577 |
+
CHECK_CUDA(dy_dx);
|
| 578 |
+
CHECK_CUDA(grad_inputs);
|
| 579 |
+
|
| 580 |
+
CHECK_CONTIGUOUS(grad);
|
| 581 |
+
CHECK_CONTIGUOUS(inputs);
|
| 582 |
+
CHECK_CONTIGUOUS(embeddings);
|
| 583 |
+
CHECK_CONTIGUOUS(offsets);
|
| 584 |
+
CHECK_CONTIGUOUS(grad_embeddings);
|
| 585 |
+
CHECK_CONTIGUOUS(dy_dx);
|
| 586 |
+
CHECK_CONTIGUOUS(grad_inputs);
|
| 587 |
+
|
| 588 |
+
CHECK_IS_FLOATING(grad);
|
| 589 |
+
CHECK_IS_FLOATING(inputs);
|
| 590 |
+
CHECK_IS_FLOATING(embeddings);
|
| 591 |
+
CHECK_IS_INT(offsets);
|
| 592 |
+
CHECK_IS_FLOATING(grad_embeddings);
|
| 593 |
+
CHECK_IS_FLOATING(dy_dx);
|
| 594 |
+
CHECK_IS_FLOATING(grad_inputs);
|
| 595 |
+
|
| 596 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 597 |
+
grad.scalar_type(), "grid_encode_backward", ([&] {
|
| 598 |
+
grid_encode_backward_cuda<scalar_t>(
|
| 599 |
+
grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(),
|
| 600 |
+
embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(),
|
| 601 |
+
grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H,
|
| 602 |
+
calc_grad_inputs, dy_dx.data_ptr<scalar_t>(),
|
| 603 |
+
grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);
|
| 604 |
+
}));
|
| 605 |
+
}
|
gaussiancity/extensions/grid_encoder/setup.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: setup.py
|
| 4 |
+
# @Author: Jiaxiang Tang (@ashawkey)
|
| 5 |
+
# @Date: 2023-04-15 10:33:32
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-09-18 10:08:45
|
| 8 |
+
# @Email: ashawkey1999@gmail.com
|
| 9 |
+
# @Ref: https://github.com/ashawkey/torch-ngp
|
| 10 |
+
|
| 11 |
+
from setuptools import setup
|
| 12 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 13 |
+
|
| 14 |
+
setup(
|
| 15 |
+
name="grid_encoder",
|
| 16 |
+
version="1.0.0",
|
| 17 |
+
ext_modules=[
|
| 18 |
+
CUDAExtension(
|
| 19 |
+
name="grid_encoder_ext",
|
| 20 |
+
sources=[
|
| 21 |
+
"grid_encoder_ext.cu",
|
| 22 |
+
"bindings.cpp",
|
| 23 |
+
],
|
| 24 |
+
extra_compile_args={
|
| 25 |
+
"cxx": ["-O3", "-std=c++17"],
|
| 26 |
+
"nvcc": [
|
| 27 |
+
"-O3",
|
| 28 |
+
"-std=c++17",
|
| 29 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
| 30 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
| 31 |
+
"-U__CUDA_NO_HALF2_OPERATORS__",
|
| 32 |
+
],
|
| 33 |
+
},
|
| 34 |
+
),
|
| 35 |
+
],
|
| 36 |
+
cmdclass={
|
| 37 |
+
"build_ext": BuildExtension,
|
| 38 |
+
},
|
| 39 |
+
)
|
gaussiancity/extensions/voxlib/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: setup.py
|
| 4 |
+
# @Author: NVIDIA Corporation
|
| 5 |
+
# @Date: 2021-10-13 00:00:00
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-10-13 03:14:15
|
| 8 |
+
# @Email: root@haozhexie.com
|
| 9 |
+
|
| 10 |
+
from voxlib import ray_voxel_intersection_perspective
|
| 11 |
+
from voxlib import points_to_volume
|
| 12 |
+
from voxlib import maps_to_volume
|
gaussiancity/extensions/voxlib/bindings.cpp
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: bindings.cpp
|
| 3 |
+
* @Author: NVIDIA Corporation
|
| 4 |
+
* @Date: 2021-10-13 00:00:00
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2024-10-13 03:03:45
|
| 7 |
+
* @Email: root@haozhexie.com
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#include <pybind11/pybind11.h>
|
| 11 |
+
#include <pybind11/stl.h>
|
| 12 |
+
#include <torch/extension.h>
|
| 13 |
+
#include <vector>
|
| 14 |
+
|
| 15 |
+
// Fast voxel traversal along rays
|
| 16 |
+
std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
|
| 17 |
+
const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
|
| 18 |
+
const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
|
| 19 |
+
const std::vector<float> &cam_c, const std::vector<int> &img_dims,
|
| 20 |
+
int max_samples);
|
| 21 |
+
|
| 22 |
+
torch::Tensor points_to_volume_cuda(const torch::Tensor &points,
|
| 23 |
+
const torch::Tensor &pt_ids,
|
| 24 |
+
const torch::Tensor &scales, int h, int w,
|
| 25 |
+
int d);
|
| 26 |
+
|
| 27 |
+
torch::Tensor
|
| 28 |
+
maps_to_volume_cuda(const torch::Tensor &inst_map, const torch::Tensor &td_hf,
|
| 29 |
+
const torch::Tensor &bu_hf,
|
| 30 |
+
const torch::Tensor &pts_map,
|
| 31 |
+
const torch::Tensor &scales);
|
| 32 |
+
|
| 33 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 34 |
+
m.def("ray_voxel_intersection_perspective",
|
| 35 |
+
&ray_voxel_intersection_perspective_cuda,
|
| 36 |
+
"Ray-voxel intersections given perspective camera parameters (CUDA)");
|
| 37 |
+
m.def("points_to_volume", &points_to_volume_cuda,
|
| 38 |
+
"Generate 3D volume from points (CUDA)");
|
| 39 |
+
m.def("maps_to_volume", &maps_to_volume_cuda,
|
| 40 |
+
"Generate 3D volume from maps (CUDA)");
|
| 41 |
+
}
|
gaussiancity/extensions/voxlib/maps_to_volume.cu
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: maps_to_volume.cu
|
| 3 |
+
* @Author: Haozhe Xie
|
| 4 |
+
* @Date: 2024-10-09 15:42:49
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2024-10-13 12:26:15
|
| 7 |
+
* @Email: root@haozhexie.com
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 11 |
+
#include <torch/extension.h>
|
| 12 |
+
|
| 13 |
+
#include "voxlib_common.h"
|
| 14 |
+
|
| 15 |
+
#define TILE_DIM 16
|
| 16 |
+
#define BLDG_MAX_HEIGHT 504
|
| 17 |
+
#define BLDG_INS_MIN_ID 10
|
| 18 |
+
#define BLDG_FACADE_SEM 2
|
| 19 |
+
#define BLDG_ROOF_OFFSET 1
|
| 20 |
+
|
| 21 |
+
__global__ void maps_to_volume_cuda_kernel(int height, int width, int depth,
|
| 22 |
+
const int8_t *__restrict__ scales,
|
| 23 |
+
const short *__restrict__ inst_map,
|
| 24 |
+
const short *__restrict__ td_hf,
|
| 25 |
+
const short *__restrict__ bu_hf,
|
| 26 |
+
const bool *__restrict__ pts_map,
|
| 27 |
+
short *__restrict__ volume) {
|
| 28 |
+
size_t i = blockIdx.x * blockDim.x + threadIdx.x; // width
|
| 29 |
+
size_t j = blockIdx.y * blockDim.y + threadIdx.y; // height
|
| 30 |
+
|
| 31 |
+
if (i < width && j < height) {
|
| 32 |
+
bool has_pt = pts_map[j * width + i];
|
| 33 |
+
if (!has_pt) {
|
| 34 |
+
return;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
// Fix: nonzero is not supported for tensors with more than INT_MAX elements
|
| 38 |
+
short hgt_up = td_hf[j * width + i];
|
| 39 |
+
short hgt_lw = bu_hf[j * width + i];
|
| 40 |
+
short inst = inst_map[j * width + i];
|
| 41 |
+
// WARN: The semantic labels for buildings would be merged to facade.
|
| 42 |
+
short sem_cls = inst < BLDG_INS_MIN_ID ? inst : BLDG_FACADE_SEM;
|
| 43 |
+
short scale = scales[sem_cls];
|
| 44 |
+
|
| 45 |
+
int64_t vol_offset = static_cast<int64_t>(j) * width * depth + i * depth;
|
| 46 |
+
for (int k = hgt_lw; k <= hgt_up; k += scale) {
|
| 47 |
+
// Make all objects hallow
|
| 48 |
+
bool is_border_1 = (k > hgt_up - scale) || (i < scale) ||
|
| 49 |
+
(i >= width - scale - 1) || (j < scale) ||
|
| 50 |
+
(j >= height - scale - 1);
|
| 51 |
+
bool is_border_2 = false;
|
| 52 |
+
bool is_border_3 = false;
|
| 53 |
+
if (!is_border_1) {
|
| 54 |
+
// Check is_border_1 to Prevent OOB
|
| 55 |
+
short nbr_hd_hf[8] = {
|
| 56 |
+
td_hf[(j - scale) * width + (i - scale)],
|
| 57 |
+
td_hf[(j - scale) * width + i],
|
| 58 |
+
td_hf[(j - scale) * width + (i + scale)],
|
| 59 |
+
td_hf[j * width + (i - scale)],
|
| 60 |
+
td_hf[j * width + (i + scale)],
|
| 61 |
+
td_hf[(j + scale) * width + (i - scale)],
|
| 62 |
+
td_hf[(j + scale) * width + i],
|
| 63 |
+
td_hf[(j + scale) * width + (i + scale)],
|
| 64 |
+
};
|
| 65 |
+
for (int ni = 0; ni < 8; ++ni) {
|
| 66 |
+
if (nbr_hd_hf[ni] != hgt_up) {
|
| 67 |
+
is_border_2 = true;
|
| 68 |
+
break;
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
short nbr_inst[8] = {
|
| 73 |
+
inst_map[(j - scale) * width + (i - scale)],
|
| 74 |
+
inst_map[(j - scale) * width + i],
|
| 75 |
+
inst_map[(j - scale) * width + (i + scale)],
|
| 76 |
+
inst_map[j * width + (i - scale)],
|
| 77 |
+
inst_map[j * width + (i + scale)],
|
| 78 |
+
inst_map[(j + scale) * width + (i - scale)],
|
| 79 |
+
inst_map[(j + scale) * width + i],
|
| 80 |
+
inst_map[(j + scale) * width + (i + scale)],
|
| 81 |
+
};
|
| 82 |
+
for (int ni = 0; ni < 8; ++ni) {
|
| 83 |
+
if (nbr_inst[ni] != inst) {
|
| 84 |
+
is_border_3 = true;
|
| 85 |
+
break;
|
| 86 |
+
}
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
if (!is_border_1 && !is_border_2 && !is_border_3) {
|
| 90 |
+
continue;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
// Building Roof Handler (Recover roof instance ID)
|
| 94 |
+
if (k > hgt_up - scale && sem_cls == BLDG_FACADE_SEM) {
|
| 95 |
+
volume[vol_offset + k] = inst + 1;
|
| 96 |
+
} else {
|
| 97 |
+
volume[vol_offset + k] = inst;
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
torch::Tensor maps_to_volume_cuda(const torch::Tensor &inst_map,
|
| 104 |
+
const torch::Tensor &td_hf,
|
| 105 |
+
const torch::Tensor &bu_hf,
|
| 106 |
+
const torch::Tensor &pts_map,
|
| 107 |
+
const torch::Tensor &scales) {
|
| 108 |
+
CHECK_CUDA(inst_map);
|
| 109 |
+
CHECK_CUDA(td_hf);
|
| 110 |
+
CHECK_CUDA(bu_hf);
|
| 111 |
+
CHECK_CUDA(pts_map);
|
| 112 |
+
CHECK_CUDA(scales);
|
| 113 |
+
|
| 114 |
+
int curDevice = -1;
|
| 115 |
+
cudaGetDevice(&curDevice);
|
| 116 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
| 117 |
+
torch::Device device = inst_map.device();
|
| 118 |
+
|
| 119 |
+
int height = inst_map.size(0);
|
| 120 |
+
int width = inst_map.size(1);
|
| 121 |
+
int depth = BLDG_MAX_HEIGHT;
|
| 122 |
+
|
| 123 |
+
dim3 blockDim(TILE_DIM, TILE_DIM);
|
| 124 |
+
dim3 gridDim((width + blockDim.x - 1) / blockDim.x,
|
| 125 |
+
(height + blockDim.y - 1) / blockDim.y);
|
| 126 |
+
|
| 127 |
+
torch::Tensor volume =
|
| 128 |
+
torch::zeros({height, width, depth},
|
| 129 |
+
torch::TensorOptions().dtype(torch::kInt16).device(device));
|
| 130 |
+
maps_to_volume_cuda_kernel<<<gridDim, blockDim, 0, stream>>>(
|
| 131 |
+
height, width, depth, scales.data_ptr<int8_t>(),
|
| 132 |
+
inst_map.data_ptr<short>(), td_hf.data_ptr<short>(),
|
| 133 |
+
bu_hf.data_ptr<short>(), pts_map.data_ptr<bool>(),
|
| 134 |
+
volume.data_ptr<short>());
|
| 135 |
+
|
| 136 |
+
cudaError_t err = cudaGetLastError();
|
| 137 |
+
if (err != cudaSuccess) {
|
| 138 |
+
printf("Error in maps_to_volume_cuda_kernel: %s\n",
|
| 139 |
+
cudaGetErrorString(err));
|
| 140 |
+
}
|
| 141 |
+
return volume;
|
| 142 |
+
}
|
gaussiancity/extensions/voxlib/points_to_volume.cu
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: points_to_volume.cu
|
| 3 |
+
* @Author: Haozhe Xie
|
| 4 |
+
* @Date: 2024-02-24 14:09:38
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2024-10-13 12:29:46
|
| 7 |
+
* @Email: root@haozhexie.com
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#include <cmath>
|
| 11 |
+
#include <cstdio>
|
| 12 |
+
#include <cstdlib>
|
| 13 |
+
|
| 14 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 15 |
+
#include <torch/extension.h>
|
| 16 |
+
|
| 17 |
+
#include "voxlib_common.h"
|
| 18 |
+
|
| 19 |
+
#define THREADS_PER_BLOCK 256
|
| 20 |
+
|
| 21 |
+
__global__ void points_to_volume_cuda_cuda_kernel(
|
| 22 |
+
size_t n_pts, int h, int w, int d, const short *__restrict__ points,
|
| 23 |
+
const int *__restrict__ pt_ids, const short *__restrict__ scales,
|
| 24 |
+
int *__restrict__ volume) {
|
| 25 |
+
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 26 |
+
|
| 27 |
+
if (idx >= n_pts) {
|
| 28 |
+
return;
|
| 29 |
+
}
|
| 30 |
+
int pid = pt_ids[idx];
|
| 31 |
+
int idx3 = idx * 3;
|
| 32 |
+
short x = points[idx3];
|
| 33 |
+
short y = points[idx3 + 1];
|
| 34 |
+
short z = points[idx3 + 2];
|
| 35 |
+
short sx = scales[idx3];
|
| 36 |
+
short sy = scales[idx3 + 1];
|
| 37 |
+
short sz = scales[idx3 + 2];
|
| 38 |
+
|
| 39 |
+
if (x >= w || y >= h || z >= d || x < 0 || y < 0 || z < 0) {
|
| 40 |
+
return;
|
| 41 |
+
}
|
| 42 |
+
for (int j = x; j < x + sx && j < w; ++j) {
|
| 43 |
+
for (int k = y; k < y + sy && k < h; ++k) {
|
| 44 |
+
for (int l = z; l < z + sz && l < d; ++l) {
|
| 45 |
+
int64_t idx = static_cast<int64_t>(k) * w * d + j * d + l;
|
| 46 |
+
volume[idx] = pid;
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
torch::Tensor points_to_volume_cuda(const torch::Tensor &points,
|
| 53 |
+
const torch::Tensor &pt_ids,
|
| 54 |
+
const torch::Tensor &scales, int h, int w,
|
| 55 |
+
int d) {
|
| 56 |
+
CHECK_CUDA(points);
|
| 57 |
+
CHECK_CUDA(pt_ids);
|
| 58 |
+
CHECK_CUDA(scales);
|
| 59 |
+
|
| 60 |
+
size_t n_pts = points.size(0);
|
| 61 |
+
int curDevice = -1;
|
| 62 |
+
cudaGetDevice(&curDevice);
|
| 63 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
| 64 |
+
torch::Device device = points.device();
|
| 65 |
+
|
| 66 |
+
int n_blocks = (n_pts + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
| 67 |
+
torch::Tensor volume = torch::zeros(
|
| 68 |
+
{h, w, d}, torch::TensorOptions().dtype(torch::kInt32).device(device));
|
| 69 |
+
points_to_volume_cuda_cuda_kernel<<<n_blocks, THREADS_PER_BLOCK, 0, stream>>>(
|
| 70 |
+
n_pts, h, w, d, points.data_ptr<short>(), pt_ids.data_ptr<int>(),
|
| 71 |
+
scales.data_ptr<short>(), volume.data_ptr<int>());
|
| 72 |
+
|
| 73 |
+
cudaError_t err = cudaGetLastError();
|
| 74 |
+
if (err != cudaSuccess) {
|
| 75 |
+
printf("Error in points_to_volume_cuda_cuda_kernel: %s\n",
|
| 76 |
+
cudaGetErrorString(err));
|
| 77 |
+
}
|
| 78 |
+
return volume;
|
| 79 |
+
}
|
gaussiancity/extensions/voxlib/ray_voxel_intersection.cu
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* @File: ray_voxel_intersection.cu
|
| 3 |
+
* @Author: NVIDIA Corporation
|
| 4 |
+
* @Date: 2021-10-13 00:00:00
|
| 5 |
+
* @Last Modified by: Haozhe Xie
|
| 6 |
+
* @Last Modified at: 2024-03-27 11:02:41
|
| 7 |
+
* @Email: root@haozhexie.com
|
| 8 |
+
*/
|
| 9 |
+
|
| 10 |
+
#include <torch/types.h>
|
| 11 |
+
|
| 12 |
+
#include <ATen/ATen.h>
|
| 13 |
+
#include <ATen/AccumulateType.h>
|
| 14 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
| 15 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 16 |
+
|
| 17 |
+
#include <cuda.h>
|
| 18 |
+
#include <cuda_runtime.h>
|
| 19 |
+
#include <curand.h>
|
| 20 |
+
#include <curand_kernel.h>
|
| 21 |
+
#include <time.h>
|
| 22 |
+
|
| 23 |
+
//#include <pybind11/numpy.h>
|
| 24 |
+
#include <pybind11/pybind11.h>
|
| 25 |
+
#include <pybind11/stl.h>
|
| 26 |
+
#include <vector>
|
| 27 |
+
|
| 28 |
+
#include "voxlib_common.h"
|
| 29 |
+
|
| 30 |
+
struct RVIP_Params {
|
| 31 |
+
int voxel_dims[3];
|
| 32 |
+
int voxel_strides[3];
|
| 33 |
+
int max_samples;
|
| 34 |
+
int img_dims[2];
|
| 35 |
+
// Camera parameters
|
| 36 |
+
float cam_ori[3];
|
| 37 |
+
float cam_fwd[3];
|
| 38 |
+
float cam_side[3];
|
| 39 |
+
float cam_up[3];
|
| 40 |
+
float cam_c[2];
|
| 41 |
+
float cam_f;
|
| 42 |
+
// unsigned long seed;
|
| 43 |
+
};
|
| 44 |
+
|
| 45 |
+
/*
|
| 46 |
+
out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1]
|
| 47 |
+
out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1]
|
| 48 |
+
out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3]
|
| 49 |
+
Image coordinates refer to the center of the pixel [0, 0, 0] at voxel
|
| 50 |
+
coordinate is at the corner of the corner block (instead of at the center)
|
| 51 |
+
*/
|
| 52 |
+
template <int TILE_DIM>
|
| 53 |
+
static __global__ void ray_voxel_intersection_perspective_kernel(
|
| 54 |
+
int32_t *__restrict__ out_voxel_id, float *__restrict__ out_depth,
|
| 55 |
+
float *__restrict__ out_raydirs, const int32_t *__restrict__ in_voxel,
|
| 56 |
+
const RVIP_Params p) {
|
| 57 |
+
|
| 58 |
+
int img_coords[2];
|
| 59 |
+
img_coords[1] = blockIdx.x * TILE_DIM + threadIdx.x;
|
| 60 |
+
img_coords[0] = blockIdx.y * TILE_DIM + threadIdx.y;
|
| 61 |
+
if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) {
|
| 62 |
+
return;
|
| 63 |
+
}
|
| 64 |
+
int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1];
|
| 65 |
+
|
| 66 |
+
// Calculate ray origin and direction
|
| 67 |
+
float rayori[3], raydir[3];
|
| 68 |
+
rayori[0] = p.cam_ori[0];
|
| 69 |
+
rayori[1] = p.cam_ori[1];
|
| 70 |
+
rayori[2] = p.cam_ori[2];
|
| 71 |
+
|
| 72 |
+
// Camera intrinsics
|
| 73 |
+
float ndc_imcoords[2];
|
| 74 |
+
ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height
|
| 75 |
+
ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1];
|
| 76 |
+
|
| 77 |
+
raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] +
|
| 78 |
+
p.cam_fwd[0] * p.cam_f;
|
| 79 |
+
raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] +
|
| 80 |
+
p.cam_fwd[1] * p.cam_f;
|
| 81 |
+
raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] +
|
| 82 |
+
p.cam_fwd[2] * p.cam_f;
|
| 83 |
+
normalize<float, 3>(raydir);
|
| 84 |
+
|
| 85 |
+
// Save out_raydirs
|
| 86 |
+
out_raydirs[pix_index * 3] = raydir[0];
|
| 87 |
+
out_raydirs[pix_index * 3 + 1] = raydir[1];
|
| 88 |
+
out_raydirs[pix_index * 3 + 2] = raydir[2];
|
| 89 |
+
|
| 90 |
+
float axis_t[3];
|
| 91 |
+
int axis_int[3];
|
| 92 |
+
// int axis_intbound[3];
|
| 93 |
+
|
| 94 |
+
// Current voxel
|
| 95 |
+
axis_int[0] = floorf(rayori[0]);
|
| 96 |
+
axis_int[1] = floorf(rayori[1]);
|
| 97 |
+
axis_int[2] = floorf(rayori[2]);
|
| 98 |
+
|
| 99 |
+
#pragma unroll
|
| 100 |
+
for (int i = 0; i < 3; i++) {
|
| 101 |
+
if (raydir[i] > 0) {
|
| 102 |
+
// Initial t value
|
| 103 |
+
// Handle boundary case where rayori[i] is a whole number. Always round Up
|
| 104 |
+
// for the next block
|
| 105 |
+
// axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) /
|
| 106 |
+
// raydir[i];
|
| 107 |
+
axis_t[i] = ((float)(axis_int[i] + 1) - rayori[i]) / raydir[i];
|
| 108 |
+
} else if (raydir[i] < 0) {
|
| 109 |
+
axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
|
| 110 |
+
} else {
|
| 111 |
+
axis_t[i] = HUGE_VALF;
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
// Fused raymarching and sampling
|
| 116 |
+
bool quit = false;
|
| 117 |
+
for (int cur_plane = 0; cur_plane < p.max_samples;
|
| 118 |
+
cur_plane++) { // Last cycle is for calculating p2
|
| 119 |
+
float t = nanf("0");
|
| 120 |
+
float t2 = nanf("0");
|
| 121 |
+
int32_t blk_id = 0;
|
| 122 |
+
// Find the next intersection
|
| 123 |
+
while (!quit) {
|
| 124 |
+
// Find the next smallest t
|
| 125 |
+
float tnow;
|
| 126 |
+
// Hand unroll
|
| 127 |
+
if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
|
| 128 |
+
// Update current t
|
| 129 |
+
tnow = axis_t[0];
|
| 130 |
+
// Update t candidates
|
| 131 |
+
if (raydir[0] > 0) {
|
| 132 |
+
axis_int[0] += 1;
|
| 133 |
+
if (axis_int[0] >= p.voxel_dims[0]) {
|
| 134 |
+
quit = true;
|
| 135 |
+
}
|
| 136 |
+
axis_t[0] = ((float)(axis_int[0] + 1) - rayori[0]) / raydir[0];
|
| 137 |
+
} else {
|
| 138 |
+
axis_int[0] -= 1;
|
| 139 |
+
if (axis_int[0] < 0) {
|
| 140 |
+
quit = true;
|
| 141 |
+
}
|
| 142 |
+
axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0];
|
| 143 |
+
}
|
| 144 |
+
} else if (axis_t[1] <= axis_t[2]) {
|
| 145 |
+
tnow = axis_t[1];
|
| 146 |
+
if (raydir[1] > 0) {
|
| 147 |
+
axis_int[1] += 1;
|
| 148 |
+
if (axis_int[1] >= p.voxel_dims[1]) {
|
| 149 |
+
quit = true;
|
| 150 |
+
}
|
| 151 |
+
axis_t[1] = ((float)(axis_int[1] + 1) - rayori[1]) / raydir[1];
|
| 152 |
+
} else {
|
| 153 |
+
axis_int[1] -= 1;
|
| 154 |
+
if (axis_int[1] < 0) {
|
| 155 |
+
quit = true;
|
| 156 |
+
}
|
| 157 |
+
axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1];
|
| 158 |
+
}
|
| 159 |
+
} else {
|
| 160 |
+
tnow = axis_t[2];
|
| 161 |
+
if (raydir[2] > 0) {
|
| 162 |
+
axis_int[2] += 1;
|
| 163 |
+
if (axis_int[2] >= p.voxel_dims[2]) {
|
| 164 |
+
quit = true;
|
| 165 |
+
}
|
| 166 |
+
axis_t[2] = ((float)(axis_int[2] + 1) - rayori[2]) / raydir[2];
|
| 167 |
+
} else {
|
| 168 |
+
axis_int[2] -= 1;
|
| 169 |
+
if (axis_int[2] < 0) {
|
| 170 |
+
quit = true;
|
| 171 |
+
}
|
| 172 |
+
axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2];
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
if (quit) {
|
| 177 |
+
break;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// Skip empty space
|
| 181 |
+
// Could there be deadlock if the ray direction is away from the world?
|
| 182 |
+
if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] ||
|
| 183 |
+
axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] ||
|
| 184 |
+
axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) {
|
| 185 |
+
continue;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
// Test intersection using voxel grid
|
| 189 |
+
int64_t in_voxel_idx =
|
| 190 |
+
static_cast<int64_t>(axis_int[0]) * p.voxel_strides[0] +
|
| 191 |
+
static_cast<int64_t>(axis_int[1]) * p.voxel_strides[1] +
|
| 192 |
+
static_cast<int64_t>(axis_int[2]) * p.voxel_strides[2];
|
| 193 |
+
blk_id = in_voxel[in_voxel_idx];
|
| 194 |
+
if (blk_id == 0) {
|
| 195 |
+
continue;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// Now that there is an intersection
|
| 199 |
+
t = tnow;
|
| 200 |
+
// Calculate t2
|
| 201 |
+
/*
|
| 202 |
+
#pragma unroll
|
| 203 |
+
for (int i=0; i<3; i++) {
|
| 204 |
+
if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
|
| 205 |
+
t2 = axis_t[i];
|
| 206 |
+
break;
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
*/
|
| 210 |
+
// Hand unroll
|
| 211 |
+
if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
|
| 212 |
+
t2 = axis_t[0];
|
| 213 |
+
} else if (axis_t[1] <= axis_t[2]) {
|
| 214 |
+
t2 = axis_t[1];
|
| 215 |
+
} else {
|
| 216 |
+
t2 = axis_t[2];
|
| 217 |
+
}
|
| 218 |
+
break;
|
| 219 |
+
} // while !quit (ray marching loop)
|
| 220 |
+
|
| 221 |
+
out_depth[pix_index * p.max_samples + cur_plane] = t;
|
| 222 |
+
out_depth[p.img_dims[0] * p.img_dims[1] * p.max_samples +
|
| 223 |
+
pix_index * p.max_samples + cur_plane] = t2;
|
| 224 |
+
out_voxel_id[pix_index * p.max_samples + cur_plane] = blk_id;
|
| 225 |
+
} // cur_plane
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
/*
|
| 229 |
+
out:
|
| 230 |
+
out_voxel_id: torch CUDA int32 [ img_dims[0], img_dims[1], max_samples, 1]
|
| 231 |
+
out_depth: torch CUDA float [2, img_dims[0], img_dims[1], max_samples, 1]
|
| 232 |
+
out_raydirs: torch CUDA float [ img_dims[0], img_dims[1], 1, 3]
|
| 233 |
+
in:
|
| 234 |
+
in_voxel: torch CUDA int32 [X, Y, Z] [40, 512, 512]
|
| 235 |
+
cam_ori: torch float [3]
|
| 236 |
+
cam_dir: torch float [3]
|
| 237 |
+
cam_up: torch float [3]
|
| 238 |
+
cam_f: float
|
| 239 |
+
cam_c: int [2]
|
| 240 |
+
img_dims: int [2]
|
| 241 |
+
max_samples: int
|
| 242 |
+
*/
|
| 243 |
+
std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
|
| 244 |
+
const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
|
| 245 |
+
const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
|
| 246 |
+
const std::vector<float> &cam_c, const std::vector<int> &img_dims,
|
| 247 |
+
int max_samples) {
|
| 248 |
+
CHECK_CUDA(in_voxel);
|
| 249 |
+
|
| 250 |
+
int curDevice = -1;
|
| 251 |
+
cudaGetDevice(&curDevice);
|
| 252 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
| 253 |
+
torch::Device device = in_voxel.device();
|
| 254 |
+
|
| 255 |
+
assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility
|
| 256 |
+
assert(in_voxel.dim() == 3);
|
| 257 |
+
assert(cam_ori.dtype() == torch::kFloat32);
|
| 258 |
+
assert(cam_ori.numel() == 3);
|
| 259 |
+
assert(cam_dir.dtype() == torch::kFloat32);
|
| 260 |
+
assert(cam_dir.numel() == 3);
|
| 261 |
+
assert(cam_up.dtype() == torch::kFloat32);
|
| 262 |
+
assert(cam_up.numel() == 3);
|
| 263 |
+
assert(img_dims.size() == 2);
|
| 264 |
+
|
| 265 |
+
RVIP_Params p;
|
| 266 |
+
|
| 267 |
+
// Calculate camera rays
|
| 268 |
+
const torch::Tensor cam_ori_c = cam_ori.cpu();
|
| 269 |
+
const torch::Tensor cam_dir_c = cam_dir.cpu();
|
| 270 |
+
const torch::Tensor cam_up_c = cam_up.cpu();
|
| 271 |
+
|
| 272 |
+
// Get the coordinate frame of camera space in world space
|
| 273 |
+
normalize<float, 3>(p.cam_fwd, cam_dir_c.data_ptr<float>());
|
| 274 |
+
cross<float>(p.cam_side, p.cam_fwd, cam_up_c.data_ptr<float>());
|
| 275 |
+
normalize<float, 3>(p.cam_side);
|
| 276 |
+
cross<float>(p.cam_up, p.cam_side, p.cam_fwd);
|
| 277 |
+
normalize<float, 3>(p.cam_up); // Not absolutely necessary as both vectors are
|
| 278 |
+
// normalized. But just in case...
|
| 279 |
+
|
| 280 |
+
copyarr<float, 3>(p.cam_ori, cam_ori_c.data_ptr<float>());
|
| 281 |
+
|
| 282 |
+
p.cam_f = cam_f;
|
| 283 |
+
p.cam_c[0] = cam_c[0];
|
| 284 |
+
p.cam_c[1] = cam_c[1];
|
| 285 |
+
p.max_samples = max_samples;
|
| 286 |
+
|
| 287 |
+
p.voxel_dims[0] = in_voxel.size(0);
|
| 288 |
+
p.voxel_dims[1] = in_voxel.size(1);
|
| 289 |
+
p.voxel_dims[2] = in_voxel.size(2);
|
| 290 |
+
p.voxel_strides[0] = in_voxel.stride(0);
|
| 291 |
+
p.voxel_strides[1] = in_voxel.stride(1);
|
| 292 |
+
p.voxel_strides[2] = in_voxel.stride(2);
|
| 293 |
+
|
| 294 |
+
p.img_dims[0] = img_dims[0];
|
| 295 |
+
p.img_dims[1] = img_dims[1];
|
| 296 |
+
|
| 297 |
+
// Create output tensors
|
| 298 |
+
// For Minecraft Seg Mask
|
| 299 |
+
torch::Tensor out_voxel_id =
|
| 300 |
+
torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1},
|
| 301 |
+
torch::TensorOptions().dtype(torch::kInt32).device(device));
|
| 302 |
+
|
| 303 |
+
torch::Tensor out_depth;
|
| 304 |
+
// Produce two sets of localcoords, one for entry point, the other one for
|
| 305 |
+
// exit point. They share the same corner_ids.
|
| 306 |
+
out_depth = torch::empty(
|
| 307 |
+
{2, p.img_dims[0], p.img_dims[1], p.max_samples, 1},
|
| 308 |
+
torch::TensorOptions().dtype(torch::kFloat32).device(device));
|
| 309 |
+
|
| 310 |
+
torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3},
|
| 311 |
+
torch::TensorOptions()
|
| 312 |
+
.dtype(torch::kFloat32)
|
| 313 |
+
.device(device)
|
| 314 |
+
.requires_grad(false));
|
| 315 |
+
|
| 316 |
+
const int TILE_DIM = 8;
|
| 317 |
+
dim3 dimGrid((p.img_dims[1] + TILE_DIM - 1) / TILE_DIM,
|
| 318 |
+
(p.img_dims[0] + TILE_DIM - 1) / TILE_DIM, 1);
|
| 319 |
+
dim3 dimBlock(TILE_DIM, TILE_DIM, 1);
|
| 320 |
+
|
| 321 |
+
ray_voxel_intersection_perspective_kernel<TILE_DIM>
|
| 322 |
+
<<<dimGrid, dimBlock, 0, stream>>>(
|
| 323 |
+
out_voxel_id.data_ptr<int32_t>(), out_depth.data_ptr<float>(),
|
| 324 |
+
out_raydirs.data_ptr<float>(), in_voxel.data_ptr<int32_t>(), p);
|
| 325 |
+
|
| 326 |
+
cudaError_t err = cudaGetLastError();
|
| 327 |
+
if (err != cudaSuccess) {
|
| 328 |
+
printf("Error in ray_voxel_intersection_perspective_kernel: %s\n",
|
| 329 |
+
cudaGetErrorString(err));
|
| 330 |
+
}
|
| 331 |
+
return {out_voxel_id, out_depth, out_raydirs};
|
| 332 |
+
}
|
gaussiancity/extensions/voxlib/setup.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: setup.py
|
| 4 |
+
# @Author: NVIDIA Corporation
|
| 5 |
+
# @Date: 2021-10-13 00:00:00
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-10-13 03:00:47
|
| 8 |
+
# @Email: root@haozhexie.com
|
| 9 |
+
|
| 10 |
+
from setuptools import setup
|
| 11 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 12 |
+
|
| 13 |
+
cxx_args = ["-fopenmp"]
|
| 14 |
+
nvcc_args = []
|
| 15 |
+
|
| 16 |
+
setup(
|
| 17 |
+
name="voxlib_ext",
|
| 18 |
+
version="3.0.0",
|
| 19 |
+
ext_modules=[
|
| 20 |
+
CUDAExtension(
|
| 21 |
+
"voxlib",
|
| 22 |
+
[
|
| 23 |
+
"bindings.cpp",
|
| 24 |
+
"ray_voxel_intersection.cu",
|
| 25 |
+
"points_to_volume.cu",
|
| 26 |
+
"maps_to_volume.cu",
|
| 27 |
+
],
|
| 28 |
+
extra_compile_args={"cxx": cxx_args, "nvcc": nvcc_args},
|
| 29 |
+
)
|
| 30 |
+
],
|
| 31 |
+
cmdclass={"build_ext": BuildExtension},
|
| 32 |
+
)
|
gaussiancity/extensions/voxlib/voxlib_common.h
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
//
|
| 3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
| 4 |
+
// To view a copy of this license, check out LICENSE.md
|
| 5 |
+
#ifndef VOXLIB_COMMON_H
|
| 6 |
+
#define VOXLIB_COMMON_H
|
| 7 |
+
|
| 8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
| 9 |
+
#define CHECK_CONTIGUOUS(x) \
|
| 10 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 11 |
+
#define CHECK_INPUT(x) \
|
| 12 |
+
CHECK_CUDA(x); \
|
| 13 |
+
CHECK_CONTIGUOUS(x)
|
| 14 |
+
#define CHECK_CPU(x) \
|
| 15 |
+
TORCH_CHECK(x.device().is_cpu(), #x " must be a CPU tensor")
|
| 16 |
+
|
| 17 |
+
#include <cuda.h>
|
| 18 |
+
#include <cuda_runtime.h>
|
| 19 |
+
// CUDA vector math functions
|
| 20 |
+
__host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
| 21 |
+
int c = a / b;
|
| 22 |
+
|
| 23 |
+
if (c * b > a) {
|
| 24 |
+
c--;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
return c;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
template <typename scalar_t>
|
| 31 |
+
__host__ __forceinline__ void cross(scalar_t *r, const scalar_t *a,
|
| 32 |
+
const scalar_t *b) {
|
| 33 |
+
r[0] = a[1] * b[2] - a[2] * b[1];
|
| 34 |
+
r[1] = a[2] * b[0] - a[0] * b[2];
|
| 35 |
+
r[2] = a[0] * b[1] - a[1] * b[0];
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
__device__ __host__ __forceinline__ float dot(const float *a, const float *b) {
|
| 39 |
+
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2];
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
template <typename scalar_t, int ndim>
|
| 43 |
+
__device__ __host__ __forceinline__ void copyarr(scalar_t *r,
|
| 44 |
+
const scalar_t *a) {
|
| 45 |
+
#pragma unroll
|
| 46 |
+
for (int i = 0; i < ndim; i++) {
|
| 47 |
+
r[i] = a[i];
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// TODO: use rsqrt to speed up
|
| 52 |
+
// inplace version
|
| 53 |
+
template <typename scalar_t, int ndim>
|
| 54 |
+
__device__ __host__ __forceinline__ void normalize(scalar_t *a) {
|
| 55 |
+
scalar_t vec_len = 0.0f;
|
| 56 |
+
#pragma unroll
|
| 57 |
+
for (int i = 0; i < ndim; i++) {
|
| 58 |
+
vec_len += a[i] * a[i];
|
| 59 |
+
}
|
| 60 |
+
vec_len = sqrtf(vec_len);
|
| 61 |
+
#pragma unroll
|
| 62 |
+
for (int i = 0; i < ndim; i++) {
|
| 63 |
+
a[i] /= vec_len;
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
// normalize + copy
|
| 68 |
+
template <typename scalar_t, int ndim>
|
| 69 |
+
__device__ __host__ __forceinline__ void normalize(scalar_t *r,
|
| 70 |
+
const scalar_t *a) {
|
| 71 |
+
scalar_t vec_len = 0.0f;
|
| 72 |
+
#pragma unroll
|
| 73 |
+
for (int i = 0; i < ndim; i++) {
|
| 74 |
+
vec_len += a[i] * a[i];
|
| 75 |
+
}
|
| 76 |
+
vec_len = sqrtf(vec_len);
|
| 77 |
+
#pragma unroll
|
| 78 |
+
for (int i = 0; i < ndim; i++) {
|
| 79 |
+
r[i] = a[i] / vec_len;
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
#endif // VOXLIB_COMMON_H
|
gaussiancity/generator.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: generator.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2024-03-09 20:36:52
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-09-23 20:49:35
|
| 8 |
+
# @Email: root@haozhexie.com
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
import extensions.grid_encoder
|
| 15 |
+
import gaussiancity.pt_v3
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Generator(torch.nn.Module):
|
| 19 |
+
def __init__(self, cfg, n_classes, proj_size):
|
| 20 |
+
super(Generator, self).__init__()
|
| 21 |
+
self.cfg = cfg
|
| 22 |
+
self.n_classes = n_classes
|
| 23 |
+
if cfg.ENCODER == "GLOBAL":
|
| 24 |
+
self.proj_encoder = GlobalEncoder(
|
| 25 |
+
n_classes, cfg.GLOBAL_ENCODER_N_BLOCKS, cfg.ENCODER_OUT_DIM - 3
|
| 26 |
+
)
|
| 27 |
+
elif cfg.ENCODER == "LOCAL":
|
| 28 |
+
self.proj_encoder = LocalEncoder(n_classes, cfg.ENCODER_OUT_DIM - 3)
|
| 29 |
+
elif cfg.ENCODER is None:
|
| 30 |
+
self.proj_encoder = None
|
| 31 |
+
assert cfg.ENCODER_OUT_DIM == 3
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError("Unknown encoder: %s" % cfg.ENCODER)
|
| 34 |
+
|
| 35 |
+
if cfg.POS_EMD == "HASH_GRID":
|
| 36 |
+
pt_feat_dim = cfg.HASH_GRID_N_LEVELS * cfg.HASH_GRID_LEVEL_DIM
|
| 37 |
+
self.pos_encoder = extensions.grid_encoder.GridEncoder(
|
| 38 |
+
in_channels=cfg.ENCODER_OUT_DIM,
|
| 39 |
+
desired_resolution=proj_size,
|
| 40 |
+
n_levels=cfg.HASH_GRID_N_LEVELS,
|
| 41 |
+
lvl_channels=cfg.HASH_GRID_LEVEL_DIM,
|
| 42 |
+
)
|
| 43 |
+
elif cfg.POS_EMD == "SIN_COS":
|
| 44 |
+
pt_feat_dim = 2 * cfg.ENCODER_OUT_DIM * cfg.SIN_COS_FREQ_BENDS
|
| 45 |
+
self.pos_encoder = SinCosEncoder(cfg.SIN_COS_FREQ_BENDS)
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError("Unknown positional encoder: %s" % cfg.POS_EMD)
|
| 48 |
+
|
| 49 |
+
if cfg.PTV3.ENABLED:
|
| 50 |
+
self.pt_net = gaussiancity.pt_v3.PointTransformerV3(
|
| 51 |
+
in_channels=pt_feat_dim,
|
| 52 |
+
order=cfg.PTV3.ORDER,
|
| 53 |
+
stride=cfg.PTV3.STRIDE,
|
| 54 |
+
enc_depths=cfg.PTV3.ENC_DEPTHS,
|
| 55 |
+
enc_channels=cfg.PTV3.ENC_CHANNELS,
|
| 56 |
+
enc_num_head=cfg.PTV3.ENC_N_HEAD,
|
| 57 |
+
enc_patch_size=cfg.PTV3.ENC_PATCH_SIZE,
|
| 58 |
+
dec_depths=cfg.PTV3.DEC_DEPTHS,
|
| 59 |
+
dec_channels=cfg.PTV3.DEC_CHANNELS,
|
| 60 |
+
dec_num_head=cfg.PTV3.DEC_N_HEAD,
|
| 61 |
+
dec_patch_size=cfg.PTV3.DEC_PATCH_SIZE,
|
| 62 |
+
enable_flash=cfg.PTV3.ENABLE_FLASH_ATTN,
|
| 63 |
+
)
|
| 64 |
+
pt_feat_dim += cfg.PTV3.DEC_CHANNELS[0]
|
| 65 |
+
else:
|
| 66 |
+
self.pt_net = None
|
| 67 |
+
|
| 68 |
+
self.ga_mlp = GaussianAttrMLP(
|
| 69 |
+
n_classes,
|
| 70 |
+
pt_feat_dim,
|
| 71 |
+
cfg.Z_DIM,
|
| 72 |
+
cfg.MLP_HIDDEN_DIM,
|
| 73 |
+
cfg.MLP_N_SHARED_LAYERS,
|
| 74 |
+
cfg.ATTR_FACTORS,
|
| 75 |
+
cfg.ATTR_N_LAYERS,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self, proj_uv, rel_xyz, batch_idx, onehots, z, proj_hf, proj_seg):
|
| 79 |
+
# Ref: https://github.com/hzxie/CityDreamer/blob/master/models/gancraft.py#L381
|
| 80 |
+
if self.cfg.ENCODER == "GLOBAL":
|
| 81 |
+
proj_feat = self.proj_encoder(proj_hf, proj_seg)
|
| 82 |
+
pt_feat = proj_feat.unsqueeze(dim=1).repeat(1, proj_uv.size(1), 1)
|
| 83 |
+
elif self.cfg.ENCODER == "LOCAL":
|
| 84 |
+
proj_feat = self.proj_encoder(proj_hf, proj_seg)
|
| 85 |
+
pt_feat = (
|
| 86 |
+
F.grid_sample(proj_feat, proj_uv.unsqueeze(dim=1), align_corners=True)
|
| 87 |
+
.squeeze(dim=2)
|
| 88 |
+
.permute(0, 2, 1)
|
| 89 |
+
)
|
| 90 |
+
elif self.cfg.ENCODER is None:
|
| 91 |
+
pt_feat = torch.empty(
|
| 92 |
+
rel_xyz.size(0), rel_xyz.size(1), 0, device=proj_uv.device
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# print(pt_feat.size()) # torch.Size([B, n_pts, cfg.ENCODER_OUT_DIM - 3])
|
| 96 |
+
pt_feat = torch.cat([pt_feat, rel_xyz], dim=2)
|
| 97 |
+
# print(pt_feat.size()) # torch.Size([B, n_pts, cfg.ENCODER_OUT_DIM])
|
| 98 |
+
pt_feat1 = self.pos_encoder(pt_feat)
|
| 99 |
+
# print(pt_feat1.size()) # torch.Size([B, n_pts, pt_feat_dim])
|
| 100 |
+
if self.pt_net is None:
|
| 101 |
+
pt_feat2 = torch.empty(
|
| 102 |
+
rel_xyz.size(0), rel_xyz.size(1), 0, device=proj_uv.device
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
pt_feat2 = self.pt_net(batch_idx, pt_feat1, rel_xyz)
|
| 106 |
+
|
| 107 |
+
# print(pt_feat2.size()) # torch.Size([B, n_pts, pt_feat_dim])
|
| 108 |
+
return self.ga_mlp(torch.cat([pt_feat1, pt_feat2], dim=-1), onehots, z)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class GlobalEncoder(torch.nn.Module):
|
| 112 |
+
def __init__(self, n_classes, n_blocks, out_channels):
|
| 113 |
+
super(GlobalEncoder, self).__init__()
|
| 114 |
+
self.hf_conv = torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1)
|
| 115 |
+
self.seg_conv = torch.nn.Conv2d(
|
| 116 |
+
n_classes,
|
| 117 |
+
8,
|
| 118 |
+
kernel_size=3,
|
| 119 |
+
stride=2,
|
| 120 |
+
padding=1,
|
| 121 |
+
)
|
| 122 |
+
conv_blocks = []
|
| 123 |
+
cur_hidden_channels = 16
|
| 124 |
+
for _ in range(1, n_blocks):
|
| 125 |
+
conv_blocks.append(
|
| 126 |
+
SRTConvBlock(in_channels=cur_hidden_channels, out_channels=None)
|
| 127 |
+
)
|
| 128 |
+
cur_hidden_channels *= 2
|
| 129 |
+
|
| 130 |
+
self.conv_blocks = torch.nn.Sequential(*conv_blocks)
|
| 131 |
+
self.fc1 = torch.nn.Linear(cur_hidden_channels, 16)
|
| 132 |
+
self.fc2 = torch.nn.Linear(16, out_channels)
|
| 133 |
+
self.act = torch.nn.LeakyReLU(0.2)
|
| 134 |
+
|
| 135 |
+
def forward(self, proj_hf, proj_seg):
|
| 136 |
+
hf = self.act(self.hf_conv(proj_hf))
|
| 137 |
+
seg = self.act(self.seg_conv(proj_seg))
|
| 138 |
+
out = torch.cat([hf, seg], dim=1)
|
| 139 |
+
for layer in self.conv_blocks:
|
| 140 |
+
out = self.act(layer(out))
|
| 141 |
+
|
| 142 |
+
out = out.permute(0, 2, 3, 1)
|
| 143 |
+
out = torch.mean(out.reshape(out.shape[0], -1, out.shape[-1]), dim=1)
|
| 144 |
+
cond = self.act(self.fc1(out))
|
| 145 |
+
cond = torch.tanh(self.fc2(cond))
|
| 146 |
+
return cond
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class LocalEncoder(torch.nn.Module):
|
| 150 |
+
def __init__(self, n_classes, out_channels):
|
| 151 |
+
super(LocalEncoder, self).__init__()
|
| 152 |
+
self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3)
|
| 153 |
+
self.seg_conv = torch.nn.Conv2d(
|
| 154 |
+
n_classes, 32, kernel_size=7, stride=2, padding=3
|
| 155 |
+
)
|
| 156 |
+
self.bn1 = torch.nn.GroupNorm(32, 64)
|
| 157 |
+
self.conv2 = ResConvBlock(64, 128)
|
| 158 |
+
self.conv3 = ResConvBlock(128, 256)
|
| 159 |
+
self.conv4 = ResConvBlock(256, 512)
|
| 160 |
+
self.dconv5 = torch.nn.ConvTranspose2d(
|
| 161 |
+
512, 128, kernel_size=4, stride=2, padding=1
|
| 162 |
+
)
|
| 163 |
+
self.dconv6 = torch.nn.ConvTranspose2d(
|
| 164 |
+
128, 32, kernel_size=4, stride=2, padding=1
|
| 165 |
+
)
|
| 166 |
+
self.dconv7 = torch.nn.Conv2d(32, out_channels, kernel_size=1)
|
| 167 |
+
|
| 168 |
+
def forward(self, proj_hf, proj_seg):
|
| 169 |
+
hf = self.hf_conv(proj_hf)
|
| 170 |
+
seg = self.seg_conv(proj_seg)
|
| 171 |
+
out = F.relu(self.bn1(torch.cat([hf, seg], dim=1)), inplace=True)
|
| 172 |
+
# print(out.size()) # torch.Size([N, 64, H/2, W/2])
|
| 173 |
+
out = F.avg_pool2d(self.conv2(out), 2, stride=2)
|
| 174 |
+
# print(out.size()) # torch.Size([N, 128, H/4, W/4])
|
| 175 |
+
out = self.conv3(out)
|
| 176 |
+
# print(out.size()) # torch.Size([N, 256, H/4, W/4])
|
| 177 |
+
out = self.conv4(out)
|
| 178 |
+
# print(out.size()) # torch.Size([N, 512, H/4, W/4])
|
| 179 |
+
out = self.dconv5(out)
|
| 180 |
+
# print(out.size()) # torch.Size([N, 128, H/2, W/2])
|
| 181 |
+
out = self.dconv6(out)
|
| 182 |
+
# print(out.size()) # torch.Size([N, 32, H, W])
|
| 183 |
+
out = self.dconv7(out)
|
| 184 |
+
# print(out.size()) # torch.Size([N, OUT_DIM - 1, H, W])
|
| 185 |
+
return torch.tanh(out)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class SRTConvBlock(torch.nn.Module):
|
| 189 |
+
def __init__(self, in_channels, hidden_channels=None, out_channels=None):
|
| 190 |
+
super(SRTConvBlock, self).__init__()
|
| 191 |
+
if hidden_channels is None:
|
| 192 |
+
hidden_channels = in_channels
|
| 193 |
+
if out_channels is None:
|
| 194 |
+
out_channels = 2 * hidden_channels
|
| 195 |
+
|
| 196 |
+
self.layers = torch.nn.Sequential(
|
| 197 |
+
torch.nn.Conv2d(
|
| 198 |
+
in_channels,
|
| 199 |
+
hidden_channels,
|
| 200 |
+
stride=1,
|
| 201 |
+
kernel_size=3,
|
| 202 |
+
padding=1,
|
| 203 |
+
bias=False,
|
| 204 |
+
),
|
| 205 |
+
torch.nn.ReLU(),
|
| 206 |
+
torch.nn.Conv2d(
|
| 207 |
+
hidden_channels,
|
| 208 |
+
out_channels,
|
| 209 |
+
stride=2,
|
| 210 |
+
kernel_size=3,
|
| 211 |
+
padding=1,
|
| 212 |
+
bias=False,
|
| 213 |
+
),
|
| 214 |
+
torch.nn.ReLU(),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
return self.layers(x)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class ResConvBlock(torch.nn.Module):
|
| 222 |
+
def __init__(self, in_channels, out_channels, bias=False):
|
| 223 |
+
super(ResConvBlock, self).__init__()
|
| 224 |
+
# conv3x3(in_planes, int(out_planes / 2))
|
| 225 |
+
self.conv1 = torch.nn.Conv2d(
|
| 226 |
+
in_channels,
|
| 227 |
+
out_channels // 2,
|
| 228 |
+
kernel_size=3,
|
| 229 |
+
stride=1,
|
| 230 |
+
padding=1,
|
| 231 |
+
bias=bias,
|
| 232 |
+
)
|
| 233 |
+
# conv3x3(int(out_planes / 2), int(out_planes / 4))
|
| 234 |
+
self.conv2 = torch.nn.Conv2d(
|
| 235 |
+
out_channels // 2,
|
| 236 |
+
out_channels // 4,
|
| 237 |
+
kernel_size=3,
|
| 238 |
+
stride=1,
|
| 239 |
+
padding=1,
|
| 240 |
+
bias=bias,
|
| 241 |
+
)
|
| 242 |
+
# conv3x3(int(out_planes / 4), int(out_planes / 4))
|
| 243 |
+
self.conv3 = torch.nn.Conv2d(
|
| 244 |
+
out_channels // 4,
|
| 245 |
+
out_channels // 4,
|
| 246 |
+
kernel_size=3,
|
| 247 |
+
stride=1,
|
| 248 |
+
padding=1,
|
| 249 |
+
bias=bias,
|
| 250 |
+
)
|
| 251 |
+
self.bn1 = torch.nn.GroupNorm(32, in_channels)
|
| 252 |
+
self.bn2 = torch.nn.GroupNorm(32, out_channels // 2)
|
| 253 |
+
self.bn3 = torch.nn.GroupNorm(32, out_channels // 4)
|
| 254 |
+
self.bn4 = torch.nn.GroupNorm(32, in_channels)
|
| 255 |
+
|
| 256 |
+
if in_channels != out_channels:
|
| 257 |
+
self.downsample = torch.nn.Sequential(
|
| 258 |
+
self.bn4,
|
| 259 |
+
torch.nn.ReLU(True),
|
| 260 |
+
torch.nn.Conv2d(
|
| 261 |
+
in_channels, out_channels, kernel_size=1, stride=1, bias=False
|
| 262 |
+
),
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
self.downsample = None
|
| 266 |
+
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
residual = x
|
| 269 |
+
# print(residual.size()) # torch.Size([N, 64, H, W])
|
| 270 |
+
out1 = self.bn1(x)
|
| 271 |
+
out1 = F.relu(out1, True)
|
| 272 |
+
out1 = self.conv1(out1)
|
| 273 |
+
# print(out1.size()) # torch.Size([N, 64, H, W])
|
| 274 |
+
out2 = self.bn2(out1)
|
| 275 |
+
out2 = F.relu(out2, True)
|
| 276 |
+
out2 = self.conv2(out2)
|
| 277 |
+
# print(out2.size()) # torch.Size([N, 32, H, W])
|
| 278 |
+
out3 = self.bn3(out2)
|
| 279 |
+
out3 = F.relu(out3, True)
|
| 280 |
+
out3 = self.conv3(out3)
|
| 281 |
+
# print(out3.size()) # torch.Size([N, 32, H, W])
|
| 282 |
+
out3 = torch.cat((out1, out2, out3), dim=1)
|
| 283 |
+
# print(out3.size()) # torch.Size([N, 128, H, W])
|
| 284 |
+
if self.downsample is not None:
|
| 285 |
+
residual = self.downsample(residual)
|
| 286 |
+
# print(residual.size()) # torch.Size([N, 128, H, W])
|
| 287 |
+
out3 += residual
|
| 288 |
+
return out3
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class SinCosEncoder(torch.nn.Module):
|
| 292 |
+
def __init__(self, n_freq_bands=8):
|
| 293 |
+
super(SinCosEncoder, self).__init__()
|
| 294 |
+
self.freq_bands = 2.0 ** torch.linspace(
|
| 295 |
+
0,
|
| 296 |
+
n_freq_bands - 1,
|
| 297 |
+
steps=n_freq_bands,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def forward(self, features):
|
| 301 |
+
cord_sin = torch.cat(
|
| 302 |
+
[torch.sin(features * fb) for fb in self.freq_bands], dim=-1
|
| 303 |
+
)
|
| 304 |
+
cord_cos = torch.cat(
|
| 305 |
+
[torch.cos(features * fb) for fb in self.freq_bands], dim=-1
|
| 306 |
+
)
|
| 307 |
+
return torch.cat([cord_sin, cord_cos], dim=-1)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class GaussianAttrMLP(torch.nn.Module):
|
| 311 |
+
r"""MLP with affine modulation."""
|
| 312 |
+
|
| 313 |
+
def __init__(
|
| 314 |
+
self,
|
| 315 |
+
n_classes,
|
| 316 |
+
in_dim,
|
| 317 |
+
z_dim,
|
| 318 |
+
hidden_dim,
|
| 319 |
+
n_shared_layers,
|
| 320 |
+
factors={},
|
| 321 |
+
n_layers={},
|
| 322 |
+
):
|
| 323 |
+
super(GaussianAttrMLP, self).__init__()
|
| 324 |
+
self.factors = factors
|
| 325 |
+
self.n_layers = n_layers
|
| 326 |
+
self.n_shared_layers = n_shared_layers
|
| 327 |
+
self.act = torch.nn.LeakyReLU(negative_slope=0.2)
|
| 328 |
+
self.fc_m_a = torch.nn.Linear(
|
| 329 |
+
n_classes,
|
| 330 |
+
hidden_dim,
|
| 331 |
+
bias=False,
|
| 332 |
+
)
|
| 333 |
+
self.fc_1 = torch.nn.Linear(
|
| 334 |
+
in_dim,
|
| 335 |
+
hidden_dim,
|
| 336 |
+
)
|
| 337 |
+
for i in range(2, n_shared_layers + 1):
|
| 338 |
+
setattr(
|
| 339 |
+
self,
|
| 340 |
+
"fc_%d" % i,
|
| 341 |
+
(
|
| 342 |
+
ModLinear(
|
| 343 |
+
hidden_dim,
|
| 344 |
+
hidden_dim,
|
| 345 |
+
z_dim,
|
| 346 |
+
bias=False,
|
| 347 |
+
mod_bias=True,
|
| 348 |
+
output_mode=True,
|
| 349 |
+
)
|
| 350 |
+
if z_dim is not None
|
| 351 |
+
else torch.nn.Linear(hidden_dim, hidden_dim)
|
| 352 |
+
),
|
| 353 |
+
)
|
| 354 |
+
for k in factors.keys():
|
| 355 |
+
assert k in ["xyz", "rgb", "scale", "opacity"], "Unknwon key: %s" % k
|
| 356 |
+
for i in range(n_layers[k]):
|
| 357 |
+
setattr(
|
| 358 |
+
self,
|
| 359 |
+
"fc_%d_%s_%d" % (n_shared_layers + 1, k, i),
|
| 360 |
+
(
|
| 361 |
+
ModLinear(
|
| 362 |
+
hidden_dim,
|
| 363 |
+
hidden_dim,
|
| 364 |
+
z_dim,
|
| 365 |
+
bias=False,
|
| 366 |
+
mod_bias=True,
|
| 367 |
+
output_mode=True,
|
| 368 |
+
)
|
| 369 |
+
if z_dim is not None
|
| 370 |
+
else torch.nn.Linear(hidden_dim, hidden_dim)
|
| 371 |
+
),
|
| 372 |
+
)
|
| 373 |
+
setattr(
|
| 374 |
+
self,
|
| 375 |
+
"fc_out_%s" % k,
|
| 376 |
+
torch.nn.Linear(
|
| 377 |
+
hidden_dim,
|
| 378 |
+
1 if k == "opacity" else 3,
|
| 379 |
+
),
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def forward(self, pt_feat, onehots, zs):
|
| 383 |
+
b, n, _ = pt_feat.size()
|
| 384 |
+
|
| 385 |
+
f = self.fc_1(pt_feat)
|
| 386 |
+
f = f + self.fc_m_a(onehots)
|
| 387 |
+
f = self.act(f)
|
| 388 |
+
if zs is None:
|
| 389 |
+
output = self._instance_forward(f)
|
| 390 |
+
else:
|
| 391 |
+
output = {
|
| 392 |
+
k: torch.zeros(b, n, 1 if k == "opacity" else 3, device=pt_feat.device)
|
| 393 |
+
for k in self.factors.keys()
|
| 394 |
+
}
|
| 395 |
+
for v in zs.values():
|
| 396 |
+
z = v["z"]
|
| 397 |
+
idx = v["idx"]
|
| 398 |
+
_output = self._instance_forward(f[idx].unsqueeze(dim=0), z)
|
| 399 |
+
for k, v in _output.items():
|
| 400 |
+
output[k][idx] = v
|
| 401 |
+
|
| 402 |
+
return output
|
| 403 |
+
|
| 404 |
+
def _instance_forward(self, f, z=None):
|
| 405 |
+
for i in range(2, self.n_shared_layers + 1):
|
| 406 |
+
fc = getattr(self, "fc_%d" % i)
|
| 407 |
+
f = self.act(fc(f, z) if z is not None else fc(f))
|
| 408 |
+
|
| 409 |
+
output = {}
|
| 410 |
+
for k in self.factors.keys():
|
| 411 |
+
_f = f.clone()
|
| 412 |
+
for i in range(self.n_layers[k]):
|
| 413 |
+
_fc = getattr(self, "fc_%d_%s_%d" % (self.n_shared_layers + 1, k, i))
|
| 414 |
+
_f = self.act(_fc(_f, z) if z is not None else _fc(f))
|
| 415 |
+
|
| 416 |
+
fc_out = getattr(self, "fc_out_%s" % k)
|
| 417 |
+
output[k] = fc_out(_f)
|
| 418 |
+
|
| 419 |
+
if "xyz" in self.factors:
|
| 420 |
+
output["xyz"] = (torch.sigmoid(output["xyz"]) - 0.5) * self.factors["xyz"]
|
| 421 |
+
if "rgb" in self.factors:
|
| 422 |
+
output["rgb"] = (torch.sigmoid(output["rgb"]) - 0.5) * self.factors["rgb"]
|
| 423 |
+
if "scale" in self.factors:
|
| 424 |
+
output["scale"] = 1 + output["scale"].clamp(-1, 1) * self.factors["scale"]
|
| 425 |
+
if "opacity" in self.factors:
|
| 426 |
+
output["opacity"] = torch.sigmoid(output["opacity"]) * self.factors[
|
| 427 |
+
"opacity"
|
| 428 |
+
] + (1 - self.factors["opacity"])
|
| 429 |
+
|
| 430 |
+
return output
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
class ModLinear(torch.nn.Module):
|
| 434 |
+
r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod).
|
| 435 |
+
Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across
|
| 436 |
+
multiple inputs.
|
| 437 |
+
Args:
|
| 438 |
+
in_features (int): Number of input features.
|
| 439 |
+
out_features (int): Number of output features.
|
| 440 |
+
style_features (int): Number of style features.
|
| 441 |
+
bias (bool): Apply additive bias before the activation function?
|
| 442 |
+
mod_bias (bool): Whether to modulate bias.
|
| 443 |
+
output_mode (bool): If True, modulate output instead of input.
|
| 444 |
+
weight_gain (float): Initialization gain
|
| 445 |
+
"""
|
| 446 |
+
|
| 447 |
+
def __init__(
|
| 448 |
+
self,
|
| 449 |
+
in_features,
|
| 450 |
+
out_features,
|
| 451 |
+
style_features,
|
| 452 |
+
bias=True,
|
| 453 |
+
mod_bias=True,
|
| 454 |
+
output_mode=False,
|
| 455 |
+
weight_gain=1,
|
| 456 |
+
bias_init=0,
|
| 457 |
+
):
|
| 458 |
+
super(ModLinear, self).__init__()
|
| 459 |
+
weight_gain = weight_gain / np.sqrt(in_features)
|
| 460 |
+
self.weight = torch.nn.Parameter(
|
| 461 |
+
torch.randn([out_features, in_features]) * weight_gain
|
| 462 |
+
)
|
| 463 |
+
self.bias = (
|
| 464 |
+
torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
|
| 465 |
+
if bias
|
| 466 |
+
else None
|
| 467 |
+
)
|
| 468 |
+
self.weight_alpha = torch.nn.Parameter(
|
| 469 |
+
torch.randn([in_features, style_features]) / np.sqrt(style_features)
|
| 470 |
+
)
|
| 471 |
+
self.bias_alpha = torch.nn.Parameter(
|
| 472 |
+
torch.full([in_features], 1, dtype=torch.float)
|
| 473 |
+
) # init to 1
|
| 474 |
+
self.weight_beta = None
|
| 475 |
+
self.bias_beta = None
|
| 476 |
+
self.mod_bias = mod_bias
|
| 477 |
+
self.output_mode = output_mode
|
| 478 |
+
if mod_bias:
|
| 479 |
+
if output_mode:
|
| 480 |
+
mod_bias_dims = out_features
|
| 481 |
+
else:
|
| 482 |
+
mod_bias_dims = in_features
|
| 483 |
+
self.weight_beta = torch.nn.Parameter(
|
| 484 |
+
torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features)
|
| 485 |
+
)
|
| 486 |
+
self.bias_beta = torch.nn.Parameter(
|
| 487 |
+
torch.full([mod_bias_dims], 0, dtype=torch.float)
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
@staticmethod
|
| 491 |
+
def _linear_f(x, w, b):
|
| 492 |
+
w = w.to(x.dtype)
|
| 493 |
+
x_shape = x.shape
|
| 494 |
+
x = x.reshape(-1, x_shape[-1])
|
| 495 |
+
if b is not None:
|
| 496 |
+
b = b.to(x.dtype)
|
| 497 |
+
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
| 498 |
+
else:
|
| 499 |
+
x = x.matmul(w.t())
|
| 500 |
+
x = x.reshape(*x_shape[:-1], -1)
|
| 501 |
+
return x
|
| 502 |
+
|
| 503 |
+
# x: B, ... , Cin
|
| 504 |
+
# z: B, ... , Cz
|
| 505 |
+
def forward(self, x, z):
|
| 506 |
+
x_shape = x.shape
|
| 507 |
+
z_shape = z.shape
|
| 508 |
+
x = x.reshape(x_shape[0], -1, x_shape[-1])
|
| 509 |
+
z = z.reshape(z_shape[0], -1, z_shape[-1])
|
| 510 |
+
|
| 511 |
+
alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I]
|
| 512 |
+
w = self.weight.to(x.dtype) # [O I]
|
| 513 |
+
w = w.unsqueeze(0) * alpha
|
| 514 |
+
|
| 515 |
+
if self.mod_bias:
|
| 516 |
+
beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I]
|
| 517 |
+
if not self.output_mode:
|
| 518 |
+
x = x + beta
|
| 519 |
+
|
| 520 |
+
b = self.bias
|
| 521 |
+
if b is not None:
|
| 522 |
+
b = b.to(x.dtype)[None, None, :]
|
| 523 |
+
if self.mod_bias and self.output_mode:
|
| 524 |
+
if b is None:
|
| 525 |
+
b = beta
|
| 526 |
+
else:
|
| 527 |
+
b = b + beta
|
| 528 |
+
|
| 529 |
+
# [B ? I] @ [B I O] = [B ? O]
|
| 530 |
+
if b is not None:
|
| 531 |
+
x = torch.baddbmm(b, x, w.transpose(1, 2))
|
| 532 |
+
else:
|
| 533 |
+
x = x.bmm(w.transpose(1, 2))
|
| 534 |
+
|
| 535 |
+
x = x.reshape(*x_shape[:-1], x.shape[-1])
|
| 536 |
+
return x
|
gaussiancity/inference.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: inference.py
|
| 4 |
+
# @Author: Haozhe Xie
|
| 5 |
+
# @Date: 2024-03-02 16:30:00
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-10-13 15:17:20
|
| 8 |
+
# @Email: root@haozhexie.com
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import math
|
| 12 |
+
import numpy as np
|
| 13 |
+
import scipy.spatial.transform
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
CLASSES = {
|
| 19 |
+
"NULL": 0,
|
| 20 |
+
"ROAD": 1,
|
| 21 |
+
"BLDG_FACADE": 2,
|
| 22 |
+
"GREEN_LANDS": 3,
|
| 23 |
+
"CONSTRUCTION": 4,
|
| 24 |
+
"COAST_ZONES": 5,
|
| 25 |
+
"ZONE": 6,
|
| 26 |
+
"BLDG_ROOF": 7,
|
| 27 |
+
}
|
| 28 |
+
SCALES = {
|
| 29 |
+
"ROAD": 2,
|
| 30 |
+
"BLDG_FACADE": 1,
|
| 31 |
+
"BLDG_ROOF": 1,
|
| 32 |
+
"GREEN_LANDS": 2,
|
| 33 |
+
"CONSTRUCTION": 1,
|
| 34 |
+
"COAST_ZONES": 4,
|
| 35 |
+
"ZONE": 2,
|
| 36 |
+
}
|
| 37 |
+
CONSTANTS = {
|
| 38 |
+
"CAM_K": [1528.1469407006614, 0, 480, 0, 1528.1469407006614, 270, 0, 0, 1],
|
| 39 |
+
"SENSOR_SIZE": [960, 540],
|
| 40 |
+
"BLDG_INST_RANGE": [100, 16384],
|
| 41 |
+
"PROJECTION_SIZE": 2048,
|
| 42 |
+
"POINT_SCALE_FACTOR": 0.5,
|
| 43 |
+
"SPECIAL_Z_SCALE_CLASSES": [
|
| 44 |
+
CLASSES["ROAD"],
|
| 45 |
+
CLASSES["COAST_ZONES"],
|
| 46 |
+
CLASSES["ZONE"],
|
| 47 |
+
],
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_instance_seg_map(seg_map):
|
| 52 |
+
# Mapping constructions to buildings
|
| 53 |
+
seg_map[seg_map == CLASSES["CONSTRUCTION"]] = CLASSES["BLDG_FACADE"]
|
| 54 |
+
# Use connected components to get building instances
|
| 55 |
+
import pdb; pdb.set_trace()
|
| 56 |
+
_, labels, _, _ = cv2.connectedComponentsWithStats(
|
| 57 |
+
(seg_map == CLASSES["BLDG_FACADE"]).astype(np.uint8), connectivity=4
|
| 58 |
+
)
|
| 59 |
+
# Remove non-building instance masks
|
| 60 |
+
labels[seg_map != CLASSES["BLDG_FACADE"]] = 0
|
| 61 |
+
# Building instance mask
|
| 62 |
+
building_mask = labels != 0
|
| 63 |
+
# Make building instance IDs are even numbers and start from 10
|
| 64 |
+
# Assume the ID of a facade instance is 2k, the corresponding roof instance is 2k - 1.
|
| 65 |
+
labels = (labels + CONSTANTS["BLDG_INST_RANGE"][0]) * 2
|
| 66 |
+
|
| 67 |
+
seg_map[seg_map == CLASSES["BLDG_FACADE"]] = 0
|
| 68 |
+
seg_map = seg_map * (1 - building_mask) + labels * building_mask
|
| 69 |
+
assert np.max(labels) < 2147483648
|
| 70 |
+
return seg_map.astype(np.int32)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_point_map(seg_map):
|
| 74 |
+
inverted_index = {v: k for k, v in CLASSES.items()}
|
| 75 |
+
pts_map = np.zeros(seg_map.shape, dtype=bool)
|
| 76 |
+
for c in np.unique(seg_map):
|
| 77 |
+
cls_name = inverted_index[c]
|
| 78 |
+
if cls_name == "NULL":
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
mask = seg_map == c
|
| 82 |
+
pt_map = _get_point_map(seg_map.shape, SCALES[cls_name])
|
| 83 |
+
pt_map[~mask] = False
|
| 84 |
+
pts_map += pt_map
|
| 85 |
+
|
| 86 |
+
return pts_map
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _get_point_map(map_size, stride):
|
| 90 |
+
pts_map = np.zeros(map_size, dtype=bool)
|
| 91 |
+
ys = np.arange(0, map_size[0], stride)
|
| 92 |
+
xs = np.arange(0, map_size[1], stride)
|
| 93 |
+
coords = np.stack(np.meshgrid(ys, xs), axis=-1).reshape(-1, 2)
|
| 94 |
+
pts_map[coords[:, 0], coords[:, 1]] = True
|
| 95 |
+
return pts_map
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_centers(ins_map, td_hf):
|
| 99 |
+
centers = {}
|
| 100 |
+
instances = np.unique(ins_map)
|
| 101 |
+
for i in tqdm(instances, desc="Calculating centers ..."):
|
| 102 |
+
if i >= CONSTANTS["BLDG_INST_RANGE"][0]:
|
| 103 |
+
ds_mask = ins_map == i
|
| 104 |
+
contours, _ = cv2.findContours(
|
| 105 |
+
ds_mask.astype(np.uint8),
|
| 106 |
+
cv2.RETR_EXTERNAL,
|
| 107 |
+
cv2.CHAIN_APPROX_SIMPLE,
|
| 108 |
+
)
|
| 109 |
+
contours = np.vstack(contours).reshape(-1, 2)
|
| 110 |
+
min_x, max_x = np.min(contours[:, 0]), np.max(contours[:, 0])
|
| 111 |
+
min_y, max_y = np.min(contours[:, 1]), np.max(contours[:, 1])
|
| 112 |
+
max_z = np.max(td_hf[ds_mask]) + 1
|
| 113 |
+
else:
|
| 114 |
+
min_x, max_x = 0, CONSTANTS["PROJECTION_SIZE"]
|
| 115 |
+
min_y, max_y = 0, CONSTANTS["PROJECTION_SIZE"]
|
| 116 |
+
max_z = np.max(td_hf)
|
| 117 |
+
|
| 118 |
+
centers[i] = np.array(
|
| 119 |
+
[
|
| 120 |
+
(min_x + max_x) / 2,
|
| 121 |
+
(min_y + max_y) / 2,
|
| 122 |
+
(max_x - min_x),
|
| 123 |
+
(max_y - min_y),
|
| 124 |
+
max_z,
|
| 125 |
+
],
|
| 126 |
+
dtype=np.float32,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return centers
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def generate_city(
|
| 133 |
+
fgm, bgm, city_layout, cx, cy, radius, altitude, azimuth, style_lut=None
|
| 134 |
+
):
|
| 135 |
+
import gaussiancity.extensions.diff_gaussian_rasterization as dgr
|
| 136 |
+
|
| 137 |
+
device = torch.device("cuda")
|
| 138 |
+
gr = dgr.GaussianRasterizerWrapper(
|
| 139 |
+
np.array(CONSTANTS["CAM_K"], dtype=np.float32).reshape((3, 3)),
|
| 140 |
+
CONSTANTS["SENSOR_SIZE"],
|
| 141 |
+
flip_lr=True,
|
| 142 |
+
flip_ud=False,
|
| 143 |
+
device=device,
|
| 144 |
+
)
|
| 145 |
+
layout = _get_local_layout(
|
| 146 |
+
city_layout,
|
| 147 |
+
cx,
|
| 148 |
+
cy,
|
| 149 |
+
CONSTANTS["PROJECTION_SIZE"] // 2,
|
| 150 |
+
CONSTANTS["BLDG_INST_RANGE"],
|
| 151 |
+
device,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
bev_pts = _get_bev_points(layout, SCALES, CLASSES)
|
| 155 |
+
bev_pt_classes = _instances_to_classes(
|
| 156 |
+
bev_pts[:, [3]], CONSTANTS["BLDG_INST_RANGE"], CLASSES
|
| 157 |
+
)
|
| 158 |
+
bev_pt_classes_onehot = _get_onehot_seg(bev_pt_classes, len(CLASSES))
|
| 159 |
+
bev_pt_scales = _get_point_scales(
|
| 160 |
+
bev_pt_classes,
|
| 161 |
+
SCALES,
|
| 162 |
+
CLASSES,
|
| 163 |
+
CONSTANTS["SPECIAL_Z_SCALE_CLASSES"],
|
| 164 |
+
)
|
| 165 |
+
bev_pts = torch.cat([bev_pts, bev_pt_scales, bev_pt_classes_onehot], dim=1)
|
| 166 |
+
# print(bev_pts.shape) # [N, XYZ + Inst + Scale3D + N_CLASSES]
|
| 167 |
+
if style_lut is None:
|
| 168 |
+
style_lut = _get_style_lut(
|
| 169 |
+
layout["CTR"],
|
| 170 |
+
{"BLDG": fgm, "REST": bgm},
|
| 171 |
+
{
|
| 172 |
+
"BLDG": CONSTANTS["BLDG_INST_RANGE"],
|
| 173 |
+
"REST": [0, CONSTANTS["BLDG_INST_RANGE"][0]],
|
| 174 |
+
},
|
| 175 |
+
device,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
cam_look_at, cam_pose = _get_orbit_camera_pose(
|
| 179 |
+
radius, altitude, azimuth, CONSTANTS["PROJECTION_SIZE"] // 2, device
|
| 180 |
+
)
|
| 181 |
+
vp_idx = _get_visible_points(
|
| 182 |
+
bev_pts[:, :3],
|
| 183 |
+
bev_pt_scales,
|
| 184 |
+
CONSTANTS["CAM_K"],
|
| 185 |
+
CONSTANTS["SENSOR_SIZE"],
|
| 186 |
+
cam_pose[:3],
|
| 187 |
+
cam_look_at,
|
| 188 |
+
)
|
| 189 |
+
gs_attrs = _get_gs_attrs(
|
| 190 |
+
bev_pts[vp_idx],
|
| 191 |
+
layout["TD_HF"].float(),
|
| 192 |
+
layout["SEG"].float(),
|
| 193 |
+
style_lut,
|
| 194 |
+
layout["CTR"],
|
| 195 |
+
{"BLDG": fgm, "REST": bgm},
|
| 196 |
+
CONSTANTS["POINT_SCALE_FACTOR"],
|
| 197 |
+
CONSTANTS["BLDG_INST_RANGE"],
|
| 198 |
+
)
|
| 199 |
+
return _render(gs_attrs, gr, cam_pose)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _get_local_layout(city_layout, cx, cy, half_proj_size, bldg_inst_range, device):
|
| 203 |
+
x_min, x_max = cx - half_proj_size, cx + half_proj_size
|
| 204 |
+
y_min, y_max = cy - half_proj_size, cy + half_proj_size
|
| 205 |
+
|
| 206 |
+
_layout = {
|
| 207 |
+
k: torch.from_numpy(v[None, None, y_min:y_max, x_min:x_max]).cuda(device)
|
| 208 |
+
for k, v in city_layout.items()
|
| 209 |
+
if k in ["TD_HF", "BU_HF", "SEG", "INS", "PTS"]
|
| 210 |
+
}
|
| 211 |
+
_layout["SEG"] = _get_onehot_seg(_layout["SEG"], len(CLASSES))
|
| 212 |
+
|
| 213 |
+
_instances = torch.unique(_layout["INS"])
|
| 214 |
+
_centers = {}
|
| 215 |
+
for inst in _instances:
|
| 216 |
+
inst = inst.item()
|
| 217 |
+
if inst >= bldg_inst_range[0]:
|
| 218 |
+
_centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device)
|
| 219 |
+
_centers[inst][0] -= x_min
|
| 220 |
+
_centers[inst][1] -= y_min
|
| 221 |
+
_centers[inst + 1] = _centers[inst] # Fix the centers for BLDG_ROOF
|
| 222 |
+
else:
|
| 223 |
+
_centers[inst] = torch.from_numpy(city_layout["CTR"][inst]).cuda(device)
|
| 224 |
+
_centers[inst][0] = x_min
|
| 225 |
+
_centers[inst][1] = y_min
|
| 226 |
+
|
| 227 |
+
_layout["CTR"] = _centers
|
| 228 |
+
return _layout
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _get_onehot_seg(seg_map, n_classes):
|
| 232 |
+
shape = seg_map.shape
|
| 233 |
+
# shape -> NxCxHxW or NxC
|
| 234 |
+
# assert shape[1] == 1
|
| 235 |
+
output_shape = (shape[0], n_classes, *shape[2:])
|
| 236 |
+
|
| 237 |
+
one_hot_masks = torch.zeros(output_shape, device=seg_map.device, dtype=torch.bool)
|
| 238 |
+
for i in range(n_classes):
|
| 239 |
+
one_hot_masks[:, [i]] = seg_map == i
|
| 240 |
+
|
| 241 |
+
return one_hot_masks
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _get_style_lut(centers, models, inst_ranges, device, z_dim=256):
|
| 245 |
+
lut = {ins: torch.rand(1, z_dim, device=device) for ins in centers.keys()}
|
| 246 |
+
for k, v in models.items():
|
| 247 |
+
if v is None:
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
if v.module.cfg.Z_DIM is None:
|
| 251 |
+
for i in range(*inst_ranges[k]):
|
| 252 |
+
if i in lut:
|
| 253 |
+
del lut[i]
|
| 254 |
+
continue
|
| 255 |
+
|
| 256 |
+
if hasattr(v.module, "z"):
|
| 257 |
+
zs = v.module.z
|
| 258 |
+
lut.update(
|
| 259 |
+
{
|
| 260 |
+
ins: zs[np.random.choice(list(zs.keys()))].unsqueeze(0)
|
| 261 |
+
for ins in centers.keys()
|
| 262 |
+
}
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return lut
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def _get_orbit_camera_pose(radius, altitude, azimuth, half_proj_size, device):
|
| 269 |
+
cx, cy = half_proj_size, half_proj_size
|
| 270 |
+
theta = np.deg2rad(azimuth)
|
| 271 |
+
cam_x = cx + radius * math.cos(theta)
|
| 272 |
+
cam_y = cy + radius * math.sin(theta)
|
| 273 |
+
|
| 274 |
+
cam_pos = np.array([cam_x, cam_y, altitude], dtype=np.float32)
|
| 275 |
+
cam_look_at = np.array([cx, cy, 1], dtype=np.float32)
|
| 276 |
+
quat = _get_quat_from_look_at(cam_pos, cam_look_at)
|
| 277 |
+
return torch.tensor([*cam_look_at], device=device), torch.tensor(
|
| 278 |
+
[*cam_pos, *quat], device=device
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def _get_quat_from_look_at(cam_pos, cam_look_at):
|
| 283 |
+
fwd_vec = cam_look_at - cam_pos
|
| 284 |
+
fwd_vec /= np.linalg.norm(fwd_vec)
|
| 285 |
+
up_vec = np.array([0, 0, 1])
|
| 286 |
+
right_vec = np.cross(up_vec, fwd_vec)
|
| 287 |
+
right_vec /= np.linalg.norm(right_vec)
|
| 288 |
+
up_vec = np.cross(fwd_vec, right_vec)
|
| 289 |
+
R = np.stack([fwd_vec, right_vec, up_vec], axis=1)
|
| 290 |
+
return scipy.spatial.transform.Rotation.from_matrix(R).as_quat()
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _get_bev_points(layout, scales, classes):
|
| 294 |
+
import gaussiancity.extensions.voxlib
|
| 295 |
+
|
| 296 |
+
assert torch.max(layout["INS"]) < 16384
|
| 297 |
+
# torch.nonzero(torch.zeros(2048, 2048, 512).cuda())
|
| 298 |
+
# -> nonzero is not supported for tensors with more than INT_MAX elements
|
| 299 |
+
# torch.nonzero(torch.zeros(2048, 2048, 508).cuda())
|
| 300 |
+
# -> an illegal memory access was encountered
|
| 301 |
+
assert torch.max(layout["TD_HF"]) <= 500
|
| 302 |
+
|
| 303 |
+
volume = gaussiancity.extensions.voxlib.maps_to_volume(
|
| 304 |
+
layout["INS"].squeeze().short(),
|
| 305 |
+
layout["TD_HF"].squeeze().short(),
|
| 306 |
+
layout["BU_HF"].squeeze().short(),
|
| 307 |
+
layout["PTS"].squeeze().bool(),
|
| 308 |
+
torch.tensor(
|
| 309 |
+
[scales[k] if k in scales else 0 for k in classes.keys()],
|
| 310 |
+
dtype=torch.int8,
|
| 311 |
+
device=layout["INS"].device,
|
| 312 |
+
),
|
| 313 |
+
)
|
| 314 |
+
non_zero_indices = torch.nonzero(volume, as_tuple=False)
|
| 315 |
+
non_zero_values = volume[
|
| 316 |
+
non_zero_indices[:, 0], non_zero_indices[:, 1], non_zero_indices[:, 2]
|
| 317 |
+
]
|
| 318 |
+
return torch.cat(
|
| 319 |
+
[non_zero_indices.short(), non_zero_values.unsqueeze(dim=1)], dim=1
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def _instances_to_classes(instances, bldg_inst_range, bldg_classes):
|
| 324 |
+
bldg_facade_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 0)
|
| 325 |
+
bldg_roof_idx = (instances >= bldg_inst_range[0]) & (instances % 2 == 1)
|
| 326 |
+
|
| 327 |
+
classes = instances.clone()
|
| 328 |
+
classes[bldg_facade_idx] = bldg_classes["BLDG_FACADE"]
|
| 329 |
+
classes[bldg_roof_idx] = bldg_classes["BLDG_ROOF"]
|
| 330 |
+
return classes
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _get_point_scales(pt_classes, scales, classes, special_z_scale_classes=[]):
|
| 334 |
+
pt_scales = pt_classes.clone()
|
| 335 |
+
for k, v in scales.items():
|
| 336 |
+
pt_scales[pt_classes == classes[k]] = v
|
| 337 |
+
|
| 338 |
+
pt_scales_3d = torch.ones_like(pt_scales).repeat(1, 3) * pt_scales
|
| 339 |
+
# Set the z-scale = 1 for roads, zones, and waters
|
| 340 |
+
pt_scales_3d[..., 2][
|
| 341 |
+
torch.isin(
|
| 342 |
+
pt_classes.squeeze(dim=-1),
|
| 343 |
+
torch.tensor(
|
| 344 |
+
list(special_z_scale_classes),
|
| 345 |
+
device=pt_classes.device,
|
| 346 |
+
),
|
| 347 |
+
)
|
| 348 |
+
] = 1
|
| 349 |
+
return pt_scales_3d
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _get_visible_points(points, scales, K, sensor_size, cam_pos, cam_look_at):
|
| 353 |
+
## NOTE: Each point is assigned with a unique ID. The values in the rendered map
|
| 354 |
+
## denotes the visibility of the points. The values are the same as the point IDs.
|
| 355 |
+
# Generate 3D volume
|
| 356 |
+
volume, offsets = _get_volume(points, scales)
|
| 357 |
+
# Ray-voxel intersection
|
| 358 |
+
vp_map = _get_ray_voxel_intersection(
|
| 359 |
+
K, sensor_size, cam_pos - offsets, cam_look_at - cam_pos, volume
|
| 360 |
+
)
|
| 361 |
+
## Generate the instance segmentation map as a side product
|
| 362 |
+
# ins_map = instances[vp_map]
|
| 363 |
+
# null_mask = vp_map == -1
|
| 364 |
+
# ins_map[null_mask] = null_class_id
|
| 365 |
+
|
| 366 |
+
# Manually release the memory to avoid OOM
|
| 367 |
+
del volume
|
| 368 |
+
torch.cuda.empty_cache()
|
| 369 |
+
|
| 370 |
+
vp_idx = torch.unique(vp_map)
|
| 371 |
+
return vp_idx[vp_idx >= 0]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _get_volume(points, scales):
|
| 375 |
+
import gaussiancity.extensions.voxlib
|
| 376 |
+
|
| 377 |
+
x_min, x_max = torch.min(points[:, 0]).item(), torch.max(points[:, 0]).item()
|
| 378 |
+
y_min, y_max = torch.min(points[:, 1]).item(), torch.max(points[:, 1]).item()
|
| 379 |
+
z_min, z_max = torch.min(points[:, 2]).item(), torch.max(points[:, 2]).item()
|
| 380 |
+
offsets = torch.tensor(
|
| 381 |
+
[x_min, y_min, z_min], dtype=torch.int16, device=points.device
|
| 382 |
+
)
|
| 383 |
+
# Normalize points coordinates to local coordinate system
|
| 384 |
+
points = _get_localized_pt_cords(points, offsets)
|
| 385 |
+
# Generate an empty 3D volume
|
| 386 |
+
w, h, d = x_max - x_min + 1, y_max - y_min + 1, z_max - z_min + 2
|
| 387 |
+
# Generate point IDs
|
| 388 |
+
# NOTE: The point IDs start from 1 to avoid the conflict with the NULL class.
|
| 389 |
+
assert points.shape[0] < 2147483648
|
| 390 |
+
pt_ids = torch.arange(
|
| 391 |
+
start=1, end=points.shape[0] + 1, dtype=torch.int32, device=points.device
|
| 392 |
+
).unsqueeze(dim=1)
|
| 393 |
+
volume = gaussiancity.extensions.voxlib.points_to_volume(
|
| 394 |
+
points.contiguous(), pt_ids, scales, h, w, d
|
| 395 |
+
)
|
| 396 |
+
return volume, offsets
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def _get_localized_pt_cords(points, offsets):
|
| 400 |
+
points[:, 0] -= offsets[0]
|
| 401 |
+
points[:, 1] -= offsets[1]
|
| 402 |
+
points[:, 2] -= offsets[2] - 1
|
| 403 |
+
return points
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _get_ray_voxel_intersection(K, sensor_size, cam_origin, viewdir, volume):
|
| 407 |
+
import gaussiancity.extensions.voxlib
|
| 408 |
+
|
| 409 |
+
N_MAX_SAMPLES = 1
|
| 410 |
+
voxel_id, _, _ = gaussiancity.extensions.voxlib.ray_voxel_intersection_perspective(
|
| 411 |
+
volume,
|
| 412 |
+
cam_origin[[1, 0, 2]].float(),
|
| 413 |
+
viewdir[[1, 0, 2]].float(),
|
| 414 |
+
torch.tensor([0, 0, 1], dtype=torch.float32),
|
| 415 |
+
K[0],
|
| 416 |
+
[K[5], K[2]],
|
| 417 |
+
[sensor_size[1], sensor_size[0]],
|
| 418 |
+
N_MAX_SAMPLES,
|
| 419 |
+
)
|
| 420 |
+
# NOTE: The point ID for NULL class is -1, the rest point IDs are from 0 to N - 1.
|
| 421 |
+
# The ray_voxel_intersection_perspective seems not accepting the negative values.
|
| 422 |
+
return voxel_id.squeeze() - 1
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def get_hf_seg_tensor(part_hf, part_seg, layout_cfg, output_device):
|
| 426 |
+
part_hf = torch.from_numpy(part_hf[None, None, ...]).to(output_device)
|
| 427 |
+
part_seg = torch.from_numpy(part_seg[None, None, ...]).to(output_device)
|
| 428 |
+
part_hf = part_hf / CONSTANTS["LAYOUT_MAX_HEIGHT"]
|
| 429 |
+
part_seg = _masks_to_onehots(part_seg[:, 0, :, :], CONSTANTS["LAYOUT_N_CLASSES"])
|
| 430 |
+
return torch.cat([part_hf, part_seg], dim=1)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def _masks_to_onehots(masks, n_class, ignored_classes=[]):
|
| 434 |
+
b, h, w = masks.shape
|
| 435 |
+
n_class_actual = n_class - len(ignored_classes)
|
| 436 |
+
one_hot_masks = torch.zeros(
|
| 437 |
+
(b, n_class_actual, h, w), dtype=torch.float32, device=masks.device
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
n_class_cnt = 0
|
| 441 |
+
for i in range(n_class):
|
| 442 |
+
if i not in ignored_classes:
|
| 443 |
+
one_hot_masks[:, n_class_cnt] = masks == i
|
| 444 |
+
n_class_cnt += 1
|
| 445 |
+
return one_hot_masks
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _get_gs_attrs(
|
| 449 |
+
pts,
|
| 450 |
+
proj_hf,
|
| 451 |
+
proj_seg,
|
| 452 |
+
style_lut,
|
| 453 |
+
centers,
|
| 454 |
+
models,
|
| 455 |
+
scale_factor,
|
| 456 |
+
bldg_inst_range,
|
| 457 |
+
):
|
| 458 |
+
n_pts, _ = pts.shape
|
| 459 |
+
# NOTE: 4: XYZ, Instance ID; 3: Scale; N_CLASSES: One-hot
|
| 460 |
+
# print(pts.shape) # [N, 4 + 3 + N_CLASSES]
|
| 461 |
+
bldg_selector = pts[:, 3] >= bldg_inst_range[0]
|
| 462 |
+
bldg_pts = pts[bldg_selector]
|
| 463 |
+
rest_pts = pts[~bldg_selector]
|
| 464 |
+
|
| 465 |
+
bldg_attrs = _get_pt_input_attrs(
|
| 466 |
+
bldg_pts[:, :4],
|
| 467 |
+
centers,
|
| 468 |
+
style_lut,
|
| 469 |
+
models["BLDG"].module.cfg.Z_DIM,
|
| 470 |
+
bldg_inst_range,
|
| 471 |
+
)
|
| 472 |
+
rest_attrs = _get_pt_input_attrs(
|
| 473 |
+
rest_pts[:, :4],
|
| 474 |
+
centers,
|
| 475 |
+
style_lut,
|
| 476 |
+
models["REST"].module.cfg.Z_DIM,
|
| 477 |
+
bldg_inst_range,
|
| 478 |
+
)
|
| 479 |
+
bldg_colors = _get_gs_colors(
|
| 480 |
+
bldg_pts, bldg_attrs, proj_hf, proj_seg, models["BLDG"]
|
| 481 |
+
)
|
| 482 |
+
rest_colors = _get_gs_colors(
|
| 483 |
+
rest_pts, rest_attrs, proj_hf, proj_seg, models["REST"]
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
abs_xyz = torch.cat([bldg_pts[:, :3], rest_pts[:, :3]], dim=0)
|
| 487 |
+
scales = torch.cat([bldg_pts[:, 4:7], rest_pts[:, 4:7]], dim=0) * scale_factor
|
| 488 |
+
rgb = torch.cat([bldg_colors, rest_colors], dim=0)
|
| 489 |
+
# Attributes with default values
|
| 490 |
+
opacity = torch.ones((n_pts, 1), device=pts.device)
|
| 491 |
+
rotations = torch.cat(
|
| 492 |
+
[
|
| 493 |
+
torch.ones(n_pts, 1, device=pts.device),
|
| 494 |
+
torch.zeros(n_pts, 3, device=pts.device),
|
| 495 |
+
],
|
| 496 |
+
dim=-1,
|
| 497 |
+
)
|
| 498 |
+
return torch.cat((abs_xyz, opacity, scales, rotations, rgb), dim=-1)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def _get_pt_input_attrs(pts, centers, style_lut, z_dim, bldg_inst_range):
|
| 502 |
+
n_pts = pts.shape[0]
|
| 503 |
+
instances = torch.unique(pts[:, -1])
|
| 504 |
+
rel_xyz = torch.zeros(1, n_pts, 3, dtype=torch.float32, device=pts.device)
|
| 505 |
+
batch_idx = torch.zeros(1, n_pts, dtype=torch.int32, device=pts.device)
|
| 506 |
+
zs = {} if z_dim is not None else None
|
| 507 |
+
for idx, ins in enumerate(instances):
|
| 508 |
+
ins = ins.item()
|
| 509 |
+
is_pts = pts[:, -1] == ins
|
| 510 |
+
cx, cy, w, h, d = centers[ins]
|
| 511 |
+
|
| 512 |
+
if ins >= bldg_inst_range[0]:
|
| 513 |
+
rel_xyz[:, is_pts, 0] = (pts[is_pts, 0] - cx) / w * 2 if w > 0 else 0
|
| 514 |
+
rel_xyz[:, is_pts, 1] = (pts[is_pts, 1] - cy) / h * 2 if h > 0 else 0
|
| 515 |
+
else:
|
| 516 |
+
# Make the BG contiguous
|
| 517 |
+
period_x = torch.ceil((pts[is_pts, 0] / w / 2) - 0.5)
|
| 518 |
+
period_y = torch.ceil((pts[is_pts, 1] / h / 2) - 0.5)
|
| 519 |
+
rel_xyz[:, is_pts, 0] = (
|
| 520 |
+
(pts[is_pts, 0] - 2 * period_x * w) * (-1) ** period_x
|
| 521 |
+
) / w
|
| 522 |
+
rel_xyz[:, is_pts, 1] = (
|
| 523 |
+
(pts[is_pts, 1] - 2 * period_y * h) * (-1) ** period_y
|
| 524 |
+
) / h
|
| 525 |
+
|
| 526 |
+
rel_xyz[:, is_pts, 2] = (
|
| 527 |
+
torch.clip(pts[is_pts, 2] / d * 2 - 1, -1, 1) if d > 0 else 0
|
| 528 |
+
)
|
| 529 |
+
batch_idx[:, is_pts] = idx
|
| 530 |
+
if zs is not None:
|
| 531 |
+
zs[ins] = {"z": style_lut[ins], "idx": is_pts.unsqueeze(dim=0)}
|
| 532 |
+
|
| 533 |
+
return rel_xyz, batch_idx, zs
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def _get_gs_colors(pts, pt_attrs, proj_hf, proj_seg, model):
|
| 537 |
+
if pts.shape[0] == 0:
|
| 538 |
+
return torch.empty(0, 3, dtype=torch.float32, device=pts.device)
|
| 539 |
+
|
| 540 |
+
abs_xyz, onehots = pts[None, :, :3], pts[None, :, 7:]
|
| 541 |
+
rel_xyz, batch_idx, zs = pt_attrs
|
| 542 |
+
proj_uv = None
|
| 543 |
+
if model.module.cfg.ENCODER is not None:
|
| 544 |
+
proj_uv = get_projection_uv(abs_xyz)
|
| 545 |
+
|
| 546 |
+
with torch.no_grad():
|
| 547 |
+
# TODO: Optimize the _instance_forward in Generator
|
| 548 |
+
gs_attrs = model(
|
| 549 |
+
proj_uv, rel_xyz, batch_idx, onehots.float(), zs, proj_hf, proj_seg
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
return gs_attrs["rgb"].squeeze(dim=0)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def get_projection_uv(xyz, proj_tlp=None, proj_size=2048):
|
| 556 |
+
n_pts = xyz.size(1)
|
| 557 |
+
if proj_tlp is None:
|
| 558 |
+
proj_uv = xyz[..., :2].clone().float()
|
| 559 |
+
else:
|
| 560 |
+
proj_uv = xyz[..., :2] - proj_tlp.unsqueeze(dim=1)
|
| 561 |
+
|
| 562 |
+
assert proj_uv.size() == (xyz.size(0), n_pts, 2)
|
| 563 |
+
proj_uv[..., 0] /= proj_size
|
| 564 |
+
proj_uv[..., 1] /= proj_size
|
| 565 |
+
# Normalize to [-1, 1]
|
| 566 |
+
return proj_uv * 2 - 1
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def _render(gs_attrs, rasterizator, cam_pose):
|
| 570 |
+
import torchvision.transforms.functional as F
|
| 571 |
+
|
| 572 |
+
with torch.no_grad():
|
| 573 |
+
img = rasterizator(
|
| 574 |
+
gs_attrs,
|
| 575 |
+
cam_pose[:3], # Position
|
| 576 |
+
cam_pose[3:], # Quaternion
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
img = img.squeeze() / 2 + 0.5
|
| 580 |
+
img = F.adjust_brightness(img, 1.2)
|
| 581 |
+
img = F.adjust_contrast(img, 1.2)
|
| 582 |
+
return (img * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
|
gaussiancity/pt_v3.py
ADDED
|
@@ -0,0 +1,1344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
#
|
| 3 |
+
# @File: pt_v3.py
|
| 4 |
+
# @Author: Xiaoyang Wu <xiaoyang.wu.cs@gmail.com>
|
| 5 |
+
# @Date: 2024-04-01 16:31:36
|
| 6 |
+
# @Last Modified by: Haozhe Xie
|
| 7 |
+
# @Last Modified at: 2024-05-15 22:05:09
|
| 8 |
+
# @Email: root@haozhexie.com
|
| 9 |
+
# Ref:
|
| 10 |
+
# - https://github.com/Pointcept/PointTransformerV3/blob/main/model.py
|
| 11 |
+
# - https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py
|
| 12 |
+
|
| 13 |
+
import addict
|
| 14 |
+
import collections
|
| 15 |
+
import functools
|
| 16 |
+
import flash_attn
|
| 17 |
+
import math
|
| 18 |
+
import torch
|
| 19 |
+
import spconv.pytorch as spconv
|
| 20 |
+
import torch_scatter
|
| 21 |
+
import typing
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@torch.inference_mode()
|
| 25 |
+
def offset2bincount(offset):
|
| 26 |
+
return torch.diff(
|
| 27 |
+
offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@torch.inference_mode()
|
| 32 |
+
def offset2batch(offset):
|
| 33 |
+
bincount = offset2bincount(offset)
|
| 34 |
+
return torch.arange(
|
| 35 |
+
len(bincount), device=offset.device, dtype=torch.long
|
| 36 |
+
).repeat_interleave(bincount)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@torch.inference_mode()
|
| 40 |
+
def batch2offset(batch):
|
| 41 |
+
return torch.cumsum(batch.bincount(), dim=0).long()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class KeyLUT:
|
| 45 |
+
def __init__(self):
|
| 46 |
+
r256 = torch.arange(256, dtype=torch.int64)
|
| 47 |
+
r512 = torch.arange(512, dtype=torch.int64)
|
| 48 |
+
zero = torch.zeros(256, dtype=torch.int64)
|
| 49 |
+
device = torch.device("cpu")
|
| 50 |
+
|
| 51 |
+
self._encode = {
|
| 52 |
+
device: (
|
| 53 |
+
self.xyz2key(r256, zero, zero, 8),
|
| 54 |
+
self.xyz2key(zero, r256, zero, 8),
|
| 55 |
+
self.xyz2key(zero, zero, r256, 8),
|
| 56 |
+
)
|
| 57 |
+
}
|
| 58 |
+
self._decode = {device: self.key2xyz(r512, 9)}
|
| 59 |
+
|
| 60 |
+
def encode_lut(self, device=torch.device("cpu")):
|
| 61 |
+
if device not in self._encode:
|
| 62 |
+
cpu = torch.device("cpu")
|
| 63 |
+
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
|
| 64 |
+
return self._encode[device]
|
| 65 |
+
|
| 66 |
+
def decode_lut(self, device=torch.device("cpu")):
|
| 67 |
+
if device not in self._decode:
|
| 68 |
+
cpu = torch.device("cpu")
|
| 69 |
+
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
|
| 70 |
+
return self._decode[device]
|
| 71 |
+
|
| 72 |
+
def xyz2key(self, x, y, z, depth):
|
| 73 |
+
key = torch.zeros_like(x)
|
| 74 |
+
for i in range(depth):
|
| 75 |
+
mask = 1 << i
|
| 76 |
+
key = (
|
| 77 |
+
key
|
| 78 |
+
| ((x & mask) << (2 * i + 2))
|
| 79 |
+
| ((y & mask) << (2 * i + 1))
|
| 80 |
+
| ((z & mask) << (2 * i + 0))
|
| 81 |
+
)
|
| 82 |
+
return key
|
| 83 |
+
|
| 84 |
+
def key2xyz(self, key, depth):
|
| 85 |
+
x = torch.zeros_like(key)
|
| 86 |
+
y = torch.zeros_like(key)
|
| 87 |
+
z = torch.zeros_like(key)
|
| 88 |
+
for i in range(depth):
|
| 89 |
+
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
|
| 90 |
+
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
|
| 91 |
+
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
|
| 92 |
+
return x, y, z
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Serializator:
|
| 96 |
+
def encode(self, grid_coord, grid_size=0.01, batch=None, depth=16, order="cord"):
|
| 97 |
+
assert order in {"cord", "z", "z-trans", "hilbert", "hilbert-trans"}
|
| 98 |
+
if order in ["z", "z-trans"]:
|
| 99 |
+
self.key_lut = KeyLUT()
|
| 100 |
+
if order == "cord":
|
| 101 |
+
code = self.cord_encode(grid_coord, grid_size)
|
| 102 |
+
elif order == "z":
|
| 103 |
+
code = self.z_order_encode(grid_coord, depth=depth)
|
| 104 |
+
elif order == "z-trans":
|
| 105 |
+
code = self.z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
| 106 |
+
elif order == "hilbert":
|
| 107 |
+
code = self.hilbert_encode(grid_coord, depth=depth)
|
| 108 |
+
elif order == "hilbert-trans":
|
| 109 |
+
code = self.hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
| 110 |
+
else:
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
if batch is not None:
|
| 114 |
+
batch = batch.long()
|
| 115 |
+
code = batch << depth * 3 | code
|
| 116 |
+
|
| 117 |
+
return code
|
| 118 |
+
|
| 119 |
+
def cord_encode(self, grid_coord: torch.Tensor, grid_size: float):
|
| 120 |
+
x, y, z = (
|
| 121 |
+
grid_coord[:, 0].long(),
|
| 122 |
+
grid_coord[:, 1].long(),
|
| 123 |
+
grid_coord[:, 2].long(),
|
| 124 |
+
)
|
| 125 |
+
# we block the support to batch, maintain batched code in Point class
|
| 126 |
+
code = x / grid_size**2 + y / grid_size + z
|
| 127 |
+
return code.long()
|
| 128 |
+
|
| 129 |
+
def z_order_encode(self, grid_coord: torch.Tensor, depth: int = 16):
|
| 130 |
+
x, y, z = (
|
| 131 |
+
grid_coord[:, 0].long(),
|
| 132 |
+
grid_coord[:, 1].long(),
|
| 133 |
+
grid_coord[:, 2].long(),
|
| 134 |
+
)
|
| 135 |
+
# we block the support to batch, maintain batched code in Point class
|
| 136 |
+
code = self._xyz2key(x, y, z, b=None, depth=depth)
|
| 137 |
+
return code
|
| 138 |
+
|
| 139 |
+
def _xyz2key(
|
| 140 |
+
self,
|
| 141 |
+
x: torch.Tensor,
|
| 142 |
+
y: torch.Tensor,
|
| 143 |
+
z: torch.Tensor,
|
| 144 |
+
b: typing.Optional[typing.Union[torch.Tensor, int]] = None,
|
| 145 |
+
depth: int = 16,
|
| 146 |
+
):
|
| 147 |
+
r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
|
| 148 |
+
based on pre-computed look up tables. The speed of this function is much
|
| 149 |
+
faster than the method based on for-loop.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
x (torch.Tensor): The x coordinate.
|
| 153 |
+
y (torch.Tensor): The y coordinate.
|
| 154 |
+
z (torch.Tensor): The z coordinate.
|
| 155 |
+
b (torch.Tensor or int): The batch index of the coordinates, and should be
|
| 156 |
+
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
|
| 157 |
+
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
|
| 158 |
+
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
| 159 |
+
"""
|
| 160 |
+
EX, EY, EZ = self.key_lut.encode_lut(x.device)
|
| 161 |
+
x, y, z = x.long(), y.long(), z.long()
|
| 162 |
+
|
| 163 |
+
mask = 255 if depth > 8 else (1 << depth) - 1
|
| 164 |
+
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
|
| 165 |
+
if depth > 8:
|
| 166 |
+
mask = (1 << (depth - 8)) - 1
|
| 167 |
+
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
|
| 168 |
+
key = key16 << 24 | key
|
| 169 |
+
|
| 170 |
+
if b is not None:
|
| 171 |
+
b = b.long()
|
| 172 |
+
key = b << 48 | key
|
| 173 |
+
|
| 174 |
+
return key
|
| 175 |
+
|
| 176 |
+
def hilbert_encode(self, grid_coord: torch.Tensor, depth: int = 16):
|
| 177 |
+
return self._hilbert_encode(grid_coord, num_dims=3, num_bits=depth)
|
| 178 |
+
|
| 179 |
+
def _hilbert_encode(self, locs, num_dims, num_bits):
|
| 180 |
+
"""Decode an array of locations in a hypercube into a Hilbert integer.
|
| 181 |
+
|
| 182 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
| 183 |
+
Skilling as described in:
|
| 184 |
+
|
| 185 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
| 186 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
| 187 |
+
|
| 188 |
+
Params:
|
| 189 |
+
-------
|
| 190 |
+
locs - An ndarray of locations in a hypercube of num_dims dimensions, in
|
| 191 |
+
which each dimension runs from 0 to 2**num_bits-1. The shape can
|
| 192 |
+
be arbitrary, as long as the last dimension of the same has size
|
| 193 |
+
num_dims.
|
| 194 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
| 195 |
+
num_bits - The number of bits for each dimension. Integer.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
--------
|
| 199 |
+
The output is an ndarray of uint64 integers with the same shape as the
|
| 200 |
+
input, excluding the last dimension, which needs to be num_dims.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
# Keep around the original shape for later.
|
| 204 |
+
orig_shape = locs.shape
|
| 205 |
+
bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
|
| 206 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
| 207 |
+
|
| 208 |
+
if orig_shape[-1] != num_dims:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
"""
|
| 211 |
+
The shape of locs was surprising in that the last dimension was of size
|
| 212 |
+
%d, but num_dims=%d. These need to be equal.
|
| 213 |
+
"""
|
| 214 |
+
% (orig_shape[-1], num_dims)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if num_dims * num_bits > 63:
|
| 218 |
+
raise ValueError(
|
| 219 |
+
"""
|
| 220 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
| 221 |
+
into a int64. Are you sure you need that many points on your Hilbert
|
| 222 |
+
curve?
|
| 223 |
+
"""
|
| 224 |
+
% (num_dims, num_bits, num_dims * num_bits)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Treat the location integers as 64-bit unsigned and then split them up into
|
| 228 |
+
# a sequence of uint8s. Preserve the association by dimension.
|
| 229 |
+
locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
|
| 230 |
+
|
| 231 |
+
# Now turn these into bits and truncate to num_bits.
|
| 232 |
+
gray = (
|
| 233 |
+
locs_uint8.unsqueeze(-1)
|
| 234 |
+
.bitwise_and(bitpack_mask_rev)
|
| 235 |
+
.ne(0)
|
| 236 |
+
.byte()
|
| 237 |
+
.flatten(-2, -1)[..., -num_bits:]
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Run the decoding process the other way.
|
| 241 |
+
# Iterate forwards through the bits.
|
| 242 |
+
for bit in range(0, num_bits):
|
| 243 |
+
# Iterate forwards through the dimensions.
|
| 244 |
+
for dim in range(0, num_dims):
|
| 245 |
+
# Identify which ones have this bit active.
|
| 246 |
+
mask = gray[:, dim, bit]
|
| 247 |
+
|
| 248 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
| 249 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
| 250 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
| 254 |
+
to_flip = torch.logical_and(
|
| 255 |
+
torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
|
| 256 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
| 257 |
+
)
|
| 258 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
| 259 |
+
gray[:, dim, bit + 1 :], to_flip
|
| 260 |
+
)
|
| 261 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
| 262 |
+
gray[:, 0, bit + 1 :], to_flip
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Now flatten out.
|
| 266 |
+
# Fix: shape '[-1, 0]' is invalid for input of size 192
|
| 267 |
+
gray = gray.swapaxes(1, 2).reshape((gray.size(0), -1))
|
| 268 |
+
|
| 269 |
+
# Convert Gray back to binary.
|
| 270 |
+
hh_bin = self._gray2binary(gray)
|
| 271 |
+
|
| 272 |
+
# Pad back out to 64 bits.
|
| 273 |
+
extra_dims = 64 - gray.size(1)
|
| 274 |
+
padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
|
| 275 |
+
|
| 276 |
+
# Convert binary values into uint8s.
|
| 277 |
+
hh_uint8 = (
|
| 278 |
+
(padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
|
| 279 |
+
.sum(2)
|
| 280 |
+
.squeeze()
|
| 281 |
+
.type(torch.uint8)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Convert uint8s into uint64s.
|
| 285 |
+
hh_uint64 = hh_uint8.view(torch.int64).squeeze()
|
| 286 |
+
|
| 287 |
+
return hh_uint64
|
| 288 |
+
|
| 289 |
+
def _gray2binary(self, gray, axis=-1):
|
| 290 |
+
"""Convert an array of Gray codes back into binary values.
|
| 291 |
+
|
| 292 |
+
Parameters:
|
| 293 |
+
-----------
|
| 294 |
+
gray: An ndarray of gray codes.
|
| 295 |
+
axis: The axis along which to perform Gray decoding. Default=-1.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
--------
|
| 299 |
+
Returns an ndarray of binary values.
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
# Loop the log2(bits) number of times necessary, with shift and xor.
|
| 303 |
+
shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
|
| 304 |
+
while shift > 0:
|
| 305 |
+
gray = torch.logical_xor(gray, self._right_shift(gray, shift))
|
| 306 |
+
shift = torch.div(shift, 2, rounding_mode="floor")
|
| 307 |
+
return gray
|
| 308 |
+
|
| 309 |
+
def _right_shift(self, binary, k=1, axis=-1):
|
| 310 |
+
"""Right shift an array of binary values.
|
| 311 |
+
|
| 312 |
+
Parameters:
|
| 313 |
+
-----------
|
| 314 |
+
binary: An ndarray of binary values.
|
| 315 |
+
|
| 316 |
+
k: The number of bits to shift. Default 1.
|
| 317 |
+
|
| 318 |
+
axis: The axis along which to shift. Default -1.
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
--------
|
| 322 |
+
Returns an ndarray with zero prepended and the ends truncated, along
|
| 323 |
+
whatever axis was specified."""
|
| 324 |
+
|
| 325 |
+
# If we're shifting the whole thing, just return zeros.
|
| 326 |
+
if binary.shape[axis] <= k:
|
| 327 |
+
return torch.zeros_like(binary)
|
| 328 |
+
|
| 329 |
+
# Determine the padding pattern.
|
| 330 |
+
# padding = [(0,0)] * len(binary.shape)
|
| 331 |
+
# padding[axis] = (k,0)
|
| 332 |
+
|
| 333 |
+
# Determine the slicing pattern to eliminate just the last one.
|
| 334 |
+
slicing = [slice(None)] * len(binary.shape)
|
| 335 |
+
slicing[axis] = slice(None, -k)
|
| 336 |
+
shifted = torch.nn.functional.pad(
|
| 337 |
+
binary[tuple(slicing)], (k, 0), mode="constant", value=0
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
return shifted
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class PointModule(torch.nn.Module):
|
| 344 |
+
r"""PointModule
|
| 345 |
+
placeholder, all module subclass from this will take Point in PointSequential.
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(self, *args, **kwargs):
|
| 349 |
+
super().__init__(*args, **kwargs)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class Point(addict.Dict):
|
| 353 |
+
"""
|
| 354 |
+
Point Structure of Pointcept
|
| 355 |
+
|
| 356 |
+
A Point (point cloud) in Pointcept is a dictionary that contains various properties of
|
| 357 |
+
a batched point cloud. The property with the following names have a specific definition
|
| 358 |
+
as follows:
|
| 359 |
+
|
| 360 |
+
- "coord": original coordinate of point cloud;
|
| 361 |
+
- "grid_coord": grid coordinate for specific grid size (related to GridSampling);
|
| 362 |
+
Point also support the following optional attributes:
|
| 363 |
+
- "offset": if not exist, initialized as batch size is 1;
|
| 364 |
+
- "batch": if not exist, initialized as batch size is 1;
|
| 365 |
+
- "feat": feature of point cloud, default input of model;
|
| 366 |
+
- "grid_size": Grid size of point cloud (related to GridSampling);
|
| 367 |
+
(related to Serialization)
|
| 368 |
+
- "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
|
| 369 |
+
- "serialized_code": a list of serialization codes;
|
| 370 |
+
- "serialized_order": a list of serialization order determined by code;
|
| 371 |
+
- "serialized_inverse": a list of inverse mapping determined by code;
|
| 372 |
+
(related to Sparsify: SpConv)
|
| 373 |
+
- "sparse_shape": Sparse shape for Sparse Conv Tensor;
|
| 374 |
+
- "sparse_conv_feat": SparseConvTensor init with information provide by Point;
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
def __init__(self, *args, **kwargs):
|
| 378 |
+
super().__init__(*args, **kwargs)
|
| 379 |
+
self.serializator = Serializator()
|
| 380 |
+
# If one of "offset" or "batch" do not exist, generate by the existing one
|
| 381 |
+
if "batch" not in self.keys() and "offset" in self.keys():
|
| 382 |
+
self["batch"] = offset2batch(self.offset)
|
| 383 |
+
elif "offset" not in self.keys() and "batch" in self.keys():
|
| 384 |
+
self["offset"] = batch2offset(self.batch)
|
| 385 |
+
|
| 386 |
+
def serialization(self, order="z", depth=None, shuffle_orders=False):
|
| 387 |
+
"""
|
| 388 |
+
Point Cloud Serialization
|
| 389 |
+
|
| 390 |
+
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
|
| 391 |
+
"""
|
| 392 |
+
assert "batch" in self.keys()
|
| 393 |
+
if "grid_coord" not in self.keys():
|
| 394 |
+
# if you don't want to operate GridSampling in data augmentation,
|
| 395 |
+
# please add the following augmentation into your pipline:
|
| 396 |
+
# dict(type="Copy", keys_dict={"grid_size": 0.01}),
|
| 397 |
+
# (adjust `grid_size` to what your want)
|
| 398 |
+
assert {"grid_size", "coord"}.issubset(self.keys())
|
| 399 |
+
self["grid_coord"] = torch.div(
|
| 400 |
+
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
|
| 401 |
+
).int()
|
| 402 |
+
|
| 403 |
+
if depth is None:
|
| 404 |
+
# Adaptive measure the depth of serialization cube (length = 2 ^ depth)
|
| 405 |
+
depth = int(self.grid_coord.max()).bit_length()
|
| 406 |
+
|
| 407 |
+
self["serialized_depth"] = depth
|
| 408 |
+
# Maximum bit length for serialization code is 63 (int64)
|
| 409 |
+
assert depth * 3 + len(self.offset).bit_length() <= 63
|
| 410 |
+
# Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
|
| 411 |
+
# Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
|
| 412 |
+
# cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
|
| 413 |
+
# We can unlock the limitation by optimizing the z-order encoding function if necessary.
|
| 414 |
+
assert depth <= 16
|
| 415 |
+
|
| 416 |
+
# The serialization codes are arranged as following structures:
|
| 417 |
+
# [Order1 ([n]),
|
| 418 |
+
# Order2 ([n]),
|
| 419 |
+
# ...
|
| 420 |
+
# OrderN ([n])] (k, n)
|
| 421 |
+
code = [
|
| 422 |
+
self.serializator.encode(
|
| 423 |
+
self.grid_coord, self.grid_size, self.batch, depth, order=order_
|
| 424 |
+
)
|
| 425 |
+
for order_ in order
|
| 426 |
+
]
|
| 427 |
+
code = torch.stack(code)
|
| 428 |
+
order = torch.argsort(code)
|
| 429 |
+
inverse = torch.zeros_like(order).scatter_(
|
| 430 |
+
dim=1,
|
| 431 |
+
index=order,
|
| 432 |
+
src=torch.arange(0, code.shape[1], device=order.device).repeat(
|
| 433 |
+
code.shape[0], 1
|
| 434 |
+
),
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
if shuffle_orders:
|
| 438 |
+
perm = torch.randperm(code.shape[0])
|
| 439 |
+
code = code[perm]
|
| 440 |
+
order = order[perm]
|
| 441 |
+
inverse = inverse[perm]
|
| 442 |
+
|
| 443 |
+
self["serialized_code"] = code
|
| 444 |
+
self["serialized_order"] = order
|
| 445 |
+
self["serialized_inverse"] = inverse
|
| 446 |
+
|
| 447 |
+
def sparsify(self, pad=96):
|
| 448 |
+
"""
|
| 449 |
+
Point Cloud Serialization
|
| 450 |
+
|
| 451 |
+
Point cloud is sparse, here we use "sparsify" to specifically refer to
|
| 452 |
+
preparing "spconv.SparseConvTensor" for SpConv.
|
| 453 |
+
|
| 454 |
+
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
|
| 455 |
+
|
| 456 |
+
pad: padding sparse for sparse shape.
|
| 457 |
+
"""
|
| 458 |
+
assert {"feat", "batch"}.issubset(self.keys())
|
| 459 |
+
if "grid_coord" not in self.keys():
|
| 460 |
+
# if you don't want to operate GridSampling in data augmentation,
|
| 461 |
+
# please add the following augmentation into your pipline:
|
| 462 |
+
# dict(type="Copy", keys_dict={"grid_size": 0.01}),
|
| 463 |
+
# (adjust `grid_size` to what your want)
|
| 464 |
+
assert {"grid_size", "coord"}.issubset(self.keys())
|
| 465 |
+
self["grid_coord"] = torch.div(
|
| 466 |
+
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
|
| 467 |
+
).int()
|
| 468 |
+
|
| 469 |
+
if "sparse_shape" in self.keys():
|
| 470 |
+
sparse_shape = self.sparse_shape
|
| 471 |
+
else:
|
| 472 |
+
sparse_shape = torch.add(
|
| 473 |
+
torch.max(self.grid_coord, dim=0).values, pad
|
| 474 |
+
).tolist()
|
| 475 |
+
|
| 476 |
+
sparse_conv_feat = spconv.SparseConvTensor(
|
| 477 |
+
features=self.feat,
|
| 478 |
+
indices=torch.cat(
|
| 479 |
+
[self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
|
| 480 |
+
).contiguous(),
|
| 481 |
+
spatial_shape=sparse_shape,
|
| 482 |
+
batch_size=self.batch[-1].tolist() + 1,
|
| 483 |
+
)
|
| 484 |
+
self["sparse_shape"] = sparse_shape
|
| 485 |
+
self["sparse_conv_feat"] = sparse_conv_feat
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class PointSequential(PointModule):
|
| 489 |
+
r"""A sequential container.
|
| 490 |
+
Modules will be added to it in the order they are passed in the constructor.
|
| 491 |
+
Alternatively, an ordered dict of modules can also be passed in.
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
def __init__(self, name="", *args, **kwargs):
|
| 495 |
+
super().__init__()
|
| 496 |
+
self.name = name
|
| 497 |
+
if len(args) == 1 and isinstance(args[0], collections.OrderedDict):
|
| 498 |
+
for key, module in args[0].items():
|
| 499 |
+
self.add_module(key, module)
|
| 500 |
+
else:
|
| 501 |
+
for idx, module in enumerate(args):
|
| 502 |
+
self.add_module(str(idx), module)
|
| 503 |
+
for name, module in kwargs.items():
|
| 504 |
+
if name in self._modules:
|
| 505 |
+
raise ValueError("name exists.")
|
| 506 |
+
self.add_module(name, module)
|
| 507 |
+
|
| 508 |
+
def __getitem__(self, idx):
|
| 509 |
+
if not (-len(self) <= idx < len(self)):
|
| 510 |
+
raise IndexError("index {} is out of range".format(idx))
|
| 511 |
+
if idx < 0:
|
| 512 |
+
idx += len(self)
|
| 513 |
+
it = iter(self._modules.values())
|
| 514 |
+
for i in range(idx):
|
| 515 |
+
next(it)
|
| 516 |
+
return next(it)
|
| 517 |
+
|
| 518 |
+
def __len__(self):
|
| 519 |
+
return len(self._modules)
|
| 520 |
+
|
| 521 |
+
def add(self, module, name=None):
|
| 522 |
+
if name is None:
|
| 523 |
+
name = str(len(self._modules))
|
| 524 |
+
if name in self._modules:
|
| 525 |
+
raise KeyError("name exists")
|
| 526 |
+
self.add_module(name, module)
|
| 527 |
+
|
| 528 |
+
def forward(self, x):
|
| 529 |
+
for module in self._modules.values():
|
| 530 |
+
# Point module
|
| 531 |
+
if isinstance(module, PointModule):
|
| 532 |
+
x = module(x)
|
| 533 |
+
# Spconv module
|
| 534 |
+
elif spconv.modules.is_spconv_module(module):
|
| 535 |
+
if isinstance(x, Point):
|
| 536 |
+
x.sparse_conv_feat = module(x.sparse_conv_feat)
|
| 537 |
+
x.feat = x.sparse_conv_feat.features
|
| 538 |
+
else:
|
| 539 |
+
x = module(x)
|
| 540 |
+
# Fix: Expected more than 1 value per channel when training
|
| 541 |
+
elif isinstance(module, torch.nn.BatchNorm1d) and isinstance(x, Point):
|
| 542 |
+
if x.feat.size(0) != 1:
|
| 543 |
+
x.feat = module(x.feat)
|
| 544 |
+
# PyTorch module
|
| 545 |
+
else:
|
| 546 |
+
if isinstance(x, Point):
|
| 547 |
+
x.feat = module(x.feat)
|
| 548 |
+
if "sparse_conv_feat" in x.keys():
|
| 549 |
+
x.sparse_conv_feat = x.sparse_conv_feat.replace_feature(x.feat)
|
| 550 |
+
elif isinstance(x, spconv.SparseConvTensor):
|
| 551 |
+
if x.indices.shape[0] != 0:
|
| 552 |
+
x = x.replace_feature(module(x.features))
|
| 553 |
+
else:
|
| 554 |
+
x = module(x)
|
| 555 |
+
|
| 556 |
+
return x
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class PDNorm(PointModule):
|
| 560 |
+
def __init__(
|
| 561 |
+
self,
|
| 562 |
+
num_features,
|
| 563 |
+
norm_layer,
|
| 564 |
+
context_channels=256,
|
| 565 |
+
conditions=("ScanNet", "S3DIS", "Structured3D"),
|
| 566 |
+
decouple=True,
|
| 567 |
+
adaptive=False,
|
| 568 |
+
):
|
| 569 |
+
super().__init__()
|
| 570 |
+
self.conditions = conditions
|
| 571 |
+
self.decouple = decouple
|
| 572 |
+
self.adaptive = adaptive
|
| 573 |
+
if self.decouple:
|
| 574 |
+
self.norm = torch.nn.ModuleList(
|
| 575 |
+
[norm_layer(num_features) for _ in conditions]
|
| 576 |
+
)
|
| 577 |
+
else:
|
| 578 |
+
self.norm = norm_layer
|
| 579 |
+
if self.adaptive:
|
| 580 |
+
self.modulation = torch.nn.Sequential(
|
| 581 |
+
torch.nn.SiLU(),
|
| 582 |
+
torch.nn.Linear(context_channels, 2 * num_features, bias=True),
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
def forward(self, point):
|
| 586 |
+
assert {"feat", "condition"}.issubset(point.keys())
|
| 587 |
+
if isinstance(point.condition, str):
|
| 588 |
+
condition = point.condition
|
| 589 |
+
else:
|
| 590 |
+
condition = point.condition[0]
|
| 591 |
+
if self.decouple:
|
| 592 |
+
assert condition in self.conditions
|
| 593 |
+
norm = self.norm[self.conditions.index(condition)]
|
| 594 |
+
else:
|
| 595 |
+
norm = self.norm
|
| 596 |
+
point.feat = norm(point.feat)
|
| 597 |
+
if self.adaptive:
|
| 598 |
+
assert "context" in point.keys()
|
| 599 |
+
shift, scale = self.modulation(point.context).chunk(2, dim=1)
|
| 600 |
+
point.feat = point.feat * (1.0 + scale) + shift
|
| 601 |
+
return point
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
class RPE(torch.nn.Module):
|
| 605 |
+
def __init__(self, patch_size, num_heads):
|
| 606 |
+
super().__init__()
|
| 607 |
+
self.patch_size = patch_size
|
| 608 |
+
self.num_heads = num_heads
|
| 609 |
+
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
|
| 610 |
+
self.rpe_num = 2 * self.pos_bnd + 1
|
| 611 |
+
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
|
| 612 |
+
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
|
| 613 |
+
|
| 614 |
+
def forward(self, coord):
|
| 615 |
+
idx = (
|
| 616 |
+
coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
|
| 617 |
+
+ self.pos_bnd # relative position to positive index
|
| 618 |
+
+ torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
|
| 619 |
+
)
|
| 620 |
+
out = self.rpe_table.index_select(0, idx.reshape(-1))
|
| 621 |
+
out = out.view(idx.shape + (-1,)).sum(3)
|
| 622 |
+
out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
|
| 623 |
+
return out
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class SerializedAttention(PointModule):
|
| 627 |
+
def __init__(
|
| 628 |
+
self,
|
| 629 |
+
channels,
|
| 630 |
+
num_heads,
|
| 631 |
+
patch_size,
|
| 632 |
+
qkv_bias=True,
|
| 633 |
+
qk_scale=None,
|
| 634 |
+
attn_drop=0.0,
|
| 635 |
+
proj_drop=0.0,
|
| 636 |
+
order_index=0,
|
| 637 |
+
enable_rpe=False,
|
| 638 |
+
enable_flash=True,
|
| 639 |
+
upcast_attention=True,
|
| 640 |
+
upcast_softmax=True,
|
| 641 |
+
):
|
| 642 |
+
super().__init__()
|
| 643 |
+
assert channels % num_heads == 0
|
| 644 |
+
self.channels = channels
|
| 645 |
+
self.num_heads = num_heads
|
| 646 |
+
self.scale = qk_scale or (channels // num_heads) ** -0.5
|
| 647 |
+
self.order_index = order_index
|
| 648 |
+
self.upcast_attention = upcast_attention
|
| 649 |
+
self.upcast_softmax = upcast_softmax
|
| 650 |
+
self.enable_rpe = enable_rpe
|
| 651 |
+
self.enable_flash = enable_flash
|
| 652 |
+
if enable_flash:
|
| 653 |
+
assert (
|
| 654 |
+
enable_rpe is False
|
| 655 |
+
), "Set enable_rpe to False when enable Flash Attention"
|
| 656 |
+
assert (
|
| 657 |
+
upcast_attention is False
|
| 658 |
+
), "Set upcast_attention to False when enable Flash Attention"
|
| 659 |
+
assert (
|
| 660 |
+
upcast_softmax is False
|
| 661 |
+
), "Set upcast_softmax to False when enable Flash Attention"
|
| 662 |
+
assert flash_attn is not None, "Make sure flash_attn is installed."
|
| 663 |
+
self.patch_size = patch_size
|
| 664 |
+
self.attn_drop = attn_drop
|
| 665 |
+
else:
|
| 666 |
+
# when disable flash attention, we still don't want to use mask
|
| 667 |
+
# consequently, patch size will auto set to the
|
| 668 |
+
# min number of patch_size_max and number of points
|
| 669 |
+
self.patch_size_max = patch_size
|
| 670 |
+
self.patch_size = 0
|
| 671 |
+
self.attn_drop = torch.nn.Dropout(attn_drop)
|
| 672 |
+
|
| 673 |
+
self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
|
| 674 |
+
self.proj = torch.nn.Linear(channels, channels)
|
| 675 |
+
self.proj_drop = torch.nn.Dropout(proj_drop)
|
| 676 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
| 677 |
+
self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
|
| 678 |
+
|
| 679 |
+
@torch.no_grad()
|
| 680 |
+
def get_rel_pos(self, point, order):
|
| 681 |
+
K = self.patch_size
|
| 682 |
+
rel_pos_key = f"rel_pos_{self.order_index}"
|
| 683 |
+
if rel_pos_key not in point.keys():
|
| 684 |
+
grid_coord = point.grid_coord[order]
|
| 685 |
+
grid_coord = grid_coord.reshape(-1, K, 3)
|
| 686 |
+
point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
|
| 687 |
+
return point[rel_pos_key]
|
| 688 |
+
|
| 689 |
+
@torch.no_grad()
|
| 690 |
+
def get_padding_and_inverse(self, point):
|
| 691 |
+
pad_key = "pad"
|
| 692 |
+
unpad_key = "unpad"
|
| 693 |
+
cu_seqlens_key = "cu_seqlens_key"
|
| 694 |
+
if (
|
| 695 |
+
pad_key not in point.keys()
|
| 696 |
+
or unpad_key not in point.keys()
|
| 697 |
+
or cu_seqlens_key not in point.keys()
|
| 698 |
+
):
|
| 699 |
+
offset = point.offset
|
| 700 |
+
bincount = offset2bincount(offset)
|
| 701 |
+
bincount_pad = (
|
| 702 |
+
torch.div(
|
| 703 |
+
bincount + self.patch_size - 1,
|
| 704 |
+
self.patch_size,
|
| 705 |
+
rounding_mode="trunc",
|
| 706 |
+
)
|
| 707 |
+
* self.patch_size
|
| 708 |
+
)
|
| 709 |
+
# only pad point when num of points larger than patch_size
|
| 710 |
+
mask_pad = bincount > self.patch_size
|
| 711 |
+
bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
|
| 712 |
+
_offset = torch.nn.functional.pad(offset, (1, 0))
|
| 713 |
+
_offset_pad = torch.nn.functional.pad(
|
| 714 |
+
torch.cumsum(bincount_pad, dim=0), (1, 0)
|
| 715 |
+
)
|
| 716 |
+
pad = torch.arange(_offset_pad[-1], device=offset.device)
|
| 717 |
+
unpad = torch.arange(_offset[-1], device=offset.device)
|
| 718 |
+
cu_seqlens = []
|
| 719 |
+
for i in range(len(offset)):
|
| 720 |
+
unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
|
| 721 |
+
if bincount[i] != bincount_pad[i]:
|
| 722 |
+
pad[
|
| 723 |
+
_offset_pad[i + 1]
|
| 724 |
+
- self.patch_size
|
| 725 |
+
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1]
|
| 726 |
+
] = pad[
|
| 727 |
+
_offset_pad[i + 1]
|
| 728 |
+
- 2 * self.patch_size
|
| 729 |
+
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1]
|
| 730 |
+
- self.patch_size
|
| 731 |
+
]
|
| 732 |
+
pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
|
| 733 |
+
cu_seqlens.append(
|
| 734 |
+
torch.arange(
|
| 735 |
+
_offset_pad[i],
|
| 736 |
+
_offset_pad[i + 1],
|
| 737 |
+
step=self.patch_size,
|
| 738 |
+
dtype=torch.int32,
|
| 739 |
+
device=offset.device,
|
| 740 |
+
)
|
| 741 |
+
)
|
| 742 |
+
point[pad_key] = pad
|
| 743 |
+
point[unpad_key] = unpad
|
| 744 |
+
point[cu_seqlens_key] = torch.nn.functional.pad(
|
| 745 |
+
torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
|
| 746 |
+
)
|
| 747 |
+
return point[pad_key], point[unpad_key], point[cu_seqlens_key]
|
| 748 |
+
|
| 749 |
+
def forward(self, point):
|
| 750 |
+
if not self.enable_flash:
|
| 751 |
+
self.patch_size = min(
|
| 752 |
+
offset2bincount(point.offset).min().tolist(), self.patch_size_max
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
H = self.num_heads
|
| 756 |
+
K = self.patch_size
|
| 757 |
+
C = self.channels
|
| 758 |
+
|
| 759 |
+
pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
|
| 760 |
+
|
| 761 |
+
order = point.serialized_order[self.order_index][pad]
|
| 762 |
+
inverse = unpad[point.serialized_inverse[self.order_index]]
|
| 763 |
+
|
| 764 |
+
# padding and reshape feat and batch for serialized point patch
|
| 765 |
+
qkv = self.qkv(point.feat)[order]
|
| 766 |
+
|
| 767 |
+
if not self.enable_flash:
|
| 768 |
+
# encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
|
| 769 |
+
q, k, v = (
|
| 770 |
+
qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
|
| 771 |
+
)
|
| 772 |
+
# attn
|
| 773 |
+
if self.upcast_attention:
|
| 774 |
+
q = q.float()
|
| 775 |
+
k = k.float()
|
| 776 |
+
attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
|
| 777 |
+
if self.enable_rpe:
|
| 778 |
+
attn = attn + self.rpe(self.get_rel_pos(point, order))
|
| 779 |
+
if self.upcast_softmax:
|
| 780 |
+
attn = attn.float()
|
| 781 |
+
attn = self.softmax(attn)
|
| 782 |
+
attn = self.attn_drop(attn).to(qkv.dtype)
|
| 783 |
+
feat = (attn @ v).transpose(1, 2).reshape(-1, C)
|
| 784 |
+
else:
|
| 785 |
+
feat = flash_attn.flash_attn_varlen_qkvpacked_func(
|
| 786 |
+
qkv.half().reshape(-1, 3, H, C // H),
|
| 787 |
+
cu_seqlens,
|
| 788 |
+
max_seqlen=self.patch_size,
|
| 789 |
+
dropout_p=self.attn_drop if self.training else 0,
|
| 790 |
+
softmax_scale=self.scale,
|
| 791 |
+
).reshape(-1, C)
|
| 792 |
+
feat = feat.to(qkv.dtype)
|
| 793 |
+
feat = feat[inverse]
|
| 794 |
+
|
| 795 |
+
# ffn
|
| 796 |
+
feat = self.proj(feat)
|
| 797 |
+
feat = self.proj_drop(feat)
|
| 798 |
+
point.feat = feat
|
| 799 |
+
return point
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
class MLP(torch.nn.Module):
|
| 803 |
+
def __init__(
|
| 804 |
+
self,
|
| 805 |
+
in_channels,
|
| 806 |
+
hidden_channels=None,
|
| 807 |
+
out_channels=None,
|
| 808 |
+
act_layer=torch.nn.GELU,
|
| 809 |
+
drop=0.0,
|
| 810 |
+
):
|
| 811 |
+
super().__init__()
|
| 812 |
+
out_channels = out_channels or in_channels
|
| 813 |
+
hidden_channels = hidden_channels or in_channels
|
| 814 |
+
self.fc1 = torch.nn.Linear(in_channels, hidden_channels)
|
| 815 |
+
self.act = act_layer()
|
| 816 |
+
self.fc2 = torch.nn.Linear(hidden_channels, out_channels)
|
| 817 |
+
# self.drop = torch.nn.Dropout(drop)
|
| 818 |
+
|
| 819 |
+
def forward(self, x):
|
| 820 |
+
x = self.fc1(x)
|
| 821 |
+
x = self.act(x)
|
| 822 |
+
# x = self.drop(x)
|
| 823 |
+
x = self.fc2(x)
|
| 824 |
+
# x = self.drop(x)
|
| 825 |
+
return x
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
class Block(PointModule):
|
| 829 |
+
def __init__(
|
| 830 |
+
self,
|
| 831 |
+
channels,
|
| 832 |
+
num_heads,
|
| 833 |
+
patch_size=48,
|
| 834 |
+
mlp_ratio=4.0,
|
| 835 |
+
qkv_bias=True,
|
| 836 |
+
qk_scale=None,
|
| 837 |
+
attn_drop=0.0,
|
| 838 |
+
proj_drop=0.0,
|
| 839 |
+
drop_path=0.0,
|
| 840 |
+
norm_layer=torch.nn.LayerNorm,
|
| 841 |
+
act_layer=torch.nn.GELU,
|
| 842 |
+
pre_norm=True,
|
| 843 |
+
order_index=0,
|
| 844 |
+
cpe_indice_key=None,
|
| 845 |
+
enable_rpe=False,
|
| 846 |
+
enable_flash=True,
|
| 847 |
+
upcast_attention=True,
|
| 848 |
+
upcast_softmax=True,
|
| 849 |
+
):
|
| 850 |
+
super().__init__()
|
| 851 |
+
self.channels = channels
|
| 852 |
+
self.pre_norm = pre_norm
|
| 853 |
+
|
| 854 |
+
self.cpe = PointSequential(
|
| 855 |
+
spconv.SubMConv3d(
|
| 856 |
+
channels,
|
| 857 |
+
channels,
|
| 858 |
+
kernel_size=3,
|
| 859 |
+
bias=True,
|
| 860 |
+
indice_key=cpe_indice_key,
|
| 861 |
+
),
|
| 862 |
+
torch.nn.Linear(channels, channels),
|
| 863 |
+
norm_layer(channels),
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
self.norm1 = PointSequential(norm_layer(channels))
|
| 867 |
+
self.attn = SerializedAttention(
|
| 868 |
+
channels=channels,
|
| 869 |
+
patch_size=patch_size,
|
| 870 |
+
num_heads=num_heads,
|
| 871 |
+
qkv_bias=qkv_bias,
|
| 872 |
+
qk_scale=qk_scale,
|
| 873 |
+
attn_drop=attn_drop,
|
| 874 |
+
proj_drop=proj_drop,
|
| 875 |
+
order_index=order_index,
|
| 876 |
+
enable_rpe=enable_rpe,
|
| 877 |
+
enable_flash=enable_flash,
|
| 878 |
+
upcast_attention=upcast_attention,
|
| 879 |
+
upcast_softmax=upcast_softmax,
|
| 880 |
+
)
|
| 881 |
+
self.norm2 = PointSequential(norm_layer(channels))
|
| 882 |
+
self.mlp = PointSequential(
|
| 883 |
+
MLP(
|
| 884 |
+
in_channels=channels,
|
| 885 |
+
hidden_channels=int(channels * mlp_ratio),
|
| 886 |
+
out_channels=channels,
|
| 887 |
+
act_layer=act_layer,
|
| 888 |
+
drop=proj_drop,
|
| 889 |
+
)
|
| 890 |
+
)
|
| 891 |
+
self.drop_path = PointSequential(
|
| 892 |
+
DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
def forward(self, point: Point):
|
| 896 |
+
shortcut = point.feat
|
| 897 |
+
point = self.cpe(point)
|
| 898 |
+
point.feat = shortcut + point.feat
|
| 899 |
+
shortcut = point.feat
|
| 900 |
+
if self.pre_norm:
|
| 901 |
+
point = self.norm1(point)
|
| 902 |
+
point = self.drop_path(self.attn(point))
|
| 903 |
+
point.feat = shortcut + point.feat
|
| 904 |
+
if not self.pre_norm:
|
| 905 |
+
point = self.norm1(point)
|
| 906 |
+
|
| 907 |
+
shortcut = point.feat
|
| 908 |
+
if self.pre_norm:
|
| 909 |
+
point = self.norm2(point)
|
| 910 |
+
point = self.drop_path(self.mlp(point))
|
| 911 |
+
point.feat = shortcut + point.feat
|
| 912 |
+
if not self.pre_norm:
|
| 913 |
+
point = self.norm2(point)
|
| 914 |
+
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
|
| 915 |
+
return point
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
class DropPath(torch.nn.Module):
|
| 919 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 920 |
+
|
| 921 |
+
def __init__(self, drop_prob=None, scale_by_keep=True):
|
| 922 |
+
super(DropPath, self).__init__()
|
| 923 |
+
self.drop_prob = drop_prob
|
| 924 |
+
self.scale_by_keep = scale_by_keep
|
| 925 |
+
|
| 926 |
+
def _drop_path(
|
| 927 |
+
self,
|
| 928 |
+
x,
|
| 929 |
+
drop_prob: float = 0.0,
|
| 930 |
+
training: bool = False,
|
| 931 |
+
scale_by_keep: bool = True,
|
| 932 |
+
):
|
| 933 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 934 |
+
|
| 935 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 936 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 937 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 938 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 939 |
+
'survival rate' as the argument.
|
| 940 |
+
|
| 941 |
+
"""
|
| 942 |
+
if drop_prob == 0.0 or not training:
|
| 943 |
+
return x
|
| 944 |
+
keep_prob = 1 - drop_prob
|
| 945 |
+
shape = (x.shape[0],) + (1,) * (
|
| 946 |
+
x.ndim - 1
|
| 947 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 948 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 949 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 950 |
+
random_tensor.div_(keep_prob)
|
| 951 |
+
return x * random_tensor
|
| 952 |
+
|
| 953 |
+
def forward(self, x):
|
| 954 |
+
return self._drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
class SerializedPooling(PointModule):
|
| 958 |
+
def __init__(
|
| 959 |
+
self,
|
| 960 |
+
in_channels,
|
| 961 |
+
out_channels,
|
| 962 |
+
stride=2,
|
| 963 |
+
norm_layer=None,
|
| 964 |
+
act_layer=None,
|
| 965 |
+
reduce="max",
|
| 966 |
+
shuffle_orders=True,
|
| 967 |
+
traceable=True, # record parent and cluster
|
| 968 |
+
):
|
| 969 |
+
super().__init__()
|
| 970 |
+
self.in_channels = in_channels
|
| 971 |
+
self.out_channels = out_channels
|
| 972 |
+
|
| 973 |
+
assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8
|
| 974 |
+
# TODO: add support to grid pool (any stride)
|
| 975 |
+
self.stride = stride
|
| 976 |
+
assert reduce in ["sum", "mean", "min", "max"]
|
| 977 |
+
self.reduce = reduce
|
| 978 |
+
self.shuffle_orders = shuffle_orders
|
| 979 |
+
self.traceable = traceable
|
| 980 |
+
|
| 981 |
+
self.proj = torch.nn.Linear(in_channels, out_channels)
|
| 982 |
+
if norm_layer is not None:
|
| 983 |
+
self.norm = PointSequential(norm_layer(out_channels))
|
| 984 |
+
if act_layer is not None:
|
| 985 |
+
self.act = PointSequential(act_layer())
|
| 986 |
+
|
| 987 |
+
def forward(self, point: Point):
|
| 988 |
+
pooling_depth = (math.ceil(self.stride) - 1).bit_length()
|
| 989 |
+
if pooling_depth > point.serialized_depth:
|
| 990 |
+
pooling_depth = 0
|
| 991 |
+
assert {
|
| 992 |
+
"serialized_code",
|
| 993 |
+
"serialized_order",
|
| 994 |
+
"serialized_inverse",
|
| 995 |
+
"serialized_depth",
|
| 996 |
+
}.issubset(
|
| 997 |
+
point.keys()
|
| 998 |
+
), "Run point.serialization() point cloud before SerializedPooling"
|
| 999 |
+
|
| 1000 |
+
code = point.serialized_code >> pooling_depth * 3
|
| 1001 |
+
code_, cluster, counts = torch.unique(
|
| 1002 |
+
code[0],
|
| 1003 |
+
sorted=True,
|
| 1004 |
+
return_inverse=True,
|
| 1005 |
+
return_counts=True,
|
| 1006 |
+
)
|
| 1007 |
+
# indices of point sorted by cluster, for torch_scatter.segment_csr
|
| 1008 |
+
_, indices = torch.sort(cluster)
|
| 1009 |
+
# index pointer for sorted point, for torch_scatter.segment_csr
|
| 1010 |
+
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
|
| 1011 |
+
# head_indices of each cluster, for reduce attr e.g. code, batch
|
| 1012 |
+
head_indices = indices[idx_ptr[:-1]]
|
| 1013 |
+
# generate down code, order, inverse
|
| 1014 |
+
code = code[:, head_indices]
|
| 1015 |
+
order = torch.argsort(code)
|
| 1016 |
+
inverse = torch.zeros_like(order).scatter_(
|
| 1017 |
+
dim=1,
|
| 1018 |
+
index=order,
|
| 1019 |
+
src=torch.arange(0, code.shape[1], device=order.device).repeat(
|
| 1020 |
+
code.shape[0], 1
|
| 1021 |
+
),
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
if self.shuffle_orders:
|
| 1025 |
+
perm = torch.randperm(code.shape[0])
|
| 1026 |
+
code = code[perm]
|
| 1027 |
+
order = order[perm]
|
| 1028 |
+
inverse = inverse[perm]
|
| 1029 |
+
|
| 1030 |
+
# collect information
|
| 1031 |
+
point_dict = addict.Dict(
|
| 1032 |
+
feat=torch_scatter.segment_csr(
|
| 1033 |
+
self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
|
| 1034 |
+
),
|
| 1035 |
+
coord=torch_scatter.segment_csr(
|
| 1036 |
+
point.coord[indices], idx_ptr, reduce="mean"
|
| 1037 |
+
),
|
| 1038 |
+
grid_coord=point.grid_coord[head_indices] >> pooling_depth,
|
| 1039 |
+
serialized_code=code,
|
| 1040 |
+
serialized_order=order,
|
| 1041 |
+
serialized_inverse=inverse,
|
| 1042 |
+
serialized_depth=point.serialized_depth - pooling_depth,
|
| 1043 |
+
batch=point.batch[head_indices],
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
if "condition" in point.keys():
|
| 1047 |
+
point_dict["condition"] = point.condition
|
| 1048 |
+
if "context" in point.keys():
|
| 1049 |
+
point_dict["context"] = point.context
|
| 1050 |
+
|
| 1051 |
+
if self.traceable:
|
| 1052 |
+
point_dict["pooling_inverse"] = cluster
|
| 1053 |
+
point_dict["pooling_parent"] = point
|
| 1054 |
+
|
| 1055 |
+
point = Point(point_dict)
|
| 1056 |
+
# Fix: Expected more than 1 value per channel when training
|
| 1057 |
+
if self.norm is not None and point.feat.size(0) != 1:
|
| 1058 |
+
point = self.norm(point)
|
| 1059 |
+
if self.act is not None:
|
| 1060 |
+
point = self.act(point)
|
| 1061 |
+
point.sparsify()
|
| 1062 |
+
|
| 1063 |
+
return point
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
class SerializedUnpooling(PointModule):
|
| 1067 |
+
def __init__(
|
| 1068 |
+
self,
|
| 1069 |
+
in_channels,
|
| 1070 |
+
skip_channels,
|
| 1071 |
+
out_channels,
|
| 1072 |
+
norm_layer=None,
|
| 1073 |
+
act_layer=None,
|
| 1074 |
+
traceable=False, # record parent and cluster
|
| 1075 |
+
):
|
| 1076 |
+
super().__init__()
|
| 1077 |
+
self.proj = PointSequential(torch.nn.Linear(in_channels, out_channels))
|
| 1078 |
+
self.proj_skip = PointSequential(torch.nn.Linear(skip_channels, out_channels))
|
| 1079 |
+
|
| 1080 |
+
if norm_layer is not None:
|
| 1081 |
+
self.proj.add(norm_layer(out_channels))
|
| 1082 |
+
self.proj_skip.add(norm_layer(out_channels))
|
| 1083 |
+
|
| 1084 |
+
if act_layer is not None:
|
| 1085 |
+
self.proj.add(act_layer())
|
| 1086 |
+
self.proj_skip.add(act_layer())
|
| 1087 |
+
|
| 1088 |
+
self.traceable = traceable
|
| 1089 |
+
|
| 1090 |
+
def forward(self, point):
|
| 1091 |
+
assert "pooling_parent" in point.keys()
|
| 1092 |
+
assert "pooling_inverse" in point.keys()
|
| 1093 |
+
parent = point.pop("pooling_parent")
|
| 1094 |
+
inverse = point.pop("pooling_inverse")
|
| 1095 |
+
point = self.proj(point)
|
| 1096 |
+
parent = self.proj_skip(parent)
|
| 1097 |
+
parent.feat = parent.feat + point.feat[inverse]
|
| 1098 |
+
|
| 1099 |
+
if self.traceable:
|
| 1100 |
+
parent["unpooling_parent"] = point
|
| 1101 |
+
return parent
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
class Embedding(PointModule):
|
| 1105 |
+
def __init__(
|
| 1106 |
+
self,
|
| 1107 |
+
in_channels,
|
| 1108 |
+
embed_channels,
|
| 1109 |
+
norm_layer=None,
|
| 1110 |
+
act_layer=None,
|
| 1111 |
+
):
|
| 1112 |
+
super().__init__()
|
| 1113 |
+
self.in_channels = in_channels
|
| 1114 |
+
self.embed_channels = embed_channels
|
| 1115 |
+
|
| 1116 |
+
# TODO: check remove spconv
|
| 1117 |
+
self.stem = PointSequential(
|
| 1118 |
+
conv=spconv.SubMConv3d(
|
| 1119 |
+
in_channels,
|
| 1120 |
+
embed_channels,
|
| 1121 |
+
kernel_size=5,
|
| 1122 |
+
padding=1,
|
| 1123 |
+
bias=False,
|
| 1124 |
+
indice_key="stem",
|
| 1125 |
+
)
|
| 1126 |
+
)
|
| 1127 |
+
if norm_layer is not None:
|
| 1128 |
+
self.stem.add(norm_layer(embed_channels), name="norm")
|
| 1129 |
+
if act_layer is not None:
|
| 1130 |
+
self.stem.add(act_layer(), name="act")
|
| 1131 |
+
|
| 1132 |
+
def forward(self, point: Point):
|
| 1133 |
+
point = self.stem(point)
|
| 1134 |
+
return point
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
class PointTransformerV3(PointModule):
|
| 1138 |
+
def __init__(
|
| 1139 |
+
self,
|
| 1140 |
+
in_channels=6,
|
| 1141 |
+
order=("cord"),
|
| 1142 |
+
stride=(2, 2, 2, 2),
|
| 1143 |
+
enc_depths=(2, 2, 2, 6, 2),
|
| 1144 |
+
enc_channels=(32, 64, 128, 256, 512),
|
| 1145 |
+
enc_num_head=(2, 4, 8, 16, 32),
|
| 1146 |
+
enc_patch_size=(1024, 1024, 1024, 1024, 1024),
|
| 1147 |
+
dec_depths=(2, 2, 2, 2),
|
| 1148 |
+
dec_channels=(64, 64, 128, 256),
|
| 1149 |
+
dec_num_head=(4, 4, 8, 16),
|
| 1150 |
+
dec_patch_size=(1024, 1024, 1024, 1024),
|
| 1151 |
+
mlp_ratio=4,
|
| 1152 |
+
grid_size=0.01,
|
| 1153 |
+
qkv_bias=True,
|
| 1154 |
+
qk_scale=None,
|
| 1155 |
+
attn_drop=0.0,
|
| 1156 |
+
proj_drop=0.0,
|
| 1157 |
+
drop_path=0.3,
|
| 1158 |
+
pre_norm=True,
|
| 1159 |
+
shuffle_orders=True,
|
| 1160 |
+
enable_rpe=False,
|
| 1161 |
+
enable_flash=True,
|
| 1162 |
+
upcast_attention=False,
|
| 1163 |
+
upcast_softmax=False,
|
| 1164 |
+
cls_mode=False,
|
| 1165 |
+
pdnorm_bn=False,
|
| 1166 |
+
pdnorm_ln=False,
|
| 1167 |
+
pdnorm_decouple=True,
|
| 1168 |
+
pdnorm_adaptive=False,
|
| 1169 |
+
pdnorm_affine=True,
|
| 1170 |
+
pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"),
|
| 1171 |
+
):
|
| 1172 |
+
super().__init__()
|
| 1173 |
+
self.num_stages = len(enc_depths)
|
| 1174 |
+
self.order = [order] if isinstance(order, str) else order
|
| 1175 |
+
self.cls_mode = cls_mode
|
| 1176 |
+
self.shuffle_orders = shuffle_orders
|
| 1177 |
+
self.grid_size = grid_size
|
| 1178 |
+
|
| 1179 |
+
assert self.num_stages == len(stride) + 1
|
| 1180 |
+
assert self.num_stages == len(enc_depths)
|
| 1181 |
+
assert self.num_stages == len(enc_channels)
|
| 1182 |
+
assert self.num_stages == len(enc_num_head)
|
| 1183 |
+
assert self.num_stages == len(enc_patch_size)
|
| 1184 |
+
assert self.cls_mode or self.num_stages == len(dec_depths) + 1
|
| 1185 |
+
assert self.cls_mode or self.num_stages == len(dec_channels) + 1
|
| 1186 |
+
assert self.cls_mode or self.num_stages == len(dec_num_head) + 1
|
| 1187 |
+
assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1
|
| 1188 |
+
|
| 1189 |
+
# norm layers
|
| 1190 |
+
if pdnorm_bn:
|
| 1191 |
+
bn_layer = functools.partial(
|
| 1192 |
+
PDNorm,
|
| 1193 |
+
norm_layer=functools.partial(
|
| 1194 |
+
torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine
|
| 1195 |
+
),
|
| 1196 |
+
conditions=pdnorm_conditions,
|
| 1197 |
+
decouple=pdnorm_decouple,
|
| 1198 |
+
adaptive=pdnorm_adaptive,
|
| 1199 |
+
)
|
| 1200 |
+
else:
|
| 1201 |
+
bn_layer = functools.partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01)
|
| 1202 |
+
if pdnorm_ln:
|
| 1203 |
+
ln_layer = functools.partial(
|
| 1204 |
+
PDNorm,
|
| 1205 |
+
norm_layer=functools.partial(
|
| 1206 |
+
torch.nn.LayerNorm, elementwise_affine=pdnorm_affine
|
| 1207 |
+
),
|
| 1208 |
+
conditions=pdnorm_conditions,
|
| 1209 |
+
decouple=pdnorm_decouple,
|
| 1210 |
+
adaptive=pdnorm_adaptive,
|
| 1211 |
+
)
|
| 1212 |
+
else:
|
| 1213 |
+
ln_layer = torch.nn.LayerNorm
|
| 1214 |
+
# activation layers
|
| 1215 |
+
act_layer = torch.nn.GELU
|
| 1216 |
+
self.embedding = Embedding(
|
| 1217 |
+
in_channels=in_channels,
|
| 1218 |
+
embed_channels=enc_channels[0],
|
| 1219 |
+
norm_layer=bn_layer,
|
| 1220 |
+
act_layer=act_layer,
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
# encoder
|
| 1224 |
+
enc_drop_path = [
|
| 1225 |
+
x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
|
| 1226 |
+
]
|
| 1227 |
+
self.enc = PointSequential(name="encoder")
|
| 1228 |
+
for s in range(self.num_stages):
|
| 1229 |
+
enc_drop_path_ = enc_drop_path[
|
| 1230 |
+
sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
|
| 1231 |
+
]
|
| 1232 |
+
enc = PointSequential(name="encoder_layer_%d" % s)
|
| 1233 |
+
if s > 0:
|
| 1234 |
+
enc.add(
|
| 1235 |
+
SerializedPooling(
|
| 1236 |
+
in_channels=enc_channels[s - 1],
|
| 1237 |
+
out_channels=enc_channels[s],
|
| 1238 |
+
stride=stride[s - 1],
|
| 1239 |
+
norm_layer=bn_layer,
|
| 1240 |
+
act_layer=act_layer,
|
| 1241 |
+
),
|
| 1242 |
+
name="down",
|
| 1243 |
+
)
|
| 1244 |
+
for i in range(enc_depths[s]):
|
| 1245 |
+
enc.add(
|
| 1246 |
+
Block(
|
| 1247 |
+
channels=enc_channels[s],
|
| 1248 |
+
num_heads=enc_num_head[s],
|
| 1249 |
+
patch_size=enc_patch_size[s],
|
| 1250 |
+
mlp_ratio=mlp_ratio,
|
| 1251 |
+
qkv_bias=qkv_bias,
|
| 1252 |
+
qk_scale=qk_scale,
|
| 1253 |
+
attn_drop=attn_drop,
|
| 1254 |
+
proj_drop=proj_drop,
|
| 1255 |
+
drop_path=enc_drop_path_[i],
|
| 1256 |
+
norm_layer=ln_layer,
|
| 1257 |
+
act_layer=act_layer,
|
| 1258 |
+
pre_norm=pre_norm,
|
| 1259 |
+
order_index=i % len(self.order),
|
| 1260 |
+
cpe_indice_key=f"stage{s}",
|
| 1261 |
+
enable_rpe=enable_rpe,
|
| 1262 |
+
enable_flash=enable_flash,
|
| 1263 |
+
upcast_attention=upcast_attention,
|
| 1264 |
+
upcast_softmax=upcast_softmax,
|
| 1265 |
+
),
|
| 1266 |
+
name=f"block{i}",
|
| 1267 |
+
)
|
| 1268 |
+
if len(enc) != 0:
|
| 1269 |
+
self.enc.add(module=enc, name=f"enc{s}")
|
| 1270 |
+
|
| 1271 |
+
# decoder
|
| 1272 |
+
if not self.cls_mode:
|
| 1273 |
+
dec_drop_path = [
|
| 1274 |
+
x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
|
| 1275 |
+
]
|
| 1276 |
+
self.dec = PointSequential(name="decoder")
|
| 1277 |
+
dec_channels = list(dec_channels) + [enc_channels[-1]]
|
| 1278 |
+
for s in reversed(range(self.num_stages - 1)):
|
| 1279 |
+
dec_drop_path_ = dec_drop_path[
|
| 1280 |
+
sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
|
| 1281 |
+
]
|
| 1282 |
+
dec_drop_path_.reverse()
|
| 1283 |
+
dec = PointSequential(name="decoder_layer_%d" % s)
|
| 1284 |
+
dec.add(
|
| 1285 |
+
SerializedUnpooling(
|
| 1286 |
+
in_channels=dec_channels[s + 1],
|
| 1287 |
+
skip_channels=enc_channels[s],
|
| 1288 |
+
out_channels=dec_channels[s],
|
| 1289 |
+
norm_layer=bn_layer,
|
| 1290 |
+
act_layer=act_layer,
|
| 1291 |
+
),
|
| 1292 |
+
name="up",
|
| 1293 |
+
)
|
| 1294 |
+
for i in range(dec_depths[s]):
|
| 1295 |
+
dec.add(
|
| 1296 |
+
Block(
|
| 1297 |
+
channels=dec_channels[s],
|
| 1298 |
+
num_heads=dec_num_head[s],
|
| 1299 |
+
patch_size=dec_patch_size[s],
|
| 1300 |
+
mlp_ratio=mlp_ratio,
|
| 1301 |
+
qkv_bias=qkv_bias,
|
| 1302 |
+
qk_scale=qk_scale,
|
| 1303 |
+
attn_drop=attn_drop,
|
| 1304 |
+
proj_drop=proj_drop,
|
| 1305 |
+
drop_path=dec_drop_path_[i],
|
| 1306 |
+
norm_layer=ln_layer,
|
| 1307 |
+
act_layer=act_layer,
|
| 1308 |
+
pre_norm=pre_norm,
|
| 1309 |
+
order_index=i % len(self.order),
|
| 1310 |
+
cpe_indice_key=f"stage{s}",
|
| 1311 |
+
enable_rpe=enable_rpe,
|
| 1312 |
+
enable_flash=enable_flash,
|
| 1313 |
+
upcast_attention=upcast_attention,
|
| 1314 |
+
upcast_softmax=upcast_softmax,
|
| 1315 |
+
),
|
| 1316 |
+
name=f"block{i}",
|
| 1317 |
+
)
|
| 1318 |
+
self.dec.add(module=dec, name=f"dec{s}")
|
| 1319 |
+
|
| 1320 |
+
def forward(self, batch, feat, coord):
|
| 1321 |
+
"""
|
| 1322 |
+
A data_dict is a dictionary containing properties of a batched point cloud.
|
| 1323 |
+
It should contain the following properties for PTv3:
|
| 1324 |
+
1. "feat": feature of point cloud
|
| 1325 |
+
2. "grid_coord": discrete coordinate after grid sampling (voxelization) or "coord" + "grid_size"
|
| 1326 |
+
3. "offset" or "batch": https://github.com/Pointcept/Pointcept?tab=readme-ov-file#offset
|
| 1327 |
+
"""
|
| 1328 |
+
point = Point(
|
| 1329 |
+
{
|
| 1330 |
+
"batch": batch.squeeze(dim=0),
|
| 1331 |
+
"feat": feat.squeeze(dim=0),
|
| 1332 |
+
"coord": coord.squeeze(dim=0),
|
| 1333 |
+
"grid_size": self.grid_size,
|
| 1334 |
+
}
|
| 1335 |
+
)
|
| 1336 |
+
point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
|
| 1337 |
+
point.sparsify()
|
| 1338 |
+
|
| 1339 |
+
point = self.embedding(point)
|
| 1340 |
+
point = self.enc(point)
|
| 1341 |
+
if not self.cls_mode:
|
| 1342 |
+
point = self.dec(point)
|
| 1343 |
+
|
| 1344 |
+
return point.feat.unsqueeze(dim=0)
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
+
easydict
|
| 3 |
+
gradio
|
| 4 |
+
numpy<2.0.0
|
| 5 |
+
opencv-python
|
| 6 |
+
pillow
|
| 7 |
+
scipy
|
| 8 |
+
torch==2.2.2
|
| 9 |
+
torchvision==0.17.2
|
| 10 |
+
|
| 11 |
+
addict
|
| 12 |
+
spconv-cu121
|
| 13 |
+
torch_scatter
|
| 14 |
+
|