Spaces:
Running on Zero
Running on Zero
lixi042 commited on
Commit ·
510e990
1
Parent(s): a4415c0
Initial commit: Argus metric panoramic 3D reconstruction demo
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -0
- app.py +1499 -0
- argus/__init__.py +2 -0
- argus/heads/__init__.py +2 -0
- argus/heads/camera_head.py +142 -0
- argus/heads/dpt_head.py +474 -0
- argus/heads/head_act.py +122 -0
- argus/heads/utils.py +142 -0
- argus/layers/__init__.py +8 -0
- argus/layers/attention.py +93 -0
- argus/layers/block.py +247 -0
- argus/layers/drop_path.py +34 -0
- argus/layers/layer_scale.py +22 -0
- argus/layers/mlp.py +40 -0
- argus/layers/patch_embed.py +85 -0
- argus/layers/rope.py +188 -0
- argus/layers/swiglu_ffn.py +67 -0
- argus/layers/vision_transformer.py +401 -0
- argus/models/__init__.py +2 -0
- argus/models/aggregator.py +502 -0
- argus/models/argus.py +234 -0
- argus/utils/__init__.py +2 -0
- argus/utils/data_io.py +152 -0
- argus/utils/geometry.py +201 -0
- argus/utils/normalization.py +65 -0
- argus/utils/pose_enc.py +105 -0
- argus/utils/rotation.py +118 -0
- assets/argus_logo.png +3 -0
- examples/far_4/0.jpg +3 -0
- examples/far_4/1.jpg +3 -0
- examples/far_4/2.jpg +3 -0
- examples/far_4/3.jpg +3 -0
- examples/scene_00008/1757748389.jpg +3 -0
- examples/scene_00008/1757748429.jpg +3 -0
- examples/scene_00008/1757748477.jpg +3 -0
- examples/scene_00008/1757748528.jpg +3 -0
- examples/scene_00008/1757748562.jpg +3 -0
- examples/scene_00008/1757748600.jpg +3 -0
- examples/scene_00008/1757748638.jpg +3 -0
- examples/scene_00008/1757748685.jpg +3 -0
- examples/scene_00008/1757748728.jpg +3 -0
- examples/scene_00008/1757748770.jpg +3 -0
- examples/scene_00008/1757748817.jpg +3 -0
- examples/scene_00008/1757748866.jpg +3 -0
- examples/scene_00008/1757748907.jpg +3 -0
- examples/scene_00008/1757748959.jpg +3 -0
- examples/scene_00008/1757749004.jpg +3 -0
- examples/scene_00008/1757749043.jpg +3 -0
- examples/scene_00008/1757749091.jpg +3 -0
- examples/scene_00008/1757749140.jpg +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,1499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Standard library imports
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import shutil
|
| 5 |
+
import glob
|
| 6 |
+
import gc
|
| 7 |
+
import time
|
| 8 |
+
import base64
|
| 9 |
+
import argparse
|
| 10 |
+
import tempfile
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
# Third-party library imports
|
| 15 |
+
import cv2
|
| 16 |
+
import torch
|
| 17 |
+
import trimesh
|
| 18 |
+
import numpy as np
|
| 19 |
+
import gradio as gr
|
| 20 |
+
import matplotlib
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
from scipy.spatial.transform import Rotation
|
| 23 |
+
|
| 24 |
+
# Custom module imports
|
| 25 |
+
from argus.models.argus import Argus
|
| 26 |
+
from argus.utils.pose_enc import pose_encoding_to_extri360
|
| 27 |
+
from argus.utils.geometry import unproject_depth_to_world_points
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# -------------------------- Argument Parsing --------------------------
|
| 31 |
+
def parse_args():
|
| 32 |
+
parser = argparse.ArgumentParser(description="Argus Gradio Demo")
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--model_path",
|
| 35 |
+
type=str,
|
| 36 |
+
default=None,
|
| 37 |
+
help="Path to pre-trained model weights (.pt file). "
|
| 38 |
+
"If not specified, auto-downloads from HuggingFace.",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--img_size",
|
| 42 |
+
type=int,
|
| 43 |
+
default=560,
|
| 44 |
+
help="Input panoramic image target width (height = width // 2)",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--crop_ratio",
|
| 48 |
+
type=float,
|
| 49 |
+
default=0.15,
|
| 50 |
+
help="Vertical crop ratio for panoramic image preprocessing (0-0.5)",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--port",
|
| 54 |
+
type=int,
|
| 55 |
+
default=7860,
|
| 56 |
+
help="Port number for Gradio server",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--share",
|
| 60 |
+
action="store_true",
|
| 61 |
+
default=False,
|
| 62 |
+
help="Enable Gradio public sharing link",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--server_name",
|
| 66 |
+
type=str,
|
| 67 |
+
default="0.0.0.0",
|
| 68 |
+
help="Server host address (0.0.0.0 for all interfaces)",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--device",
|
| 72 |
+
type=str,
|
| 73 |
+
default=None,
|
| 74 |
+
help="Device to use (cuda/cpu). Default: auto-detect",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--examples_dir",
|
| 78 |
+
type=str,
|
| 79 |
+
default="examples",
|
| 80 |
+
help="Directory containing example scenes",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--save_tmp",
|
| 84 |
+
type=str,
|
| 85 |
+
default=None,
|
| 86 |
+
help="Directory to persist intermediate files (images, predictions, GLB). "
|
| 87 |
+
"If not set, uses system temp dir and cleans up automatically.",
|
| 88 |
+
)
|
| 89 |
+
return parser.parse_args()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
args = parse_args()
|
| 93 |
+
|
| 94 |
+
# -------------------------- Global Configuration --------------------------
|
| 95 |
+
# Device configuration: use specified device or auto-detect
|
| 96 |
+
DEVICE = args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 97 |
+
# Input panoramic image target size (ERP: W=img_size, H=img_size//2)
|
| 98 |
+
IMG_SIZE = args.img_size
|
| 99 |
+
# Vertical crop ratio for panoramic image preprocessing
|
| 100 |
+
CROP_RATIO = args.crop_ratio
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def resolve_model_path(model_path: str) -> str:
|
| 104 |
+
"""
|
| 105 |
+
Resolve model path: if a local file is specified and exists, use it directly;
|
| 106 |
+
otherwise download from HuggingFace Hub.
|
| 107 |
+
Requires `huggingface-cli login` for gated repos.
|
| 108 |
+
"""
|
| 109 |
+
if model_path is not None and os.path.isfile(model_path):
|
| 110 |
+
return model_path
|
| 111 |
+
|
| 112 |
+
if model_path is not None:
|
| 113 |
+
print(f"Specified model path '{model_path}' not found.")
|
| 114 |
+
|
| 115 |
+
print("Downloading model from HuggingFace (RealseeTechnology/argus-realsee3d)...")
|
| 116 |
+
try:
|
| 117 |
+
from huggingface_hub import hf_hub_download
|
| 118 |
+
downloaded_path = hf_hub_download(
|
| 119 |
+
repo_id="RealseeTechnology/argus-realsee3d",
|
| 120 |
+
filename="argus_realsee3d.pt",
|
| 121 |
+
)
|
| 122 |
+
print(f"Model downloaded to: {downloaded_path}")
|
| 123 |
+
return downloaded_path
|
| 124 |
+
except Exception as e:
|
| 125 |
+
error_msg = str(e)
|
| 126 |
+
if "GatedRepoError" in type(e).__name__ or "401" in error_msg:
|
| 127 |
+
raise RuntimeError(
|
| 128 |
+
"Cannot access gated model repo. Please authenticate first:\n"
|
| 129 |
+
" 1. Run: hf auth login\n"
|
| 130 |
+
" 2. Accept the model license at: https://huggingface.co/RealseeTechnology/argus-realsee3d\n"
|
| 131 |
+
" 3. Re-run this script.\n"
|
| 132 |
+
"Or download manually and specify --model_path."
|
| 133 |
+
) from e
|
| 134 |
+
raise
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# Pre-trained model path (auto-download if not found locally)
|
| 138 |
+
MODEL_PATH = resolve_model_path(args.model_path)
|
| 139 |
+
|
| 140 |
+
# -------------------------- Model Initialization --------------------------
|
| 141 |
+
print("Initializing and loading Argus model...")
|
| 142 |
+
# Initialize Argus model with metric scale and learning ref reorder
|
| 143 |
+
model = Argus(reorder_by_learning_ref=True, restore_metric_scale=True)
|
| 144 |
+
# Load model weights (non-strict to ignore unused parameters)
|
| 145 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)["model"], strict=False)
|
| 146 |
+
# Set model to evaluation mode and move to target device
|
| 147 |
+
model.eval()
|
| 148 |
+
model = model.to(DEVICE)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# -------------------------- Image Preprocessing --------------------------
|
| 152 |
+
def load_and_preprocess_images(image_path_list, target_size=IMG_SIZE):
|
| 153 |
+
"""
|
| 154 |
+
Load and preprocess panoramic images for model inference
|
| 155 |
+
Args:
|
| 156 |
+
image_path_list (list): List of input image file paths
|
| 157 |
+
target_size (int): Target width of panoramic image (height = target_size//2)
|
| 158 |
+
Returns:
|
| 159 |
+
torch.Tensor: Preprocessed tensor with shape (S, C, H, W)
|
| 160 |
+
S: sequence length, C: 3(RGB), H/W: image size
|
| 161 |
+
"""
|
| 162 |
+
images = []
|
| 163 |
+
pano_W, pano_H = target_size, target_size // 2
|
| 164 |
+
|
| 165 |
+
# Load and resize each image
|
| 166 |
+
for image_path in image_path_list:
|
| 167 |
+
img = cv2.imread(image_path) # Load as BGR (H, W, C)
|
| 168 |
+
h, w = img.shape[:2]
|
| 169 |
+
if w != pano_W or h != pano_H:
|
| 170 |
+
img = cv2.resize(img, (pano_W, pano_H), interpolation=cv2.INTER_AREA)
|
| 171 |
+
images.append(img)
|
| 172 |
+
|
| 173 |
+
# Stack and preprocess: crop vertical → BGR2RGB → normalize → reshape
|
| 174 |
+
images = np.stack(images) # (S, H, W, C)
|
| 175 |
+
# Crop top/bottom 15% of height and convert BGR to RGB
|
| 176 |
+
images = np.ascontiguousarray(
|
| 177 |
+
images[:, int(pano_H * CROP_RATIO) : int(pano_H * (1 - CROP_RATIO)), :, ::-1]
|
| 178 |
+
)
|
| 179 |
+
# Convert to tensor and normalize to [0,1]
|
| 180 |
+
images = torch.from_numpy(images).float() / 255.0
|
| 181 |
+
# Reshape to (S, C, H, W) for PyTorch model input
|
| 182 |
+
images = images.permute(0, 3, 1, 2)
|
| 183 |
+
|
| 184 |
+
return images
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# -------------------------- Point Cloud Utils --------------------------
|
| 188 |
+
def save_point_cloud_to_ply(points: np.ndarray, save_path: str):
|
| 189 |
+
"""
|
| 190 |
+
Save 3D point cloud (N,3) to PLY format (ASCII) for universal compatibility
|
| 191 |
+
Args:
|
| 192 |
+
points (np.ndarray): 3D point cloud with shape [N, 3] (x, y, z for each point)
|
| 193 |
+
save_path (str): Output PLY file path
|
| 194 |
+
Raises:
|
| 195 |
+
ValueError: If input points shape is not [N, 3]
|
| 196 |
+
"""
|
| 197 |
+
# Validate input point cloud shape
|
| 198 |
+
if points.ndim != 2 or points.shape[1] != 3:
|
| 199 |
+
raise ValueError(f"Point cloud must be [N,3], got {points.shape}")
|
| 200 |
+
|
| 201 |
+
num_points = points.shape[0]
|
| 202 |
+
# PLY format header (follow official specification)
|
| 203 |
+
ply_header = f"""ply
|
| 204 |
+
format ascii 1.0
|
| 205 |
+
element vertex {num_points}
|
| 206 |
+
property float x
|
| 207 |
+
property float y
|
| 208 |
+
property float z
|
| 209 |
+
end_header
|
| 210 |
+
"""
|
| 211 |
+
# Write header and point data to file
|
| 212 |
+
with open(save_path, "w", encoding="utf-8") as f:
|
| 213 |
+
f.write(ply_header)
|
| 214 |
+
np.savetxt(f, points, fmt="%.6f %.6f %.6f")
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# -------------------------- Core Model Inference --------------------------
|
| 218 |
+
def run_model(target_dir, model) -> dict:
|
| 219 |
+
"""
|
| 220 |
+
Run Argus model inference on images in target_dir/images
|
| 221 |
+
Args:
|
| 222 |
+
target_dir (str): Root directory containing 'images' subfolder
|
| 223 |
+
model (Argus): Pre-initialized Argus model
|
| 224 |
+
Returns:
|
| 225 |
+
dict: Model predictions with tensor converted to numpy array
|
| 226 |
+
Raises:
|
| 227 |
+
ValueError: If CUDA unavailable or no images found in target_dir
|
| 228 |
+
"""
|
| 229 |
+
print(f"Processing images from {target_dir}")
|
| 230 |
+
|
| 231 |
+
# Enforce CUDA for inference
|
| 232 |
+
if not torch.cuda.is_available():
|
| 233 |
+
raise ValueError("CUDA is not available. Inference requires GPU acceleration.")
|
| 234 |
+
|
| 235 |
+
model = model.to(DEVICE)
|
| 236 |
+
model.eval()
|
| 237 |
+
|
| 238 |
+
# Load and sort input images
|
| 239 |
+
image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
|
| 240 |
+
print(f"Found {len(image_names)} input images")
|
| 241 |
+
if len(image_names) == 0:
|
| 242 |
+
raise ValueError("No images found in target_dir/images. Check your upload.")
|
| 243 |
+
|
| 244 |
+
# Preprocess images and move to device
|
| 245 |
+
images = load_and_preprocess_images(image_names, target_size=IMG_SIZE).to(DEVICE)
|
| 246 |
+
print(f"Preprocessed images shape: {images.shape}")
|
| 247 |
+
|
| 248 |
+
# Mixed precision inference for speed and memory efficiency
|
| 249 |
+
print("Running model inference...")
|
| 250 |
+
dtype = (
|
| 251 |
+
torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
torch.cuda.synchronize()
|
| 255 |
+
t0 = time.perf_counter()
|
| 256 |
+
|
| 257 |
+
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
| 258 |
+
predictions = model(images)
|
| 259 |
+
|
| 260 |
+
torch.cuda.synchronize()
|
| 261 |
+
t1 = time.perf_counter()
|
| 262 |
+
inference_time = t1 - t0
|
| 263 |
+
print(f"Inference time: {inference_time:.3f} s")
|
| 264 |
+
|
| 265 |
+
# Convert pose encoding to extrinsic/intrinsic matrices
|
| 266 |
+
print("Converting pose encoding to extrinsic matrices...")
|
| 267 |
+
extrinsic, conf = pose_encoding_to_extri360(pose_encoding=predictions["pose_enc"])
|
| 268 |
+
predictions["extrinsic"] = extrinsic[:, :, :3, :]
|
| 269 |
+
|
| 270 |
+
# Unproject depth map to 3D world coordinates
|
| 271 |
+
print("Computing 3D world points from depth map...")
|
| 272 |
+
world_points = unproject_depth_to_world_points(
|
| 273 |
+
predictions["depth"], predictions["extrinsic"], size=IMG_SIZE
|
| 274 |
+
)
|
| 275 |
+
predictions["world_points_from_depth"] = world_points
|
| 276 |
+
|
| 277 |
+
# Convert all torch tensors to numpy arrays and remove batch dimension
|
| 278 |
+
print("Converting model outputs to numpy arrays...")
|
| 279 |
+
for key in predictions.keys():
|
| 280 |
+
if isinstance(predictions[key], torch.Tensor):
|
| 281 |
+
predictions[key] = predictions[key].cpu().float().numpy().squeeze(0)
|
| 282 |
+
elif isinstance(predictions[key], list):
|
| 283 |
+
for i in range(len(predictions[key])):
|
| 284 |
+
if isinstance(predictions[key][i], torch.Tensor):
|
| 285 |
+
predictions[key][i] = (
|
| 286 |
+
predictions[key][i].cpu().float().numpy().squeeze(0)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
print(f"Model prediction keys: {predictions.keys()}")
|
| 290 |
+
# Clear CUDA cache to save memory
|
| 291 |
+
torch.cuda.empty_cache()
|
| 292 |
+
return predictions, inference_time
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# -------------------------- Upload File Handling --------------------------
|
| 296 |
+
def handle_uploads(input_images):
|
| 297 |
+
"""
|
| 298 |
+
Create directory for uploaded images and copy files to target path.
|
| 299 |
+
Uses system temp dir by default; uses --save_tmp dir if specified.
|
| 300 |
+
Args:
|
| 301 |
+
input_images: Gradio uploaded file data
|
| 302 |
+
Returns:
|
| 303 |
+
tuple: (target_dir, sorted_image_paths)
|
| 304 |
+
"""
|
| 305 |
+
start_time = time.time()
|
| 306 |
+
gc.collect()
|
| 307 |
+
torch.cuda.empty_cache()
|
| 308 |
+
|
| 309 |
+
# Create target directory: persistent if --save_tmp is set, otherwise temp
|
| 310 |
+
if args.save_tmp:
|
| 311 |
+
os.makedirs(args.save_tmp, exist_ok=True)
|
| 312 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 313 |
+
target_dir = os.path.join(args.save_tmp, f"input_images_{timestamp}")
|
| 314 |
+
else:
|
| 315 |
+
target_dir = tempfile.mkdtemp(prefix="argus_")
|
| 316 |
+
target_img_dir = os.path.join(target_dir, "images")
|
| 317 |
+
|
| 318 |
+
# Clean up if directory exists (edge case)
|
| 319 |
+
if os.path.exists(target_dir) and args.save_tmp:
|
| 320 |
+
shutil.rmtree(target_dir)
|
| 321 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 322 |
+
os.makedirs(target_img_dir, exist_ok=True)
|
| 323 |
+
|
| 324 |
+
# Copy uploaded images to target directory
|
| 325 |
+
image_paths = []
|
| 326 |
+
if input_images is not None:
|
| 327 |
+
for file_data in input_images:
|
| 328 |
+
# Get file path from Gradio file data
|
| 329 |
+
file_path = file_data["name"] if isinstance(file_data, dict) else file_data
|
| 330 |
+
dst_path = os.path.join(target_img_dir, os.path.basename(file_path))
|
| 331 |
+
shutil.copy(file_path, dst_path)
|
| 332 |
+
image_paths.append(dst_path)
|
| 333 |
+
|
| 334 |
+
# Sort images for consistent processing
|
| 335 |
+
image_paths = sorted(image_paths)
|
| 336 |
+
print(
|
| 337 |
+
f"Files copied to {target_img_dir} | Time cost: {time.time() - start_time:.3f}s"
|
| 338 |
+
)
|
| 339 |
+
return target_dir, image_paths
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def update_gallery_on_upload(input_images):
|
| 343 |
+
"""
|
| 344 |
+
Update image gallery immediately after file upload
|
| 345 |
+
Args:
|
| 346 |
+
input_images: Gradio uploaded file data
|
| 347 |
+
Returns:
|
| 348 |
+
tuple: Gradio component update values
|
| 349 |
+
"""
|
| 350 |
+
if not input_images:
|
| 351 |
+
return None, None, None, None
|
| 352 |
+
target_dir, image_paths = handle_uploads(input_images)
|
| 353 |
+
return (
|
| 354 |
+
None,
|
| 355 |
+
target_dir,
|
| 356 |
+
image_paths,
|
| 357 |
+
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# -------------------------- 3D Reconstruction Pipeline --------------------------
|
| 362 |
+
def gradio_demo(
|
| 363 |
+
target_dir,
|
| 364 |
+
conf_thres=5.0,
|
| 365 |
+
frame_filter="All",
|
| 366 |
+
show_cam=True,
|
| 367 |
+
show_index=True,
|
| 368 |
+
ceiling_remove=25,
|
| 369 |
+
):
|
| 370 |
+
"""
|
| 371 |
+
Main 3D reconstruction pipeline for Gradio interface
|
| 372 |
+
Args:
|
| 373 |
+
target_dir (str): Directory with input images
|
| 374 |
+
conf_thres (float): Confidence threshold for point cloud filtering
|
| 375 |
+
frame_filter (str): Filter frames to show in 3D model
|
| 376 |
+
show_cam (bool): Whether to show camera poses in 3D model
|
| 377 |
+
show_index (bool): Whether to show frame indices in 3D model
|
| 378 |
+
ceiling_remove (float): Percentage of top Y-coordinate points to remove as ceiling (0-100, 0=disabled)
|
| 379 |
+
Returns:
|
| 380 |
+
tuple: Gradio component update values (3D model, logs, dropdown, etc.)
|
| 381 |
+
"""
|
| 382 |
+
# Validate target directory
|
| 383 |
+
if not os.path.isdir(target_dir) or target_dir == "None":
|
| 384 |
+
return (
|
| 385 |
+
None,
|
| 386 |
+
"No valid target directory. Please upload images first.",
|
| 387 |
+
None,
|
| 388 |
+
None,
|
| 389 |
+
None,
|
| 390 |
+
"",
|
| 391 |
+
None,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
start_time = time.time()
|
| 395 |
+
gc.collect()
|
| 396 |
+
torch.cuda.empty_cache()
|
| 397 |
+
|
| 398 |
+
# Prepare frame filter dropdown options
|
| 399 |
+
target_img_dir = os.path.join(target_dir, "images")
|
| 400 |
+
all_files = (
|
| 401 |
+
sorted(os.listdir(target_img_dir)) if os.path.isdir(target_img_dir) else []
|
| 402 |
+
)
|
| 403 |
+
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
|
| 404 |
+
frame_filter_choices = ["All"] + all_files
|
| 405 |
+
|
| 406 |
+
# Run model inference
|
| 407 |
+
with torch.no_grad():
|
| 408 |
+
predictions, inference_time = run_model(target_dir, model)
|
| 409 |
+
|
| 410 |
+
# Save predictions to NPZ for later visualization update
|
| 411 |
+
pred_save_path = os.path.join(target_dir, "predictions.npz")
|
| 412 |
+
np.savez(pred_save_path, **predictions)
|
| 413 |
+
|
| 414 |
+
# Default frame filter to All if None
|
| 415 |
+
frame_filter = frame_filter if frame_filter is not None else "All"
|
| 416 |
+
|
| 417 |
+
# Generate unique GLB filename with parameters
|
| 418 |
+
glb_filename = f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_index{show_index}_ceiling{ceiling_remove}.glb"
|
| 419 |
+
glbfile = os.path.join(target_dir, glb_filename)
|
| 420 |
+
|
| 421 |
+
# Convert model predictions to GLB 3D model
|
| 422 |
+
glbscene = predictions_to_glb(
|
| 423 |
+
predictions,
|
| 424 |
+
conf_thres=conf_thres,
|
| 425 |
+
filter_by_frames=frame_filter,
|
| 426 |
+
show_cam=show_cam,
|
| 427 |
+
show_index=show_index,
|
| 428 |
+
ceiling_remove=ceiling_remove,
|
| 429 |
+
target_dir=target_dir,
|
| 430 |
+
)
|
| 431 |
+
glbscene.export(file_obj=glbfile)
|
| 432 |
+
|
| 433 |
+
# Prepare measure view
|
| 434 |
+
measure_img, _ = update_measure_view(predictions, 0)
|
| 435 |
+
# Create view selector based on number of input images
|
| 436 |
+
num_views = (
|
| 437 |
+
predictions["images"].shape[0] if predictions["images"].shape[0] > 0 else 1
|
| 438 |
+
)
|
| 439 |
+
view_choices = [f"View {i + 1}" for i in range(num_views)]
|
| 440 |
+
measure_selector = gr.Dropdown(choices=view_choices, value=view_choices[0])
|
| 441 |
+
|
| 442 |
+
# Clean up memory
|
| 443 |
+
gc.collect()
|
| 444 |
+
torch.cuda.empty_cache()
|
| 445 |
+
|
| 446 |
+
total_time = time.time() - start_time
|
| 447 |
+
log_msg = f"Reconstruction Success ({len(all_files)} frames). Inference: {inference_time:.2f}s | Total: {total_time:.2f}s"
|
| 448 |
+
print(f"Reconstruction complete | Inference: {inference_time:.2f}s | Total: {total_time:.2f}s")
|
| 449 |
+
|
| 450 |
+
return (
|
| 451 |
+
glbfile,
|
| 452 |
+
log_msg,
|
| 453 |
+
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
|
| 454 |
+
predictions,
|
| 455 |
+
measure_img,
|
| 456 |
+
"",
|
| 457 |
+
measure_selector,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# -------------------------- UI Utility Functions --------------------------
|
| 462 |
+
def clear_fields():
|
| 463 |
+
"""Clear 3D model viewer for Gradio interface"""
|
| 464 |
+
return None
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def update_log():
|
| 468 |
+
"""Update log message during model processing"""
|
| 469 |
+
return "Loading and Reconstructing..."
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def update_visualization(
|
| 473 |
+
target_dir,
|
| 474 |
+
conf_thres,
|
| 475 |
+
frame_filter,
|
| 476 |
+
show_cam,
|
| 477 |
+
show_index,
|
| 478 |
+
ceiling_remove,
|
| 479 |
+
is_example,
|
| 480 |
+
):
|
| 481 |
+
"""
|
| 482 |
+
Update 3D visualization when parameters change (without re-running model)
|
| 483 |
+
Args:
|
| 484 |
+
is_example (str): Whether it's example data (skip if "True")
|
| 485 |
+
Returns:
|
| 486 |
+
tuple: (GLB file path, log message)
|
| 487 |
+
"""
|
| 488 |
+
# Skip if loading example data
|
| 489 |
+
if is_example == "True":
|
| 490 |
+
return (
|
| 491 |
+
None,
|
| 492 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 493 |
+
)
|
| 494 |
+
# Validate target directory and prediction file
|
| 495 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
| 496 |
+
return None, "No valid reconstruction. Please upload and reconstruct first."
|
| 497 |
+
|
| 498 |
+
pred_path = os.path.join(target_dir, "predictions.npz")
|
| 499 |
+
if not os.path.exists(pred_path):
|
| 500 |
+
return None, f"No prediction file found at {pred_path}. Run Reconstruct first."
|
| 501 |
+
|
| 502 |
+
# Load saved predictions
|
| 503 |
+
key_list = [
|
| 504 |
+
"pose_enc",
|
| 505 |
+
"depth",
|
| 506 |
+
"depth_conf",
|
| 507 |
+
"images",
|
| 508 |
+
"extrinsic",
|
| 509 |
+
"world_points_from_depth",
|
| 510 |
+
]
|
| 511 |
+
loaded = np.load(pred_path)
|
| 512 |
+
predictions = {key: np.array(loaded[key]) for key in key_list if key in loaded}
|
| 513 |
+
|
| 514 |
+
# Generate GLB file (create if not exists)
|
| 515 |
+
glb_filename = f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_index{show_index}_ceiling{ceiling_remove}.glb"
|
| 516 |
+
glbfile = os.path.join(target_dir, glb_filename)
|
| 517 |
+
|
| 518 |
+
if not os.path.exists(glbfile):
|
| 519 |
+
glbscene = predictions_to_glb(
|
| 520 |
+
predictions,
|
| 521 |
+
conf_thres=conf_thres,
|
| 522 |
+
filter_by_frames=frame_filter,
|
| 523 |
+
show_cam=show_cam,
|
| 524 |
+
show_index=show_index,
|
| 525 |
+
ceiling_remove=ceiling_remove,
|
| 526 |
+
target_dir=target_dir,
|
| 527 |
+
)
|
| 528 |
+
glbscene.export(file_obj=glbfile)
|
| 529 |
+
|
| 530 |
+
return glbfile, "Visualization updated successfully"
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
# -------------------------- Metric Measurement --------------------------
|
| 534 |
+
def update_measure_view(predictions, view_index):
|
| 535 |
+
"""
|
| 536 |
+
Update measure view with depth confidence mask overlay
|
| 537 |
+
Args:
|
| 538 |
+
predictions (dict): Model predictions with images and depth confidence
|
| 539 |
+
view_index (int): Index of the view to show
|
| 540 |
+
Returns:
|
| 541 |
+
tuple: (processed_image, empty_list)
|
| 542 |
+
"""
|
| 543 |
+
# Get image and depth confidence
|
| 544 |
+
image = predictions["images"][view_index].transpose(1, 2, 0).copy()
|
| 545 |
+
depth_conf = predictions["depth_conf"][view_index].copy()
|
| 546 |
+
|
| 547 |
+
# Convert image to uint8 format
|
| 548 |
+
if image.dtype != np.uint8:
|
| 549 |
+
image = (
|
| 550 |
+
(image * 255).astype(np.uint8)
|
| 551 |
+
if image.max() <= 1.0
|
| 552 |
+
else image.astype(np.uint8)
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# Create depth confidence mask (filter low confidence areas)
|
| 556 |
+
depth_conf_norm = (depth_conf - depth_conf.min()) / (
|
| 557 |
+
depth_conf.max() - depth_conf.min()
|
| 558 |
+
)
|
| 559 |
+
mask = depth_conf_norm > 0.05
|
| 560 |
+
invalid_mask = ~mask
|
| 561 |
+
|
| 562 |
+
# Apply red overlay to invalid areas (low confidence)
|
| 563 |
+
if invalid_mask.any():
|
| 564 |
+
overlay_color = np.array([255, 220, 220], dtype=np.uint8)
|
| 565 |
+
alpha = 0.5 # Transparency
|
| 566 |
+
for c in range(3):
|
| 567 |
+
image[:, :, c] = np.where(
|
| 568 |
+
invalid_mask,
|
| 569 |
+
(1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
|
| 570 |
+
image[:, :, c],
|
| 571 |
+
).astype(np.uint8)
|
| 572 |
+
|
| 573 |
+
return image, []
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def navigate_measure_view(processed_data, current_selector_value, direction):
|
| 577 |
+
"""
|
| 578 |
+
Navigate between different measure views (previous/next)
|
| 579 |
+
Args:
|
| 580 |
+
direction (int): -1 for previous, +1 for next
|
| 581 |
+
Returns:
|
| 582 |
+
tuple: (new_selector_value, measure_image, empty_points)
|
| 583 |
+
"""
|
| 584 |
+
if processed_data["images"].shape[0] == 0:
|
| 585 |
+
return "View 1", None, []
|
| 586 |
+
|
| 587 |
+
# Parse current view index from selector
|
| 588 |
+
try:
|
| 589 |
+
current_view = int(current_selector_value.split()[1]) - 1
|
| 590 |
+
except:
|
| 591 |
+
current_view = 0
|
| 592 |
+
|
| 593 |
+
# Calculate new view index (circular navigation)
|
| 594 |
+
num_views = processed_data["images"].shape[0]
|
| 595 |
+
new_view = (current_view + direction) % num_views
|
| 596 |
+
|
| 597 |
+
# Update selector and image
|
| 598 |
+
new_selector = f"View {new_view + 1}"
|
| 599 |
+
measure_image, _ = update_measure_view(processed_data, new_view)
|
| 600 |
+
return new_selector, measure_image, []
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def measure(
|
| 604 |
+
processed_data, measure_points, current_view_selector, event: gr.SelectData
|
| 605 |
+
):
|
| 606 |
+
"""
|
| 607 |
+
Core metric measurement function: click to select points and calculate 3D distance
|
| 608 |
+
Args:
|
| 609 |
+
event (gr.SelectData): Gradio click event data (image coordinates)
|
| 610 |
+
Returns:
|
| 611 |
+
tuple: (annotated_image, measure_points, measurement_text)
|
| 612 |
+
"""
|
| 613 |
+
try:
|
| 614 |
+
# Get current view index
|
| 615 |
+
try:
|
| 616 |
+
current_view = int(current_view_selector.split()[1]) - 1
|
| 617 |
+
except:
|
| 618 |
+
current_view = 0
|
| 619 |
+
# Validate view index
|
| 620 |
+
current_view = (
|
| 621 |
+
0
|
| 622 |
+
if current_view < 0 or current_view >= processed_data["images"].shape[0]
|
| 623 |
+
else current_view
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
# Get clicked 2D point
|
| 627 |
+
point2d = event.index[0], event.index[1]
|
| 628 |
+
measure_points.append(point2d)
|
| 629 |
+
print(f"Measuring: clicked point {point2d} (view {current_view + 1})")
|
| 630 |
+
|
| 631 |
+
# Get base image and 3D points
|
| 632 |
+
image, _ = update_measure_view(processed_data, current_view)
|
| 633 |
+
image = image.copy()
|
| 634 |
+
points3d = processed_data["world_points_from_depth"][current_view]
|
| 635 |
+
|
| 636 |
+
# Draw blue circles for clicked points
|
| 637 |
+
for p in measure_points:
|
| 638 |
+
if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
|
| 639 |
+
image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
|
| 640 |
+
|
| 641 |
+
# Calculate depth for single point
|
| 642 |
+
depth_text = ""
|
| 643 |
+
depth = processed_data["depth"][current_view].squeeze(axis=-1)
|
| 644 |
+
for i, p in enumerate(measure_points):
|
| 645 |
+
try:
|
| 646 |
+
if 0 <= p[1] < depth.shape[0] and 0 <= p[0] < depth.shape[1]:
|
| 647 |
+
d = depth[p[1], p[0]]
|
| 648 |
+
depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
|
| 649 |
+
else:
|
| 650 |
+
d = np.linalg.norm(points3d[p[1], p[0]], ord=2)
|
| 651 |
+
depth_text += f"- **P{i + 1} dist: {d:.2f}m.**\n"
|
| 652 |
+
except:
|
| 653 |
+
depth_text += f"- **P{i + 1}: Depth unavailable**\n"
|
| 654 |
+
|
| 655 |
+
# Calculate 3D distance for two points
|
| 656 |
+
if len(measure_points) == 2:
|
| 657 |
+
p1, p2 = measure_points
|
| 658 |
+
# Draw blue line between two points
|
| 659 |
+
if all(
|
| 660 |
+
0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]
|
| 661 |
+
for p in [p1, p2]
|
| 662 |
+
):
|
| 663 |
+
image = cv2.line(image, p1, p2, color=(255, 0, 0), thickness=2)
|
| 664 |
+
# Calculate 3D Euclidean distance
|
| 665 |
+
try:
|
| 666 |
+
p1_3d = points3d[p1[1], p1[0]]
|
| 667 |
+
p2_3d = points3d[p2[1], p2[0]]
|
| 668 |
+
distance = np.linalg.norm(p1_3d - p2_3d)
|
| 669 |
+
distance_text = f"- **Distance: {distance:.2f}m**"
|
| 670 |
+
except:
|
| 671 |
+
distance_text = "- **Distance: Unable to compute**"
|
| 672 |
+
# Reset points after measurement
|
| 673 |
+
measure_points = []
|
| 674 |
+
return [image, measure_points, depth_text + distance_text]
|
| 675 |
+
|
| 676 |
+
return [image, measure_points, depth_text]
|
| 677 |
+
except Exception as e:
|
| 678 |
+
print(f"Measurement error: {str(e)}")
|
| 679 |
+
return None, [], f"Measure error: {str(e)}"
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
# -------------------------- Example Data Loader --------------------------
|
| 683 |
+
def get_scene_info(examples_dir):
|
| 684 |
+
"""
|
| 685 |
+
Load example scene information from examples directory
|
| 686 |
+
Args:
|
| 687 |
+
examples_dir (str): Directory containing example scenes
|
| 688 |
+
Returns:
|
| 689 |
+
list: List of scene dicts with name, path, thumbnail, image files
|
| 690 |
+
"""
|
| 691 |
+
scenes = []
|
| 692 |
+
if not os.path.exists(examples_dir):
|
| 693 |
+
return scenes
|
| 694 |
+
|
| 695 |
+
# Iterate over example scene folders
|
| 696 |
+
for scene_folder in sorted(os.listdir(examples_dir)):
|
| 697 |
+
scene_path = os.path.join(examples_dir, scene_folder)
|
| 698 |
+
if not os.path.isdir(scene_path):
|
| 699 |
+
continue
|
| 700 |
+
# Load all image files
|
| 701 |
+
img_exts = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
|
| 702 |
+
image_files = []
|
| 703 |
+
for ext in img_exts:
|
| 704 |
+
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
|
| 705 |
+
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
|
| 706 |
+
# Skip empty folders
|
| 707 |
+
if not image_files:
|
| 708 |
+
continue
|
| 709 |
+
# Sort images and get thumbnail
|
| 710 |
+
image_files = sorted(image_files)
|
| 711 |
+
scenes.append(
|
| 712 |
+
{
|
| 713 |
+
"name": scene_folder,
|
| 714 |
+
"path": scene_path,
|
| 715 |
+
"thumbnail": image_files[0],
|
| 716 |
+
"num_images": len(image_files),
|
| 717 |
+
"image_files": image_files,
|
| 718 |
+
}
|
| 719 |
+
)
|
| 720 |
+
return scenes
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def example_pipeline(
|
| 724 |
+
scene,
|
| 725 |
+
conf_thres=5.0,
|
| 726 |
+
show_cam=True,
|
| 727 |
+
show_index=True,
|
| 728 |
+
ceiling_remove=25,
|
| 729 |
+
):
|
| 730 |
+
"""
|
| 731 |
+
Pipeline for loading example scenes and running reconstruction
|
| 732 |
+
Args:
|
| 733 |
+
scene (dict): Example scene info from get_scene_info
|
| 734 |
+
Returns:
|
| 735 |
+
tuple: Gradio component update values
|
| 736 |
+
"""
|
| 737 |
+
input_image_paths = scene["image_files"]
|
| 738 |
+
target_dir, image_paths = handle_uploads(input_image_paths)
|
| 739 |
+
frame_filter = "All" # Default to all frames for examples
|
| 740 |
+
# Run reconstruction
|
| 741 |
+
(
|
| 742 |
+
glbfile,
|
| 743 |
+
log_msg,
|
| 744 |
+
dropdown,
|
| 745 |
+
predictions,
|
| 746 |
+
measure_img,
|
| 747 |
+
measure_text,
|
| 748 |
+
measure_selector,
|
| 749 |
+
) = gradio_demo(
|
| 750 |
+
target_dir, conf_thres, frame_filter, show_cam, show_index, ceiling_remove
|
| 751 |
+
)
|
| 752 |
+
return (
|
| 753 |
+
glbfile,
|
| 754 |
+
log_msg,
|
| 755 |
+
target_dir,
|
| 756 |
+
dropdown,
|
| 757 |
+
image_paths,
|
| 758 |
+
predictions,
|
| 759 |
+
measure_img,
|
| 760 |
+
measure_text,
|
| 761 |
+
measure_selector,
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
# -------------------------- 3D Visualization Utilities --------------------------
|
| 766 |
+
class SevenSegmentDigit:
|
| 767 |
+
"""7-segment display definition for digital watch style 3D point cloud generation"""
|
| 768 |
+
# 7 segments definition: A(top), B(upper right), C(lower right), D(bottom), E(lower left), F(upper left), G(middle)
|
| 769 |
+
SEGMENTS = {
|
| 770 |
+
'A': np.array([(x, 0.5, 0) for x in np.linspace(-0.4, 0.4, 80) for y in np.linspace(0.45, 0.55, 10)]),
|
| 771 |
+
'B': np.array([(x, y, 0) for x in np.linspace(0.4, 0.5, 10) for y in np.linspace(0, 0.5, 80)]),
|
| 772 |
+
'C': np.array([(x, y, 0) for x in np.linspace(0.4, 0.5, 10) for y in np.linspace(-0.5, 0, 80)]),
|
| 773 |
+
'D': np.array([(x, y, 0) for x in np.linspace(-0.4, 0.4, 80) for y in np.linspace(-0.55, -0.45, 10)]),
|
| 774 |
+
'E': np.array([(x, y, 0) for x in np.linspace(-0.5, -0.4, 10) for y in np.linspace(-0.5, 0, 80)]),
|
| 775 |
+
'F': np.array([(x, y, 0) for x in np.linspace(-0.5, -0.4, 10) for y in np.linspace(0, 0.5, 80)]),
|
| 776 |
+
'G': np.array([(x, y, 0) for x in np.linspace(-0.4, 0.4, 80) for y in np.linspace(-0.05, 0.05, 10)])
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
# Segment mapping for standard 0-9 digits (specify lit segments for each digit)
|
| 780 |
+
DIGIT_SEGMENTS = {
|
| 781 |
+
0: ['A', 'B', 'C', 'D', 'E', 'F'],
|
| 782 |
+
1: ['B', 'C'],
|
| 783 |
+
2: ['A', 'B', 'G', 'E', 'D'],
|
| 784 |
+
3: ['A', 'B', 'G', 'C', 'D'],
|
| 785 |
+
4: ['F', 'G', 'B', 'C'],
|
| 786 |
+
5: ['A', 'F', 'G', 'C', 'D'],
|
| 787 |
+
6: ['A', 'F', 'G', 'C', 'D', 'E'],
|
| 788 |
+
7: ['A', 'B', 'C'],
|
| 789 |
+
8: ['A', 'B', 'C', 'D', 'E', 'F', 'G'],
|
| 790 |
+
9: ['A', 'B', 'C', 'D', 'F', 'G']
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
@classmethod
|
| 794 |
+
def get_digit_points(cls, digit, scale=0.05):
|
| 795 |
+
"""
|
| 796 |
+
Generate 3D point cloud for a single digital watch style digit (0-9)
|
| 797 |
+
Args:
|
| 798 |
+
digit (int): Target digit (0-9 only)
|
| 799 |
+
scale (float): Scale factor for point cloud size
|
| 800 |
+
Returns:
|
| 801 |
+
np.ndarray: N×3 array of 3D points for the digit
|
| 802 |
+
Raises:
|
| 803 |
+
ValueError: If digit is not in 0-9 range
|
| 804 |
+
"""
|
| 805 |
+
if not 0 <= digit <= 9:
|
| 806 |
+
raise ValueError(f"Digit must be 0-9, got {digit}")
|
| 807 |
+
|
| 808 |
+
# Combine lit segments for the target digit
|
| 809 |
+
segments = cls.DIGIT_SEGMENTS[digit]
|
| 810 |
+
points = np.vstack([cls.SEGMENTS[seg] for seg in segments])
|
| 811 |
+
|
| 812 |
+
# Scale point cloud and center to origin
|
| 813 |
+
points = points * scale
|
| 814 |
+
points -= points.mean(axis=0)
|
| 815 |
+
|
| 816 |
+
# Remove duplicate points and supplement sparse points (ensure dense distribution)
|
| 817 |
+
points = np.unique(points.round(6), axis=0)
|
| 818 |
+
if len(points) < 200:
|
| 819 |
+
points = trimesh.sample.sample_surface(trimesh.Trimesh(points), 500)[0]
|
| 820 |
+
|
| 821 |
+
return points
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
def create_number_point_cloud(number, scale=0.05):
|
| 825 |
+
"""
|
| 826 |
+
Generate 3D point cloud for multi-digit number (digital watch style), facing +Y axis
|
| 827 |
+
Args:
|
| 828 |
+
number (int): Non-negative target integer (any digit length)
|
| 829 |
+
scale (float): Scale factor for single digit point cloud size
|
| 830 |
+
Returns:
|
| 831 |
+
trimesh.PointCloud: Colored (red) 3D point cloud of the number
|
| 832 |
+
Raises:
|
| 833 |
+
ValueError: If number is negative or non-integer
|
| 834 |
+
"""
|
| 835 |
+
if not isinstance(number, int) or number < 0:
|
| 836 |
+
raise ValueError(f"Number must be non-negative integer, got {number}")
|
| 837 |
+
|
| 838 |
+
# Split number into individual digits and handle 0 specially
|
| 839 |
+
digits = [int(d) for d in str(number)] if number != 0 else [0]
|
| 840 |
+
all_points, spacing = [], scale * 1.2
|
| 841 |
+
total_width = (len(digits)-1) * spacing
|
| 842 |
+
|
| 843 |
+
# Arrange digits horizontally and center the whole number
|
| 844 |
+
for idx, d in enumerate(digits):
|
| 845 |
+
digit_points = SevenSegmentDigit.get_digit_points(d, scale)
|
| 846 |
+
digit_points[:, 0] += -total_width/2 + idx * spacing
|
| 847 |
+
all_points.append(digit_points)
|
| 848 |
+
|
| 849 |
+
# Merge all digit points and apply rotation to face +Y axis
|
| 850 |
+
all_points = np.vstack(all_points)
|
| 851 |
+
rotation = np.array([[1, 0, 0],
|
| 852 |
+
[0, 0, -1],
|
| 853 |
+
[0, 1, 0]])
|
| 854 |
+
all_points = np.dot(all_points, rotation.T)
|
| 855 |
+
|
| 856 |
+
# Create red point cloud (classic digital watch color)
|
| 857 |
+
colors = np.full((len(all_points), 3), [255, 0, 0], dtype=np.uint8)
|
| 858 |
+
|
| 859 |
+
return trimesh.PointCloud(all_points, colors)
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
def predictions_to_glb(
|
| 863 |
+
predictions,
|
| 864 |
+
conf_thres=50.0,
|
| 865 |
+
filter_by_frames="all",
|
| 866 |
+
show_cam=True,
|
| 867 |
+
show_index=True,
|
| 868 |
+
ceiling_remove=25,
|
| 869 |
+
target_dir=None,
|
| 870 |
+
prediction_mode="Predicted Pointmap",
|
| 871 |
+
) -> trimesh.Scene:
|
| 872 |
+
"""
|
| 873 |
+
Convert VGGT model predictions to a 3D trimesh Scene (exportable to GLB)
|
| 874 |
+
Integrates colored point cloud, camera meshes and digital camera indexes
|
| 875 |
+
Args:
|
| 876 |
+
predictions (dict): Model prediction dict with keys:
|
| 877 |
+
- world_points: 3D point coordinates (S, H, W, 3)
|
| 878 |
+
- world_points_conf: Confidence scores (S, H, W)
|
| 879 |
+
- images: Input images (S, H, W, 3)
|
| 880 |
+
- extrinsic: Camera extrinsic matrices (S, 3, 4)
|
| 881 |
+
conf_thres (float): Low-confidence point filter (percentile, 0-100)
|
| 882 |
+
filter_by_frames (str): Frame filter ("all" or specific frame index like "0:")
|
| 883 |
+
show_cam (bool): Whether to add camera mesh visualization to scene
|
| 884 |
+
show_index (bool): Whether to add digital index point cloud above cameras
|
| 885 |
+
ceiling_remove (float): Percentage of top Y-coordinate points to remove as ceiling (0-100, 0=disabled)
|
| 886 |
+
target_dir (str): Directory for intermediate files (images)
|
| 887 |
+
prediction_mode (str): Prediction branch ("Predicted Pointmap" / others for depth-based)
|
| 888 |
+
Returns:
|
| 889 |
+
trimesh.Scene: 3D scene with point cloud, cameras and indexes (if enabled)
|
| 890 |
+
Raises:
|
| 891 |
+
ValueError: If predictions is not a dictionary
|
| 892 |
+
"""
|
| 893 |
+
if not isinstance(predictions, dict):
|
| 894 |
+
raise ValueError("predictions must be a dictionary")
|
| 895 |
+
|
| 896 |
+
conf_thres = 10.0 if conf_thres is None else conf_thres
|
| 897 |
+
print("Building GLB scene")
|
| 898 |
+
selected_frame_idx = None
|
| 899 |
+
|
| 900 |
+
# Parse selected frame index from filter string (e.g., "0:" -> 0)
|
| 901 |
+
if filter_by_frames not in ["all", "All"]:
|
| 902 |
+
try:
|
| 903 |
+
selected_frame_idx = int(filter_by_frames.split(":")[0])
|
| 904 |
+
except (ValueError, IndexError):
|
| 905 |
+
pass
|
| 906 |
+
|
| 907 |
+
# Select prediction branch (Pointmap direct / Depthmap derived)
|
| 908 |
+
if "Pointmap" in prediction_mode:
|
| 909 |
+
print("Using Pointmap Branch")
|
| 910 |
+
if "world_points" in predictions:
|
| 911 |
+
pred_world_points = predictions["world_points"]
|
| 912 |
+
pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0]))
|
| 913 |
+
else:
|
| 914 |
+
print("Warning: world_points not found, falling back to depth-based world points")
|
| 915 |
+
pred_world_points = predictions["world_points_from_depth"]
|
| 916 |
+
pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
|
| 917 |
+
else:
|
| 918 |
+
print("Using Depthmap and Camera Branch")
|
| 919 |
+
pred_world_points = predictions["world_points_from_depth"]
|
| 920 |
+
pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
|
| 921 |
+
|
| 922 |
+
# Extract core prediction data: images and camera extrinsic matrices
|
| 923 |
+
images = predictions["images"]
|
| 924 |
+
camera_matrices = predictions["extrinsic"]
|
| 925 |
+
|
| 926 |
+
# Filter prediction data to selected single frame if specified
|
| 927 |
+
if selected_frame_idx is not None:
|
| 928 |
+
pred_world_points = pred_world_points[selected_frame_idx][None]
|
| 929 |
+
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
|
| 930 |
+
images = images[selected_frame_idx][None]
|
| 931 |
+
camera_matrices = camera_matrices[selected_frame_idx][None]
|
| 932 |
+
|
| 933 |
+
# Reshape 3D points and convert image colors to 8-bit RGB (match point cloud)
|
| 934 |
+
vertices_3d = pred_world_points.reshape(-1, 3)
|
| 935 |
+
if images.ndim == 4 and images.shape[1] == 3: # Convert NCHW to NHWC format
|
| 936 |
+
colors_rgb = np.transpose(images, (0, 2, 3, 1))
|
| 937 |
+
else: # Direct use if already NHWC format
|
| 938 |
+
colors_rgb = images
|
| 939 |
+
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
|
| 940 |
+
|
| 941 |
+
# Filter points by confidence threshold (remove low-confidence points)
|
| 942 |
+
conf = pred_world_points_conf.reshape(-1)
|
| 943 |
+
conf_threshold = 0.0 if conf_thres == 0.0 else np.percentile(conf, conf_thres)
|
| 944 |
+
conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
|
| 945 |
+
|
| 946 |
+
vertices_3d = vertices_3d[conf_mask]
|
| 947 |
+
colors_rgb = colors_rgb[conf_mask]
|
| 948 |
+
|
| 949 |
+
# Create dummy point if no valid points left (avoid scene empty error)
|
| 950 |
+
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
|
| 951 |
+
vertices_3d = np.array([[1, 0, 0]])
|
| 952 |
+
colors_rgb = np.array([[255, 255, 255]])
|
| 953 |
+
scene_scale = 1
|
| 954 |
+
else:
|
| 955 |
+
# Calculate scene scale by 5th/95th percentile bounding box diagonal
|
| 956 |
+
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
| 957 |
+
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
|
| 958 |
+
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
|
| 959 |
+
|
| 960 |
+
# Initialize 3D scene and colormap for camera unique colors
|
| 961 |
+
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
| 962 |
+
scene_3d = trimesh.Scene()
|
| 963 |
+
|
| 964 |
+
# Filter out ceiling points (remove top N% of Y-coordinates by percentile)
|
| 965 |
+
if ceiling_remove > 0 and vertices_3d.size > 1:
|
| 966 |
+
y_coords = vertices_3d[:, 1]
|
| 967 |
+
y_percentile = np.percentile(y_coords, ceiling_remove)
|
| 968 |
+
mask = y_coords > y_percentile
|
| 969 |
+
vertices_3d = vertices_3d[mask]
|
| 970 |
+
colors_rgb = colors_rgb[mask]
|
| 971 |
+
|
| 972 |
+
# Add colored 3D point cloud to the scene
|
| 973 |
+
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
|
| 974 |
+
scene_3d.add_geometry(point_cloud_data)
|
| 975 |
+
|
| 976 |
+
# Convert 3x4 camera extrinsics to 4x4 homogeneous matrices
|
| 977 |
+
num_cameras = len(camera_matrices)
|
| 978 |
+
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
|
| 979 |
+
extrinsics_matrices[:, :3, :4] = camera_matrices
|
| 980 |
+
extrinsics_matrices[:, 3, 3] = 1
|
| 981 |
+
|
| 982 |
+
# Add camera meshes and digital index point clouds to the scene
|
| 983 |
+
for i in range(num_cameras):
|
| 984 |
+
camera_to_world = extrinsics_matrices[i]
|
| 985 |
+
rgba_color = colormap(i / num_cameras) # Unique color for each camera
|
| 986 |
+
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
| 987 |
+
|
| 988 |
+
# Add camera mesh to scene
|
| 989 |
+
if show_cam:
|
| 990 |
+
integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
|
| 991 |
+
|
| 992 |
+
# Add digital index point cloud above each camera (red, digital watch style)
|
| 993 |
+
if show_index:
|
| 994 |
+
camera_center = camera_to_world[:3, 3]
|
| 995 |
+
y_offset = 0.5 # Y-axis offset for index position (above camera)
|
| 996 |
+
number_position = camera_center + np.array([0, y_offset, 0])
|
| 997 |
+
|
| 998 |
+
# Generate index point cloud and translate to target position
|
| 999 |
+
number_scale = 0.3
|
| 1000 |
+
number_pc = create_number_point_cloud(number=i, scale=number_scale)
|
| 1001 |
+
number_pc.apply_translation(number_position)
|
| 1002 |
+
scene_3d.add_geometry(number_pc)
|
| 1003 |
+
|
| 1004 |
+
# Align the whole scene to the first camera's viewing perspective
|
| 1005 |
+
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
|
| 1006 |
+
|
| 1007 |
+
print("GLB Scene built successfully")
|
| 1008 |
+
return scene_3d
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
def integrate_camera_into_scene(
|
| 1012 |
+
scene: trimesh.Scene, transform: np.ndarray, face_colors: tuple, scene_scale: float
|
| 1013 |
+
):
|
| 1014 |
+
"""
|
| 1015 |
+
Add a 3D cone-shaped camera mesh to the 3D scene with specified transform and color
|
| 1016 |
+
Args:
|
| 1017 |
+
scene (trimesh.Scene): Target 3D scene to add camera mesh
|
| 1018 |
+
transform (np.ndarray): 4x4 camera-to-world transformation matrix
|
| 1019 |
+
face_colors (tuple): RGB color tuple (0-255) for camera mesh faces
|
| 1020 |
+
scene_scale (float): Overall scale of the 3D scene (for camera size adaptation)
|
| 1021 |
+
"""
|
| 1022 |
+
# Set camera mesh size based on scene scale
|
| 1023 |
+
cam_width = scene_scale * 0.02
|
| 1024 |
+
cam_height = scene_scale * 0.02
|
| 1025 |
+
|
| 1026 |
+
# 45° Z-axis rotation for camera cone shape and backward translation
|
| 1027 |
+
rot_45_degree = np.eye(4)
|
| 1028 |
+
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
|
| 1029 |
+
rot_45_degree[2, 3] = -cam_height
|
| 1030 |
+
|
| 1031 |
+
# Combine OpenGL conversion, rotation and camera transform matrices
|
| 1032 |
+
opengl_transform = get_opengl_conversion_matrix()
|
| 1033 |
+
complete_transform = transform @ opengl_transform @ rot_45_degree
|
| 1034 |
+
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
|
| 1035 |
+
|
| 1036 |
+
# Slight Z-axis rotation for camera mesh detail enhancement
|
| 1037 |
+
slight_rotation = np.eye(4)
|
| 1038 |
+
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
|
| 1039 |
+
|
| 1040 |
+
# Combine original, scaled and rotated cone vertices for dense camera mesh
|
| 1041 |
+
vertices_combined = np.concatenate(
|
| 1042 |
+
[
|
| 1043 |
+
camera_cone_shape.vertices,
|
| 1044 |
+
0.95 * camera_cone_shape.vertices,
|
| 1045 |
+
transform_points(slight_rotation, camera_cone_shape.vertices),
|
| 1046 |
+
]
|
| 1047 |
+
)
|
| 1048 |
+
vertices_transformed = transform_points(complete_transform, vertices_combined)
|
| 1049 |
+
|
| 1050 |
+
# Compute camera mesh faces from cone shape
|
| 1051 |
+
mesh_faces = compute_camera_faces(camera_cone_shape)
|
| 1052 |
+
|
| 1053 |
+
# Create camera mesh with specified color and add to scene
|
| 1054 |
+
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
|
| 1055 |
+
camera_mesh.visual.face_colors[:, :3] = face_colors
|
| 1056 |
+
scene.add_geometry(camera_mesh)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
def apply_scene_alignment(
|
| 1060 |
+
scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
|
| 1061 |
+
) -> trimesh.Scene:
|
| 1062 |
+
"""
|
| 1063 |
+
Align the 3D scene to the first camera's viewing perspective with OpenGL conversion
|
| 1064 |
+
Args:
|
| 1065 |
+
scene_3d (trimesh.Scene): Unaligned 3D scene
|
| 1066 |
+
extrinsics_matrices (np.ndarray): N×4×4 camera extrinsic matrices
|
| 1067 |
+
Returns:
|
| 1068 |
+
trimesh.Scene: Aligned 3D scene
|
| 1069 |
+
"""
|
| 1070 |
+
# Get OpenGL coordinate conversion matrix and 180° Y-axis rotation for alignment
|
| 1071 |
+
opengl_conversion_matrix = get_opengl_conversion_matrix()
|
| 1072 |
+
align_rotation = np.eye(4)
|
| 1073 |
+
align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
|
| 1074 |
+
|
| 1075 |
+
# Combine transformation matrices and apply to the whole scene
|
| 1076 |
+
initial_transformation = np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
|
| 1077 |
+
scene_3d.apply_transform(initial_transformation)
|
| 1078 |
+
return scene_3d
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
def get_opengl_conversion_matrix() -> np.ndarray:
|
| 1082 |
+
"""
|
| 1083 |
+
Create 4x4 OpenGL coordinate system conversion matrix (flip Y and Z axes)
|
| 1084 |
+
Returns:
|
| 1085 |
+
np.ndarray: 4x4 identity-based conversion matrix
|
| 1086 |
+
"""
|
| 1087 |
+
matrix = np.identity(4)
|
| 1088 |
+
matrix[1, 1] = -1 # Flip Y axis
|
| 1089 |
+
matrix[2, 2] = -1 # Flip Z axis
|
| 1090 |
+
return matrix
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
def transform_points(
|
| 1094 |
+
transformation: np.ndarray, points: np.ndarray, dim: int = None
|
| 1095 |
+
) -> np.ndarray:
|
| 1096 |
+
"""
|
| 1097 |
+
Apply 4x4 homogeneous transformation matrix to a set of 3D points
|
| 1098 |
+
Args:
|
| 1099 |
+
transformation (np.ndarray): 4x4 transformation matrix
|
| 1100 |
+
points (np.ndarray): N×3 array of 3D points to transform
|
| 1101 |
+
dim (int, optional): Target dimension of output points (default: 3)
|
| 1102 |
+
Returns:
|
| 1103 |
+
np.ndarray: N×dim array of transformed points (same shape as input except last dim)
|
| 1104 |
+
"""
|
| 1105 |
+
points = np.asarray(points)
|
| 1106 |
+
initial_shape = points.shape[:-1]
|
| 1107 |
+
dim = dim or points.shape[-1]
|
| 1108 |
+
|
| 1109 |
+
# Transpose matrix and apply affine transformation to points
|
| 1110 |
+
transformation = transformation.swapaxes(-1, -2)
|
| 1111 |
+
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
|
| 1112 |
+
|
| 1113 |
+
# Reshape transformed points to original shape (excluding last dimension)
|
| 1114 |
+
result = points[..., :dim].reshape(*initial_shape, dim)
|
| 1115 |
+
return result
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
|
| 1119 |
+
"""
|
| 1120 |
+
Compute face indices for camera mesh from original cone shape faces (enhance detail)
|
| 1121 |
+
Args:
|
| 1122 |
+
cone_shape (trimesh.Trimesh): Original cone mesh for camera base shape
|
| 1123 |
+
Returns:
|
| 1124 |
+
np.ndarray: M×3 array of face indices for the camera mesh
|
| 1125 |
+
"""
|
| 1126 |
+
faces_list = []
|
| 1127 |
+
num_vertices_cone = len(cone_shape.vertices)
|
| 1128 |
+
|
| 1129 |
+
# Generate enhanced faces from cone faces (skip origin vertex 0)
|
| 1130 |
+
for face in cone_shape.faces:
|
| 1131 |
+
if 0 in face:
|
| 1132 |
+
continue
|
| 1133 |
+
v1, v2, v3 = face
|
| 1134 |
+
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
|
| 1135 |
+
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
| 1136 |
+
|
| 1137 |
+
# Add multiple face variations for dense camera mesh
|
| 1138 |
+
faces_list.extend(
|
| 1139 |
+
[
|
| 1140 |
+
(v1, v2, v2_offset),
|
| 1141 |
+
(v1, v1_offset, v3),
|
| 1142 |
+
(v3_offset, v2, v3),
|
| 1143 |
+
(v1, v2, v2_offset_2),
|
| 1144 |
+
(v1, v1_offset_2, v3),
|
| 1145 |
+
(v3_offset_2, v2, v3),
|
| 1146 |
+
]
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
# Add reversed faces for double-sided rendering
|
| 1150 |
+
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
| 1151 |
+
return np.array(faces_list)
|
| 1152 |
+
|
| 1153 |
+
|
| 1154 |
+
# -------------------------- Gradio UI Construction --------------------------
|
| 1155 |
+
if __name__ == "__main__":
|
| 1156 |
+
# Gradio theme configuration
|
| 1157 |
+
theme = gr.themes.Ocean()
|
| 1158 |
+
theme.set(
|
| 1159 |
+
checkbox_label_background_fill_selected="*button_primary_background_fill",
|
| 1160 |
+
checkbox_label_text_color_selected="*button_primary_text_color",
|
| 1161 |
+
)
|
| 1162 |
+
|
| 1163 |
+
with gr.Blocks(
|
| 1164 |
+
theme=theme,
|
| 1165 |
+
title="Argus - 3D Reconstruction",
|
| 1166 |
+
css="""
|
| 1167 |
+
.custom-log * {
|
| 1168 |
+
font-style: italic;
|
| 1169 |
+
font-size: 20px !important;
|
| 1170 |
+
background-image: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 1171 |
+
-webkit-background-clip: text;
|
| 1172 |
+
background-clip: text;
|
| 1173 |
+
font-weight: 600 !important;
|
| 1174 |
+
color: transparent !important;
|
| 1175 |
+
text-align: center !important;
|
| 1176 |
+
}
|
| 1177 |
+
.example-log * {
|
| 1178 |
+
font-size: 15px !important;
|
| 1179 |
+
background-image: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 1180 |
+
-webkit-background-clip: text;
|
| 1181 |
+
background-clip: text;
|
| 1182 |
+
color: transparent !important;
|
| 1183 |
+
font-weight: 500 !important;
|
| 1184 |
+
}
|
| 1185 |
+
.header-banner {
|
| 1186 |
+
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
|
| 1187 |
+
border-radius: 16px;
|
| 1188 |
+
padding: 32px 24px 24px;
|
| 1189 |
+
margin-bottom: 16px;
|
| 1190 |
+
border: 1px solid #e2e8f0;
|
| 1191 |
+
text-align: center;
|
| 1192 |
+
}
|
| 1193 |
+
.header-banner h1 {
|
| 1194 |
+
font-size: 28px;
|
| 1195 |
+
font-weight: 700;
|
| 1196 |
+
color: #1e293b;
|
| 1197 |
+
margin: 12px 0 8px;
|
| 1198 |
+
}
|
| 1199 |
+
.header-banner .links {
|
| 1200 |
+
margin-top: 12px;
|
| 1201 |
+
font-size: 15px;
|
| 1202 |
+
}
|
| 1203 |
+
.header-banner .links a {
|
| 1204 |
+
margin: 0 10px;
|
| 1205 |
+
color: #4f46e5;
|
| 1206 |
+
text-decoration: none;
|
| 1207 |
+
font-weight: 500;
|
| 1208 |
+
}
|
| 1209 |
+
.header-banner .links a:hover {
|
| 1210 |
+
text-decoration: underline;
|
| 1211 |
+
}
|
| 1212 |
+
.instructions {
|
| 1213 |
+
font-size: 14px;
|
| 1214 |
+
color: #475569;
|
| 1215 |
+
line-height: 1.7;
|
| 1216 |
+
padding: 12px 20px;
|
| 1217 |
+
background: #f8fafc;
|
| 1218 |
+
border-radius: 10px;
|
| 1219 |
+
border: 1px solid #e2e8f0;
|
| 1220 |
+
}
|
| 1221 |
+
.instructions ol {
|
| 1222 |
+
padding-left: 20px;
|
| 1223 |
+
margin: 8px 0;
|
| 1224 |
+
}
|
| 1225 |
+
.instructions li {
|
| 1226 |
+
margin-bottom: 4px;
|
| 1227 |
+
}
|
| 1228 |
+
.param-group {
|
| 1229 |
+
padding: 8px 0;
|
| 1230 |
+
}
|
| 1231 |
+
footer {visibility: hidden;}
|
| 1232 |
+
""",
|
| 1233 |
+
) as demo:
|
| 1234 |
+
# Hidden state components for data passing
|
| 1235 |
+
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
| 1236 |
+
processed_data_state = gr.State(value=None)
|
| 1237 |
+
measure_points_state = gr.State(value=[])
|
| 1238 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
| 1239 |
+
|
| 1240 |
+
# Load and display logo (base64 encoded)
|
| 1241 |
+
root_dir = Path(__file__).parent
|
| 1242 |
+
logo_path = root_dir / "assets" / "argus_logo.png"
|
| 1243 |
+
if logo_path.exists():
|
| 1244 |
+
with open(logo_path, "rb") as f:
|
| 1245 |
+
logo_base64 = base64.b64encode(f.read()).decode()
|
| 1246 |
+
logo_src = f"data:image/png;base64,{logo_base64}"
|
| 1247 |
+
else:
|
| 1248 |
+
logo_src = "" # Fallback if logo not found
|
| 1249 |
+
|
| 1250 |
+
# UI Header and Instructions
|
| 1251 |
+
gr.HTML(
|
| 1252 |
+
f"""
|
| 1253 |
+
<div class="header-banner">
|
| 1254 |
+
<div style="display: flex; justify-content: center;">
|
| 1255 |
+
<img src="{logo_src}" alt="Argus Logo" style="height: 72px; border-radius: 8px;">
|
| 1256 |
+
</div>
|
| 1257 |
+
<h1>Argus: Metric Panoramic 3D Reconstruction for Indoor Scenes</h1>
|
| 1258 |
+
<div class="links">
|
| 1259 |
+
<a href="https://github.com/realsee-developer/Argus" target="_blank">🌟 GitHub</a>
|
| 1260 |
+
<a href="https://argus-paper.realsee.ai" target="_blank">🚀 Project Page</a>
|
| 1261 |
+
<a href="https://arxiv.org/abs/2606.30047" target="_blank">📄 Paper</a>
|
| 1262 |
+
</div>
|
| 1263 |
+
</div>
|
| 1264 |
+
<div class="instructions">
|
| 1265 |
+
<ol>
|
| 1266 |
+
<li><strong>Upload</strong> a set of ERP panoramic images on the left.</li>
|
| 1267 |
+
<li><strong>Click "Reconstruct"</strong> to run the 3D reconstruction pipeline.</li>
|
| 1268 |
+
<li><strong>Explore</strong> the 3D model — rotate, pan, zoom, and download the GLB.</li>
|
| 1269 |
+
<li><strong>Measure</strong> — switch to the Metric tab and click two points to measure real-world distance.</li>
|
| 1270 |
+
</ol>
|
| 1271 |
+
</div>
|
| 1272 |
+
"""
|
| 1273 |
+
)
|
| 1274 |
+
|
| 1275 |
+
# Main UI Layout (2 columns: upload/gallery | 3D model/measurement)
|
| 1276 |
+
with gr.Row(equal_height=False):
|
| 1277 |
+
with gr.Column(scale=2, min_width=280):
|
| 1278 |
+
input_images = gr.File(
|
| 1279 |
+
file_count="multiple", label="📁 Upload Panoramic Images", interactive=True
|
| 1280 |
+
)
|
| 1281 |
+
image_gallery = gr.Gallery(
|
| 1282 |
+
label="Preview",
|
| 1283 |
+
columns=3,
|
| 1284 |
+
height="280px",
|
| 1285 |
+
object_fit="contain",
|
| 1286 |
+
preview=True,
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
with gr.Column(scale=5):
|
| 1290 |
+
# Log output
|
| 1291 |
+
log_output = gr.Markdown(
|
| 1292 |
+
"Upload panoramic images (ERP), then click Reconstruct.",
|
| 1293 |
+
elem_classes=["custom-log"],
|
| 1294 |
+
)
|
| 1295 |
+
# Tabbed interface: 3D Model + Metric Measure
|
| 1296 |
+
with gr.Tabs():
|
| 1297 |
+
with gr.Tab("🏠 3D Model"):
|
| 1298 |
+
reconstruction_output = gr.Model3D(
|
| 1299 |
+
height=540, zoom_speed=0.5, pan_speed=0.5
|
| 1300 |
+
)
|
| 1301 |
+
with gr.Tab("📏 Metric Measure"):
|
| 1302 |
+
gr.Markdown(
|
| 1303 |
+
"Click two points on the panorama to measure the real-world distance between them."
|
| 1304 |
+
)
|
| 1305 |
+
with gr.Row():
|
| 1306 |
+
prev_measure_btn = gr.Button(
|
| 1307 |
+
"◀ Prev", size="sm", scale=1
|
| 1308 |
+
)
|
| 1309 |
+
measure_view_selector = gr.Dropdown(
|
| 1310 |
+
choices=["View 1"],
|
| 1311 |
+
value="View 1",
|
| 1312 |
+
label="Select View",
|
| 1313 |
+
scale=3,
|
| 1314 |
+
interactive=True,
|
| 1315 |
+
allow_custom_value=True,
|
| 1316 |
+
)
|
| 1317 |
+
next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
|
| 1318 |
+
measure_image = gr.Image(
|
| 1319 |
+
type="numpy",
|
| 1320 |
+
show_label=False,
|
| 1321 |
+
format="webp",
|
| 1322 |
+
interactive=False,
|
| 1323 |
+
sources=[],
|
| 1324 |
+
)
|
| 1325 |
+
measure_text = gr.Markdown("")
|
| 1326 |
+
|
| 1327 |
+
# Action buttons
|
| 1328 |
+
with gr.Row():
|
| 1329 |
+
submit_btn = gr.Button("🔨 Reconstruct", scale=2, variant="primary")
|
| 1330 |
+
clear_btn = gr.ClearButton(
|
| 1331 |
+
[
|
| 1332 |
+
input_images,
|
| 1333 |
+
reconstruction_output,
|
| 1334 |
+
log_output,
|
| 1335 |
+
target_dir_output,
|
| 1336 |
+
image_gallery,
|
| 1337 |
+
],
|
| 1338 |
+
value="🗑️ Clear",
|
| 1339 |
+
scale=1,
|
| 1340 |
+
)
|
| 1341 |
+
|
| 1342 |
+
# Reconstruction parameters
|
| 1343 |
+
gr.Markdown("**Visualization Settings**")
|
| 1344 |
+
with gr.Row():
|
| 1345 |
+
conf_thres = gr.Slider(
|
| 1346 |
+
0, 100, 5, 1, label="Confidence Threshold (%)"
|
| 1347 |
+
)
|
| 1348 |
+
ceiling_remove = gr.Slider(
|
| 1349 |
+
0, 100, 25, 1, label="Ceiling Remove (%)"
|
| 1350 |
+
)
|
| 1351 |
+
with gr.Row():
|
| 1352 |
+
frame_filter = gr.Dropdown(
|
| 1353 |
+
["All"], "All", label="Show Points from Frame", scale=2
|
| 1354 |
+
)
|
| 1355 |
+
show_cam = gr.Checkbox(True, label="Show Camera")
|
| 1356 |
+
show_index = gr.Checkbox(True, label="Show Index")
|
| 1357 |
+
|
| 1358 |
+
# Example Scenes Section
|
| 1359 |
+
gr.Markdown("---")
|
| 1360 |
+
gr.Markdown("### 🖼️ Example Scenes")
|
| 1361 |
+
gr.Markdown("Click any thumbnail to load and reconstruct.", elem_classes=["example-log"])
|
| 1362 |
+
example_scenes = get_scene_info(args.examples_dir)
|
| 1363 |
+
# Create 4-column example thumbnail grid
|
| 1364 |
+
if example_scenes:
|
| 1365 |
+
for i in range(0, len(example_scenes), 4):
|
| 1366 |
+
with gr.Row():
|
| 1367 |
+
for j in range(4):
|
| 1368 |
+
idx = i + j
|
| 1369 |
+
if idx < len(example_scenes):
|
| 1370 |
+
scene = example_scenes[idx]
|
| 1371 |
+
with gr.Column(scale=1):
|
| 1372 |
+
scene_state = gr.State(value=scene)
|
| 1373 |
+
scene_img = gr.Image(
|
| 1374 |
+
value=scene["thumbnail"],
|
| 1375 |
+
height=150,
|
| 1376 |
+
interactive=False,
|
| 1377 |
+
show_label=False,
|
| 1378 |
+
sources=[],
|
| 1379 |
+
)
|
| 1380 |
+
gr.Markdown(
|
| 1381 |
+
f"**{scene['name']}** \n {scene['num_images']} images"
|
| 1382 |
+
)
|
| 1383 |
+
# Bind thumbnail click to example pipeline
|
| 1384 |
+
scene_img.select(
|
| 1385 |
+
example_pipeline,
|
| 1386 |
+
[scene_state],
|
| 1387 |
+
[
|
| 1388 |
+
reconstruction_output,
|
| 1389 |
+
log_output,
|
| 1390 |
+
target_dir_output,
|
| 1391 |
+
frame_filter,
|
| 1392 |
+
image_gallery,
|
| 1393 |
+
processed_data_state,
|
| 1394 |
+
measure_image,
|
| 1395 |
+
measure_text,
|
| 1396 |
+
measure_view_selector,
|
| 1397 |
+
],
|
| 1398 |
+
)
|
| 1399 |
+
else:
|
| 1400 |
+
with gr.Column(scale=1):
|
| 1401 |
+
pass # Empty column for grid alignment
|
| 1402 |
+
|
| 1403 |
+
# -------------------------- Gradio Event Bindings --------------------------
|
| 1404 |
+
# Reconstruct button logic
|
| 1405 |
+
submit_btn.click(clear_fields, [], [reconstruction_output]).then(
|
| 1406 |
+
update_log, [], [log_output]
|
| 1407 |
+
).then(
|
| 1408 |
+
gradio_demo,
|
| 1409 |
+
[
|
| 1410 |
+
target_dir_output,
|
| 1411 |
+
conf_thres,
|
| 1412 |
+
frame_filter,
|
| 1413 |
+
show_cam,
|
| 1414 |
+
show_index,
|
| 1415 |
+
ceiling_remove,
|
| 1416 |
+
],
|
| 1417 |
+
[
|
| 1418 |
+
reconstruction_output,
|
| 1419 |
+
log_output,
|
| 1420 |
+
frame_filter,
|
| 1421 |
+
processed_data_state,
|
| 1422 |
+
measure_image,
|
| 1423 |
+
measure_text,
|
| 1424 |
+
measure_view_selector,
|
| 1425 |
+
],
|
| 1426 |
+
).then(
|
| 1427 |
+
lambda: "False", [], [is_example]
|
| 1428 |
+
)
|
| 1429 |
+
|
| 1430 |
+
# Real-time parameter update for 3D visualization
|
| 1431 |
+
for param in [conf_thres, frame_filter, show_cam, show_index, ceiling_remove]:
|
| 1432 |
+
param.change(
|
| 1433 |
+
update_visualization,
|
| 1434 |
+
[
|
| 1435 |
+
target_dir_output,
|
| 1436 |
+
conf_thres,
|
| 1437 |
+
frame_filter,
|
| 1438 |
+
show_cam,
|
| 1439 |
+
show_index,
|
| 1440 |
+
ceiling_remove,
|
| 1441 |
+
is_example,
|
| 1442 |
+
],
|
| 1443 |
+
[reconstruction_output, log_output],
|
| 1444 |
+
)
|
| 1445 |
+
|
| 1446 |
+
# Auto-update gallery on file upload
|
| 1447 |
+
input_images.change(
|
| 1448 |
+
update_gallery_on_upload,
|
| 1449 |
+
[input_images],
|
| 1450 |
+
[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 1451 |
+
)
|
| 1452 |
+
|
| 1453 |
+
# Metric measure event bindings
|
| 1454 |
+
measure_image.select(
|
| 1455 |
+
measure,
|
| 1456 |
+
[processed_data_state, measure_points_state, measure_view_selector],
|
| 1457 |
+
[measure_image, measure_points_state, measure_text],
|
| 1458 |
+
)
|
| 1459 |
+
# Measure view navigation
|
| 1460 |
+
prev_measure_btn.click(
|
| 1461 |
+
lambda d, s: navigate_measure_view(d, s, -1),
|
| 1462 |
+
[processed_data_state, measure_view_selector],
|
| 1463 |
+
[measure_view_selector, measure_image, measure_points_state],
|
| 1464 |
+
)
|
| 1465 |
+
next_measure_btn.click(
|
| 1466 |
+
lambda d, s: navigate_measure_view(d, s, 1),
|
| 1467 |
+
[processed_data_state, measure_view_selector],
|
| 1468 |
+
[measure_view_selector, measure_image, measure_points_state],
|
| 1469 |
+
)
|
| 1470 |
+
# Update measure view when selector changes
|
| 1471 |
+
measure_view_selector.change(
|
| 1472 |
+
lambda d, s: (
|
| 1473 |
+
update_measure_view(d, int(s.split()[1]) - 1) if s else (None, [])
|
| 1474 |
+
),
|
| 1475 |
+
[processed_data_state, measure_view_selector],
|
| 1476 |
+
[measure_image, measure_points_state],
|
| 1477 |
+
)
|
| 1478 |
+
|
| 1479 |
+
# Footer acknowledgement
|
| 1480 |
+
gr.HTML(
|
| 1481 |
+
"""
|
| 1482 |
+
<hr style="margin-top: 40px; margin-bottom: 20px; border-color: #e2e8f0;">
|
| 1483 |
+
<div style="text-align: center; font-size: 13px; color: #94a3b8; margin-bottom: 20px;">
|
| 1484 |
+
<p style="margin-bottom: 8px; font-weight: 500; color: #64748b;">Acknowledgements</p>
|
| 1485 |
+
<p>Built upon
|
| 1486 |
+
<a href="https://github.com/facebookresearch/vggt" style="color: #6366f1;">VGGT</a> &
|
| 1487 |
+
<a href="https://github.com/facebookresearch/map-anything" style="color: #6366f1;">Map-Anything</a>
|
| 1488 |
+
</p>
|
| 1489 |
+
</div>
|
| 1490 |
+
"""
|
| 1491 |
+
)
|
| 1492 |
+
|
| 1493 |
+
# Launch Gradio demo
|
| 1494 |
+
demo.queue(max_size=20).launch(
|
| 1495 |
+
show_error=True,
|
| 1496 |
+
share=args.share,
|
| 1497 |
+
server_name=args.server_name,
|
| 1498 |
+
server_port=args.port,
|
| 1499 |
+
)
|
argus/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Realsee. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0.
|
argus/heads/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Realsee. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0.
|
argus/heads/camera_head.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from argus.layers import Mlp
|
| 6 |
+
from argus.layers.block import Block
|
| 7 |
+
from argus.heads.head_act import activate_pose
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CameraHead(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 13 |
+
|
| 14 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim_in: int = 2048,
|
| 20 |
+
trunk_depth: int = 4,
|
| 21 |
+
num_heads: int = 16,
|
| 22 |
+
mlp_ratio: int = 4,
|
| 23 |
+
init_values: float = 0.01,
|
| 24 |
+
trans_act: str = "linear",
|
| 25 |
+
quat_act: str = "linear",
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.target_dim = 9
|
| 30 |
+
self.trans_act = trans_act
|
| 31 |
+
self.quat_act = quat_act
|
| 32 |
+
self.trunk_depth = trunk_depth
|
| 33 |
+
|
| 34 |
+
# Build the trunk using a sequence of transformer blocks.
|
| 35 |
+
self.trunk = nn.Sequential(
|
| 36 |
+
*[
|
| 37 |
+
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
| 38 |
+
for _ in range(trunk_depth)
|
| 39 |
+
]
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Normalizations for camera token and trunk output.
|
| 43 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 44 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 45 |
+
|
| 46 |
+
# Learnable empty camera pose token.
|
| 47 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 48 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 49 |
+
|
| 50 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 51 |
+
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
| 52 |
+
|
| 53 |
+
# Adaptive layer normalization without affine parameters.
|
| 54 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 55 |
+
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
| 56 |
+
|
| 57 |
+
# conf branch for T and R
|
| 58 |
+
self.conf_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=2, drop=0)
|
| 59 |
+
|
| 60 |
+
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
|
| 61 |
+
"""
|
| 62 |
+
Forward pass to predict camera parameters.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
| 66 |
+
the last tensor is used for prediction.
|
| 67 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 71 |
+
"""
|
| 72 |
+
# Use tokens from the last block for camera prediction.
|
| 73 |
+
tokens = aggregated_tokens_list[-1]
|
| 74 |
+
|
| 75 |
+
# Extract the camera tokens
|
| 76 |
+
pose_tokens = tokens[:, :, 0]
|
| 77 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 78 |
+
|
| 79 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
| 80 |
+
return pred_pose_enc_list
|
| 81 |
+
|
| 82 |
+
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
| 83 |
+
"""
|
| 84 |
+
Iteratively refine camera pose predictions.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
|
| 88 |
+
num_iterations (int): Number of refinement iterations.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
list: List of activated camera encodings from each iteration.
|
| 92 |
+
"""
|
| 93 |
+
B, S, C = pose_tokens.shape
|
| 94 |
+
pred_pose_enc = None
|
| 95 |
+
pred_pose_enc_conf = None
|
| 96 |
+
pred_pose_enc_list = []
|
| 97 |
+
|
| 98 |
+
for _ in range(num_iterations):
|
| 99 |
+
# Use a learned empty pose for the first iteration.
|
| 100 |
+
if pred_pose_enc is None:
|
| 101 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 102 |
+
else:
|
| 103 |
+
# Detach the previous prediction to avoid backprop through time.
|
| 104 |
+
pred_pose_enc = pred_pose_enc.detach()
|
| 105 |
+
module_input = self.embed_pose(pred_pose_enc)
|
| 106 |
+
|
| 107 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 108 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
| 109 |
+
|
| 110 |
+
# Adaptive layer normalization and modulation.
|
| 111 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
| 112 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 113 |
+
|
| 114 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
| 115 |
+
# Compute the delta update for the pose encoding.
|
| 116 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
| 117 |
+
pred_pose_enc_conf_delta = self.conf_branch(self.trunk_norm(pose_tokens_modulated))
|
| 118 |
+
|
| 119 |
+
if pred_pose_enc is None:
|
| 120 |
+
pred_pose_enc = pred_pose_enc_delta
|
| 121 |
+
pred_pose_enc_conf = pred_pose_enc_conf_delta
|
| 122 |
+
else:
|
| 123 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 124 |
+
pred_pose_enc_conf = pred_pose_enc_conf + pred_pose_enc_conf_delta
|
| 125 |
+
|
| 126 |
+
# Apply final activation functions for translation, quaternion
|
| 127 |
+
activated_pose = activate_pose(
|
| 128 |
+
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act
|
| 129 |
+
)
|
| 130 |
+
activated_conf = 1 + pred_pose_enc_conf.exp()
|
| 131 |
+
activated_pose = torch.cat([activated_pose, activated_conf], dim=-1)
|
| 132 |
+
pred_pose_enc_list.append(activated_pose)
|
| 133 |
+
|
| 134 |
+
return pred_pose_enc_list
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
"""
|
| 139 |
+
Modulate the input tensor using scaling and shifting parameters.
|
| 140 |
+
"""
|
| 141 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 142 |
+
return x * (1 + scale) + shift
|
argus/heads/dpt_head.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, Dict, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from .head_act import activate_head
|
| 8 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DPTHead(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
DPT Head for dense prediction tasks.
|
| 14 |
+
|
| 15 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 16 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 17 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
dim_in (int): Input dimension (channels).
|
| 21 |
+
patch_size (int, optional): Patch size. Default is 14.
|
| 22 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
| 23 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
| 24 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 25 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 26 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 27 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 28 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 29 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 30 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
dim_in: int,
|
| 36 |
+
patch_size: int = 14,
|
| 37 |
+
output_dim: int = 4,
|
| 38 |
+
activation: str = "inv_log",
|
| 39 |
+
conf_activation: str = "expp1",
|
| 40 |
+
features: int = 256,
|
| 41 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 42 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 43 |
+
pos_embed: bool = True,
|
| 44 |
+
feature_only: bool = False,
|
| 45 |
+
down_ratio: int = 1,
|
| 46 |
+
) -> None:
|
| 47 |
+
super(DPTHead, self).__init__()
|
| 48 |
+
self.patch_size = patch_size
|
| 49 |
+
self.activation = activation
|
| 50 |
+
self.conf_activation = conf_activation
|
| 51 |
+
self.pos_embed = pos_embed
|
| 52 |
+
self.feature_only = feature_only
|
| 53 |
+
self.down_ratio = down_ratio
|
| 54 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 55 |
+
|
| 56 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 57 |
+
|
| 58 |
+
# Projection layers for each output channel from tokens.
|
| 59 |
+
self.projects = nn.ModuleList(
|
| 60 |
+
[nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Resize layers for upsampling feature maps.
|
| 64 |
+
self.resize_layers = nn.ModuleList(
|
| 65 |
+
[
|
| 66 |
+
nn.ConvTranspose2d(
|
| 67 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 68 |
+
),
|
| 69 |
+
nn.ConvTranspose2d(
|
| 70 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 71 |
+
),
|
| 72 |
+
nn.Identity(),
|
| 73 |
+
nn.Conv2d(
|
| 74 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 75 |
+
),
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.scratch = _make_scratch(out_channels, features, expand=False)
|
| 80 |
+
|
| 81 |
+
# Attach additional modules to scratch.
|
| 82 |
+
self.scratch.stem_transpose = None
|
| 83 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 84 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 85 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 86 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 87 |
+
|
| 88 |
+
head_features_1 = features
|
| 89 |
+
head_features_2 = 32
|
| 90 |
+
|
| 91 |
+
if feature_only:
|
| 92 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
| 93 |
+
else:
|
| 94 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 95 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 96 |
+
)
|
| 97 |
+
conv2_in_channels = head_features_1 // 2
|
| 98 |
+
|
| 99 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 100 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 101 |
+
nn.ReLU(inplace=True),
|
| 102 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(
|
| 106 |
+
self,
|
| 107 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 108 |
+
images: torch.Tensor,
|
| 109 |
+
patch_start_idx: int,
|
| 110 |
+
frames_chunk_size: int = 8,
|
| 111 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 112 |
+
"""
|
| 113 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
| 114 |
+
Args:
|
| 115 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 116 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 117 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 118 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 119 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 120 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Tensor or Tuple[Tensor, Tensor]:
|
| 124 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 125 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 126 |
+
"""
|
| 127 |
+
B, S, _, H, W = images.shape
|
| 128 |
+
|
| 129 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 130 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 131 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 132 |
+
|
| 133 |
+
# Otherwise, process frames in chunks to manage memory usage
|
| 134 |
+
assert frames_chunk_size > 0
|
| 135 |
+
|
| 136 |
+
# Process frames in batches
|
| 137 |
+
all_preds = []
|
| 138 |
+
all_conf = []
|
| 139 |
+
|
| 140 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 141 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 142 |
+
|
| 143 |
+
# Process batch of frames
|
| 144 |
+
if self.feature_only:
|
| 145 |
+
chunk_output = self._forward_impl(
|
| 146 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 147 |
+
)
|
| 148 |
+
all_preds.append(chunk_output)
|
| 149 |
+
else:
|
| 150 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
| 151 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 152 |
+
)
|
| 153 |
+
all_preds.append(chunk_preds)
|
| 154 |
+
all_conf.append(chunk_conf)
|
| 155 |
+
|
| 156 |
+
# Concatenate results along the sequence dimension
|
| 157 |
+
if self.feature_only:
|
| 158 |
+
return torch.cat(all_preds, dim=1)
|
| 159 |
+
else:
|
| 160 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 161 |
+
|
| 162 |
+
def _forward_impl(
|
| 163 |
+
self,
|
| 164 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 165 |
+
images: torch.Tensor,
|
| 166 |
+
patch_start_idx: int,
|
| 167 |
+
frames_start_idx: int = None,
|
| 168 |
+
frames_end_idx: int = None,
|
| 169 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 170 |
+
"""
|
| 171 |
+
Implementation of the forward pass through the DPT head.
|
| 172 |
+
|
| 173 |
+
This method processes a specific chunk of frames from the sequence.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 177 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 178 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 179 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
| 180 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 184 |
+
"""
|
| 185 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 186 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
| 187 |
+
|
| 188 |
+
B, S, _, H, W = images.shape
|
| 189 |
+
|
| 190 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 191 |
+
|
| 192 |
+
out = []
|
| 193 |
+
dpt_idx = 0
|
| 194 |
+
|
| 195 |
+
for layer_idx in self.intermediate_layer_idx:
|
| 196 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 197 |
+
|
| 198 |
+
# Select frames if processing a chunk
|
| 199 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 200 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
| 201 |
+
|
| 202 |
+
x = x.reshape(B * S, -1, x.shape[-1])
|
| 203 |
+
|
| 204 |
+
x = self.norm(x)
|
| 205 |
+
|
| 206 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 207 |
+
|
| 208 |
+
x = self.projects[dpt_idx](x)
|
| 209 |
+
if self.pos_embed:
|
| 210 |
+
x = self._apply_pos_embed(x, W, H)
|
| 211 |
+
x = self.resize_layers[dpt_idx](x)
|
| 212 |
+
|
| 213 |
+
out.append(x)
|
| 214 |
+
dpt_idx += 1
|
| 215 |
+
|
| 216 |
+
# Fuse features from multiple layers.
|
| 217 |
+
out = self.scratch_forward(out)
|
| 218 |
+
# Interpolate fused output to match target image resolution.
|
| 219 |
+
out = custom_interpolate(
|
| 220 |
+
out,
|
| 221 |
+
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
| 222 |
+
mode="bilinear",
|
| 223 |
+
align_corners=True,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
if self.pos_embed:
|
| 227 |
+
out = self._apply_pos_embed(out, W, H)
|
| 228 |
+
|
| 229 |
+
if self.feature_only:
|
| 230 |
+
return out.view(B, S, *out.shape[1:])
|
| 231 |
+
|
| 232 |
+
out = self.scratch.output_conv2(out)
|
| 233 |
+
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
| 234 |
+
|
| 235 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
| 236 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
| 237 |
+
return preds, conf
|
| 238 |
+
|
| 239 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 240 |
+
"""
|
| 241 |
+
Apply positional embedding to tensor x.
|
| 242 |
+
"""
|
| 243 |
+
patch_w = x.shape[-1]
|
| 244 |
+
patch_h = x.shape[-2]
|
| 245 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 246 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 247 |
+
pos_embed = pos_embed * ratio
|
| 248 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 249 |
+
return x + pos_embed
|
| 250 |
+
|
| 251 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 252 |
+
"""
|
| 253 |
+
Forward pass through the fusion blocks.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
features (List[Tensor]): List of feature maps from different layers.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Tensor: Fused feature map.
|
| 260 |
+
"""
|
| 261 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 262 |
+
|
| 263 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 264 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 265 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 266 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 267 |
+
|
| 268 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 269 |
+
del layer_4_rn, layer_4
|
| 270 |
+
|
| 271 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 272 |
+
del layer_3_rn, layer_3
|
| 273 |
+
|
| 274 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 275 |
+
del layer_2_rn, layer_2
|
| 276 |
+
|
| 277 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 278 |
+
del layer_1_rn, layer_1
|
| 279 |
+
|
| 280 |
+
out = self.scratch.output_conv1(out)
|
| 281 |
+
return out
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
################################################################################
|
| 285 |
+
# Modules
|
| 286 |
+
################################################################################
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
| 290 |
+
return FeatureFusionBlock(
|
| 291 |
+
features,
|
| 292 |
+
nn.ReLU(inplace=True),
|
| 293 |
+
deconv=False,
|
| 294 |
+
bn=False,
|
| 295 |
+
expand=False,
|
| 296 |
+
align_corners=True,
|
| 297 |
+
size=size,
|
| 298 |
+
has_residual=has_residual,
|
| 299 |
+
groups=groups,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
| 304 |
+
scratch = nn.Module()
|
| 305 |
+
out_shape1 = out_shape
|
| 306 |
+
out_shape2 = out_shape
|
| 307 |
+
out_shape3 = out_shape
|
| 308 |
+
if len(in_shape) >= 4:
|
| 309 |
+
out_shape4 = out_shape
|
| 310 |
+
|
| 311 |
+
if expand:
|
| 312 |
+
out_shape1 = out_shape
|
| 313 |
+
out_shape2 = out_shape * 2
|
| 314 |
+
out_shape3 = out_shape * 4
|
| 315 |
+
if len(in_shape) >= 4:
|
| 316 |
+
out_shape4 = out_shape * 8
|
| 317 |
+
|
| 318 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 319 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 320 |
+
)
|
| 321 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 322 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 323 |
+
)
|
| 324 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 325 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 326 |
+
)
|
| 327 |
+
if len(in_shape) >= 4:
|
| 328 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 329 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 330 |
+
)
|
| 331 |
+
return scratch
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class ResidualConvUnit(nn.Module):
|
| 335 |
+
"""Residual convolution module."""
|
| 336 |
+
|
| 337 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 338 |
+
"""Init.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
features (int): number of features
|
| 342 |
+
"""
|
| 343 |
+
super().__init__()
|
| 344 |
+
|
| 345 |
+
self.bn = bn
|
| 346 |
+
self.groups = groups
|
| 347 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 348 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 349 |
+
|
| 350 |
+
self.norm1 = None
|
| 351 |
+
self.norm2 = None
|
| 352 |
+
|
| 353 |
+
self.activation = activation
|
| 354 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 355 |
+
|
| 356 |
+
def forward(self, x):
|
| 357 |
+
"""Forward pass.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
x (tensor): input
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
tensor: output
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
out = self.activation(x)
|
| 367 |
+
out = self.conv1(out)
|
| 368 |
+
if self.norm1 is not None:
|
| 369 |
+
out = self.norm1(out)
|
| 370 |
+
|
| 371 |
+
out = self.activation(out)
|
| 372 |
+
out = self.conv2(out)
|
| 373 |
+
if self.norm2 is not None:
|
| 374 |
+
out = self.norm2(out)
|
| 375 |
+
|
| 376 |
+
return self.skip_add.add(out, x)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class FeatureFusionBlock(nn.Module):
|
| 380 |
+
"""Feature fusion block."""
|
| 381 |
+
|
| 382 |
+
def __init__(
|
| 383 |
+
self,
|
| 384 |
+
features,
|
| 385 |
+
activation,
|
| 386 |
+
deconv=False,
|
| 387 |
+
bn=False,
|
| 388 |
+
expand=False,
|
| 389 |
+
align_corners=True,
|
| 390 |
+
size=None,
|
| 391 |
+
has_residual=True,
|
| 392 |
+
groups=1,
|
| 393 |
+
):
|
| 394 |
+
"""Init.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
features (int): number of features
|
| 398 |
+
"""
|
| 399 |
+
super(FeatureFusionBlock, self).__init__()
|
| 400 |
+
|
| 401 |
+
self.deconv = deconv
|
| 402 |
+
self.align_corners = align_corners
|
| 403 |
+
self.groups = groups
|
| 404 |
+
self.expand = expand
|
| 405 |
+
out_features = features
|
| 406 |
+
if self.expand == True:
|
| 407 |
+
out_features = features // 2
|
| 408 |
+
|
| 409 |
+
self.out_conv = nn.Conv2d(
|
| 410 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
if has_residual:
|
| 414 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 415 |
+
|
| 416 |
+
self.has_residual = has_residual
|
| 417 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 418 |
+
|
| 419 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 420 |
+
self.size = size
|
| 421 |
+
|
| 422 |
+
def forward(self, *xs, size=None):
|
| 423 |
+
"""Forward pass.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
tensor: output
|
| 427 |
+
"""
|
| 428 |
+
output = xs[0]
|
| 429 |
+
|
| 430 |
+
if self.has_residual:
|
| 431 |
+
res = self.resConfUnit1(xs[1])
|
| 432 |
+
output = self.skip_add.add(output, res)
|
| 433 |
+
|
| 434 |
+
output = self.resConfUnit2(output)
|
| 435 |
+
|
| 436 |
+
if (size is None) and (self.size is None):
|
| 437 |
+
modifier = {"scale_factor": 2}
|
| 438 |
+
elif size is None:
|
| 439 |
+
modifier = {"size": self.size}
|
| 440 |
+
else:
|
| 441 |
+
modifier = {"size": size}
|
| 442 |
+
|
| 443 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 444 |
+
output = self.out_conv(output)
|
| 445 |
+
|
| 446 |
+
return output
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def custom_interpolate(
|
| 450 |
+
x: torch.Tensor,
|
| 451 |
+
size: Tuple[int, int] = None,
|
| 452 |
+
scale_factor: float = None,
|
| 453 |
+
mode: str = "bilinear",
|
| 454 |
+
align_corners: bool = True,
|
| 455 |
+
) -> torch.Tensor:
|
| 456 |
+
"""
|
| 457 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 458 |
+
"""
|
| 459 |
+
if size is None:
|
| 460 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 461 |
+
|
| 462 |
+
INT_MAX = 1610612736
|
| 463 |
+
|
| 464 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 465 |
+
|
| 466 |
+
if input_elements > INT_MAX:
|
| 467 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 468 |
+
interpolated_chunks = [
|
| 469 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
| 470 |
+
]
|
| 471 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 472 |
+
return x.contiguous()
|
| 473 |
+
else:
|
| 474 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
argus/heads/head_act.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear"):
|
| 6 |
+
"""
|
| 7 |
+
Activate pose parameters with specified activation functions.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, xx]
|
| 11 |
+
trans_act: Activation type for translation component
|
| 12 |
+
quat_act: Activation type for quaternion component
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
Activated pose parameters tensor
|
| 16 |
+
"""
|
| 17 |
+
T = pred_pose_enc[..., :3]
|
| 18 |
+
quat = pred_pose_enc[..., 3:7]
|
| 19 |
+
|
| 20 |
+
T = base_pose_act(T, trans_act)
|
| 21 |
+
quat = base_pose_act(quat, quat_act)
|
| 22 |
+
|
| 23 |
+
# Discard the remaining parameters
|
| 24 |
+
pred_pose_enc = torch.cat([T, quat], dim=-1)
|
| 25 |
+
|
| 26 |
+
return pred_pose_enc
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
| 30 |
+
"""
|
| 31 |
+
Apply basic activation function to pose parameters.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
pose_enc: Tensor containing encoded pose parameters
|
| 35 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Activated pose parameters
|
| 39 |
+
"""
|
| 40 |
+
if act_type == "linear":
|
| 41 |
+
return pose_enc
|
| 42 |
+
elif act_type == "inv_log":
|
| 43 |
+
return inverse_log_transform(pose_enc)
|
| 44 |
+
elif act_type == "exp":
|
| 45 |
+
return torch.exp(pose_enc)
|
| 46 |
+
elif act_type == "relu":
|
| 47 |
+
return F.relu(pose_enc)
|
| 48 |
+
elif act_type == "expp1":
|
| 49 |
+
return 1 + pose_enc.exp()
|
| 50 |
+
elif act_type == "expp0":
|
| 51 |
+
return pose_enc.exp()
|
| 52 |
+
elif act_type == "sigmoid":
|
| 53 |
+
return torch.sigmoid(pose_enc)
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
| 59 |
+
"""
|
| 60 |
+
Process network output to extract 3D points and confidence values.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
out: Network output tensor (B, C, H, W)
|
| 64 |
+
activation: Activation type for 3D points
|
| 65 |
+
conf_activation: Activation type for confidence values
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tuple of (3D points tensor, confidence tensor)
|
| 69 |
+
"""
|
| 70 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 71 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 72 |
+
|
| 73 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 74 |
+
xyz = fmap[:, :, :, :-1]
|
| 75 |
+
conf = fmap[:, :, :, -1]
|
| 76 |
+
|
| 77 |
+
if activation == "norm_exp":
|
| 78 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 79 |
+
xyz_normed = xyz / d
|
| 80 |
+
pts3d = xyz_normed * torch.expm1(d)
|
| 81 |
+
elif activation == "norm":
|
| 82 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 83 |
+
elif activation == "exp":
|
| 84 |
+
pts3d = torch.exp(xyz)
|
| 85 |
+
elif activation == "relu":
|
| 86 |
+
pts3d = F.relu(xyz)
|
| 87 |
+
elif activation == "inv_log":
|
| 88 |
+
pts3d = inverse_log_transform(xyz)
|
| 89 |
+
elif activation == "xy_inv_log":
|
| 90 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
| 91 |
+
z = inverse_log_transform(z)
|
| 92 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
| 93 |
+
elif activation == "sigmoid":
|
| 94 |
+
pts3d = torch.sigmoid(xyz)
|
| 95 |
+
elif activation == "linear":
|
| 96 |
+
pts3d = xyz
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 99 |
+
|
| 100 |
+
if conf_activation == "expp1":
|
| 101 |
+
conf_out = 1 + conf.exp()
|
| 102 |
+
elif conf_activation == "expp0":
|
| 103 |
+
conf_out = conf.exp()
|
| 104 |
+
elif conf_activation == "sigmoid":
|
| 105 |
+
conf_out = torch.sigmoid(conf)
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 108 |
+
|
| 109 |
+
return pts3d, conf_out
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def inverse_log_transform(y):
|
| 113 |
+
"""
|
| 114 |
+
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
y: Input tensor
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Transformed tensor
|
| 121 |
+
"""
|
| 122 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
argus/heads/utils.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
| 5 |
+
"""
|
| 6 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
| 10 |
+
embed_dim: Output channel dimension for embeddings
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
| 14 |
+
"""
|
| 15 |
+
H, W, grid_dim = pos_grid.shape
|
| 16 |
+
assert grid_dim == 2
|
| 17 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
| 18 |
+
|
| 19 |
+
# Process x and y coordinates separately
|
| 20 |
+
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
| 21 |
+
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
| 22 |
+
|
| 23 |
+
# Combine and reshape
|
| 24 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
| 25 |
+
|
| 26 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
| 30 |
+
"""
|
| 31 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
- embed_dim: The embedding dimension.
|
| 35 |
+
- pos: The position to generate the embedding from.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
- emb: The generated 1D positional embedding.
|
| 39 |
+
"""
|
| 40 |
+
assert embed_dim % 2 == 0
|
| 41 |
+
device = pos.device
|
| 42 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
|
| 43 |
+
omega /= embed_dim / 2.0
|
| 44 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
| 45 |
+
|
| 46 |
+
pos = pos.reshape(-1) # (M,)
|
| 47 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 48 |
+
|
| 49 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 50 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 51 |
+
|
| 52 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 53 |
+
return emb.float()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Inspired by https://github.com/microsoft/moge
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def create_uv_grid(
|
| 60 |
+
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
| 61 |
+
) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Create a normalized UV grid of shape (width, height, 2).
|
| 64 |
+
|
| 65 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
| 66 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
| 67 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
width (int): Number of points horizontally.
|
| 71 |
+
height (int): Number of points vertically.
|
| 72 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
| 73 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
| 74 |
+
device (torch.device, optional): Device on which the tensor is created.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
| 78 |
+
"""
|
| 79 |
+
# Derive aspect ratio if not explicitly provided
|
| 80 |
+
if aspect_ratio is None:
|
| 81 |
+
aspect_ratio = float(width) / float(height)
|
| 82 |
+
|
| 83 |
+
# Compute normalized spans for X and Y
|
| 84 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
| 85 |
+
span_x = aspect_ratio / diag_factor
|
| 86 |
+
span_y = 1.0 / diag_factor
|
| 87 |
+
|
| 88 |
+
# Establish the linspace boundaries
|
| 89 |
+
left_x = -span_x * (width - 1) / width
|
| 90 |
+
right_x = span_x * (width - 1) / width
|
| 91 |
+
top_y = -span_y * (height - 1) / height
|
| 92 |
+
bottom_y = span_y * (height - 1) / height
|
| 93 |
+
|
| 94 |
+
# Generate 1D coordinates
|
| 95 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
| 96 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
| 97 |
+
|
| 98 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
| 99 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
| 100 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
| 101 |
+
|
| 102 |
+
return uv_grid
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
"""Reorder tensor views to place the selected reference view at the first position (index 0),
|
| 108 |
+
while keeping the remaining views in their original order (excluding the reference view).
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
x: Input tensor with shape (B, S, ...) where B = batch size, S = number of views,
|
| 112 |
+
and trailing dimensions can be arbitrary (e.g., N, C for patch tokens).
|
| 113 |
+
b_idx: 1D tensor of shape (B,) containing the index of the reference view for each batch element,
|
| 114 |
+
each value must be in the range [0, S-1].
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Reordered tensor with the same shape as input, where the reference view is at position 0
|
| 118 |
+
and other views retain their original order (skipping the reference view).
|
| 119 |
+
|
| 120 |
+
Example:
|
| 121 |
+
If B=1, S=5, b_idx=[2], input view order is [0,1,2,3,4],
|
| 122 |
+
output order becomes [2,0,1,3,4].
|
| 123 |
+
"""
|
| 124 |
+
# Extract batch size (B) and number of views (S) from input shape
|
| 125 |
+
B, S = x.shape[0], x.shape[1]
|
| 126 |
+
|
| 127 |
+
# No reordering needed if only one view exists
|
| 128 |
+
if S <= 1:
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
# Generate base index matrix (B, S): each row is [0, 1, ..., S-1] (same across batches)
|
| 132 |
+
idx = torch.arange(S, device=x.device).expand(B, -1)
|
| 133 |
+
|
| 134 |
+
# Create mask to exclude reference view indices (True for non-reference positions)
|
| 135 |
+
mask = idx != b_idx.unsqueeze(1)
|
| 136 |
+
|
| 137 |
+
# Build reorder indices: [reference_idx] + [all non-reference indices in original order]
|
| 138 |
+
# Reshape non-reference indices to (B, S-1) to match batch dimension, then concatenate
|
| 139 |
+
reorder_idx = torch.cat([b_idx.unsqueeze(1), idx[mask].reshape(B, S-1)], dim=1)
|
| 140 |
+
|
| 141 |
+
# Advanced indexing to reorder: batch indices (B,1) paired with reorder indices (B,S)
|
| 142 |
+
return x[torch.arange(B).unsqueeze(1), reorder_idx]
|
argus/layers/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
argus/layers/attention.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
XFORMERS_AVAILABLE = False
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Attention(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dim: int,
|
| 25 |
+
num_heads: int = 8,
|
| 26 |
+
qkv_bias: bool = True,
|
| 27 |
+
proj_bias: bool = True,
|
| 28 |
+
attn_drop: float = 0.0,
|
| 29 |
+
proj_drop: float = 0.0,
|
| 30 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 31 |
+
qk_norm: bool = False,
|
| 32 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 33 |
+
rope=None,
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 37 |
+
self.num_heads = num_heads
|
| 38 |
+
self.head_dim = dim // num_heads
|
| 39 |
+
self.scale = self.head_dim**-0.5
|
| 40 |
+
self.fused_attn = fused_attn
|
| 41 |
+
|
| 42 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 43 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 44 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
self.rope = rope
|
| 49 |
+
|
| 50 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 51 |
+
B, N, C = x.shape
|
| 52 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 53 |
+
q, k, v = qkv.unbind(0)
|
| 54 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 55 |
+
|
| 56 |
+
if self.rope is not None:
|
| 57 |
+
q = self.rope(q, pos)
|
| 58 |
+
k = self.rope(k, pos)
|
| 59 |
+
|
| 60 |
+
if self.fused_attn:
|
| 61 |
+
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
|
| 62 |
+
else:
|
| 63 |
+
q = q * self.scale
|
| 64 |
+
attn = q @ k.transpose(-2, -1)
|
| 65 |
+
attn = attn.softmax(dim=-1)
|
| 66 |
+
attn = self.attn_drop(attn)
|
| 67 |
+
x = attn @ v
|
| 68 |
+
|
| 69 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 70 |
+
x = self.proj(x)
|
| 71 |
+
x = self.proj_drop(x)
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class MemEffAttention(Attention):
|
| 76 |
+
def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
|
| 77 |
+
assert pos is None
|
| 78 |
+
if not XFORMERS_AVAILABLE:
|
| 79 |
+
if attn_bias is not None:
|
| 80 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 81 |
+
return super().forward(x)
|
| 82 |
+
|
| 83 |
+
B, N, C = x.shape
|
| 84 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 85 |
+
|
| 86 |
+
q, k, v = unbind(qkv, 2)
|
| 87 |
+
|
| 88 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 89 |
+
x = x.reshape([B, N, C])
|
| 90 |
+
|
| 91 |
+
x = self.proj(x)
|
| 92 |
+
x = self.proj_drop(x)
|
| 93 |
+
return x
|
argus/layers/block.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
XFORMERS_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Block(nn.Module):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dim: int,
|
| 31 |
+
num_heads: int,
|
| 32 |
+
mlp_ratio: float = 4.0,
|
| 33 |
+
qkv_bias: bool = True,
|
| 34 |
+
proj_bias: bool = True,
|
| 35 |
+
ffn_bias: bool = True,
|
| 36 |
+
drop: float = 0.0,
|
| 37 |
+
attn_drop: float = 0.0,
|
| 38 |
+
init_values=None,
|
| 39 |
+
drop_path: float = 0.0,
|
| 40 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 41 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 42 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 43 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 44 |
+
qk_norm: bool = False,
|
| 45 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 46 |
+
rope=None,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.norm1 = norm_layer(dim)
|
| 51 |
+
|
| 52 |
+
self.attn = attn_class(
|
| 53 |
+
dim,
|
| 54 |
+
num_heads=num_heads,
|
| 55 |
+
qkv_bias=qkv_bias,
|
| 56 |
+
proj_bias=proj_bias,
|
| 57 |
+
attn_drop=attn_drop,
|
| 58 |
+
proj_drop=drop,
|
| 59 |
+
qk_norm=qk_norm,
|
| 60 |
+
fused_attn=fused_attn,
|
| 61 |
+
rope=rope,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 65 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 66 |
+
|
| 67 |
+
self.norm2 = norm_layer(dim)
|
| 68 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 69 |
+
self.mlp = ffn_layer(
|
| 70 |
+
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
|
| 71 |
+
)
|
| 72 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 73 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 74 |
+
|
| 75 |
+
self.sample_drop_ratio = drop_path
|
| 76 |
+
|
| 77 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
| 78 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
| 79 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos))
|
| 80 |
+
|
| 81 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 82 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 83 |
+
|
| 84 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 85 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 86 |
+
x = drop_add_residual_stochastic_depth(
|
| 87 |
+
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 88 |
+
)
|
| 89 |
+
x = drop_add_residual_stochastic_depth(
|
| 90 |
+
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
| 91 |
+
)
|
| 92 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 93 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
| 94 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 95 |
+
else:
|
| 96 |
+
x = x + attn_residual_func(x, pos=pos)
|
| 97 |
+
x = x + ffn_residual_func(x)
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def drop_add_residual_stochastic_depth(
|
| 102 |
+
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
|
| 103 |
+
) -> Tensor:
|
| 104 |
+
# 1) extract subset using permutation
|
| 105 |
+
b, n, d = x.shape
|
| 106 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 107 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 108 |
+
x_subset = x[brange]
|
| 109 |
+
|
| 110 |
+
# 2) apply residual_func to get residual
|
| 111 |
+
if pos is not None:
|
| 112 |
+
# if necessary, apply rope to the subset
|
| 113 |
+
pos = pos[brange]
|
| 114 |
+
residual = residual_func(x_subset, pos=pos)
|
| 115 |
+
else:
|
| 116 |
+
residual = residual_func(x_subset)
|
| 117 |
+
|
| 118 |
+
x_flat = x.flatten(1)
|
| 119 |
+
residual = residual.flatten(1)
|
| 120 |
+
|
| 121 |
+
residual_scale_factor = b / sample_subset_size
|
| 122 |
+
|
| 123 |
+
# 3) add the residual
|
| 124 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 125 |
+
return x_plus_residual.view_as(x)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 129 |
+
b, n, d = x.shape
|
| 130 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 131 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 132 |
+
residual_scale_factor = b / sample_subset_size
|
| 133 |
+
return brange, residual_scale_factor
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 137 |
+
if scaling_vector is None:
|
| 138 |
+
x_flat = x.flatten(1)
|
| 139 |
+
residual = residual.flatten(1)
|
| 140 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 141 |
+
else:
|
| 142 |
+
x_plus_residual = scaled_index_add(
|
| 143 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 144 |
+
)
|
| 145 |
+
return x_plus_residual
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 152 |
+
"""
|
| 153 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 154 |
+
"""
|
| 155 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 156 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 157 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 158 |
+
seqlens = []
|
| 159 |
+
for b, x in zip(batch_sizes, x_list):
|
| 160 |
+
for _ in range(b):
|
| 161 |
+
seqlens.append(x.shape[1])
|
| 162 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 163 |
+
attn_bias._batch_sizes = batch_sizes
|
| 164 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 165 |
+
|
| 166 |
+
if branges is not None:
|
| 167 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 168 |
+
else:
|
| 169 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 170 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 171 |
+
|
| 172 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def drop_add_residual_stochastic_depth_list(
|
| 176 |
+
x_list: List[Tensor],
|
| 177 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 178 |
+
sample_drop_ratio: float = 0.0,
|
| 179 |
+
scaling_vector=None,
|
| 180 |
+
) -> Tensor:
|
| 181 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 182 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 183 |
+
branges = [s[0] for s in branges_scales]
|
| 184 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 185 |
+
|
| 186 |
+
# 2) get attention bias and index+concat the tensors
|
| 187 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 188 |
+
|
| 189 |
+
# 3) apply residual_func to get residual, and split the result
|
| 190 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 191 |
+
|
| 192 |
+
outputs = []
|
| 193 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 194 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 195 |
+
return outputs
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class NestedTensorBlock(Block):
|
| 199 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 200 |
+
"""
|
| 201 |
+
x_list contains a list of tensors to nest together and run
|
| 202 |
+
"""
|
| 203 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 204 |
+
|
| 205 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 206 |
+
|
| 207 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 208 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 209 |
+
|
| 210 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 211 |
+
return self.mlp(self.norm2(x))
|
| 212 |
+
|
| 213 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 214 |
+
x_list,
|
| 215 |
+
residual_func=attn_residual_func,
|
| 216 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 217 |
+
scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
|
| 218 |
+
)
|
| 219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 220 |
+
x_list,
|
| 221 |
+
residual_func=ffn_residual_func,
|
| 222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 223 |
+
scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
|
| 224 |
+
)
|
| 225 |
+
return x_list
|
| 226 |
+
else:
|
| 227 |
+
|
| 228 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 229 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 230 |
+
|
| 231 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 232 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 233 |
+
|
| 234 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 235 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 236 |
+
x = x + ffn_residual_func(x)
|
| 237 |
+
return attn_bias.split(x)
|
| 238 |
+
|
| 239 |
+
def forward(self, x_or_x_list):
|
| 240 |
+
if isinstance(x_or_x_list, Tensor):
|
| 241 |
+
return super().forward(x_or_x_list)
|
| 242 |
+
elif isinstance(x_or_x_list, list):
|
| 243 |
+
if not XFORMERS_AVAILABLE:
|
| 244 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 245 |
+
return self.forward_nested(x_or_x_list)
|
| 246 |
+
else:
|
| 247 |
+
raise AssertionError
|
argus/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
argus/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.inplace = inplace
|
| 19 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 20 |
+
|
| 21 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 22 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
argus/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
argus/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
|
| 51 |
+
|
| 52 |
+
self.img_size = image_HW
|
| 53 |
+
self.patch_size = patch_HW
|
| 54 |
+
self.patches_resolution = patch_grid_size
|
| 55 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 56 |
+
|
| 57 |
+
self.in_chans = in_chans
|
| 58 |
+
self.embed_dim = embed_dim
|
| 59 |
+
|
| 60 |
+
self.flatten_embedding = flatten_embedding
|
| 61 |
+
|
| 62 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 63 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 64 |
+
|
| 65 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 66 |
+
_, _, H, W = x.shape
|
| 67 |
+
patch_H, patch_W = self.patch_size
|
| 68 |
+
|
| 69 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 70 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 71 |
+
|
| 72 |
+
x = self.proj(x) # B C H W
|
| 73 |
+
H, W = x.size(2), x.size(3)
|
| 74 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 75 |
+
x = self.norm(x)
|
| 76 |
+
if not self.flatten_embedding:
|
| 77 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
def flops(self) -> float:
|
| 81 |
+
Ho, Wo = self.patches_resolution
|
| 82 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 83 |
+
if self.norm is not None:
|
| 84 |
+
flops += Ho * Wo * self.embed_dim
|
| 85 |
+
return flops
|
argus/layers/rope.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
| 8 |
+
|
| 9 |
+
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
| 10 |
+
# which extends the original RoPE concept to handle 2D spatial positions.
|
| 11 |
+
|
| 12 |
+
# Inspired by:
|
| 13 |
+
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 14 |
+
# https://github.com/naver-ai/rope-vit
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from typing import Dict, Tuple
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PositionGetter:
|
| 25 |
+
"""Generates and caches 2D spatial positions for patches in a grid.
|
| 26 |
+
|
| 27 |
+
This class efficiently manages the generation of spatial coordinates for patches
|
| 28 |
+
in a 2D grid, caching results to avoid redundant computations.
|
| 29 |
+
|
| 30 |
+
Attributes:
|
| 31 |
+
position_cache: Dictionary storing precomputed position tensors for different
|
| 32 |
+
grid dimensions.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
"""Initializes the position generator with an empty cache."""
|
| 37 |
+
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
| 38 |
+
|
| 39 |
+
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
| 40 |
+
"""Generates spatial positions for a batch of patches.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
batch_size: Number of samples in the batch.
|
| 44 |
+
height: Height of the grid in patches.
|
| 45 |
+
width: Width of the grid in patches.
|
| 46 |
+
device: Target device for the position tensor.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
| 50 |
+
for each position in the grid, repeated for each batch item.
|
| 51 |
+
"""
|
| 52 |
+
if (height, width) not in self.position_cache:
|
| 53 |
+
y_coords = torch.arange(height, device=device)
|
| 54 |
+
x_coords = torch.arange(width, device=device)
|
| 55 |
+
positions = torch.cartesian_prod(y_coords, x_coords)
|
| 56 |
+
self.position_cache[height, width] = positions
|
| 57 |
+
|
| 58 |
+
cached_positions = self.position_cache[height, width]
|
| 59 |
+
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class RotaryPositionEmbedding2D(nn.Module):
|
| 63 |
+
"""2D Rotary Position Embedding implementation.
|
| 64 |
+
|
| 65 |
+
This module applies rotary position embeddings to input tokens based on their
|
| 66 |
+
2D spatial positions. It handles the position-dependent rotation of features
|
| 67 |
+
separately for vertical and horizontal dimensions.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
frequency: Base frequency for the position embeddings. Default: 100.0
|
| 71 |
+
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
| 72 |
+
|
| 73 |
+
Attributes:
|
| 74 |
+
base_frequency: Base frequency for computing position embeddings.
|
| 75 |
+
scaling_factor: Factor to scale the computed frequencies.
|
| 76 |
+
frequency_cache: Cache for storing precomputed frequency components.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
| 80 |
+
"""Initializes the 2D RoPE module."""
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.base_frequency = frequency
|
| 83 |
+
self.scaling_factor = scaling_factor
|
| 84 |
+
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 85 |
+
|
| 86 |
+
def _compute_frequency_components(
|
| 87 |
+
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
| 88 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 89 |
+
"""Computes frequency components for rotary embeddings.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
dim: Feature dimension (must be even).
|
| 93 |
+
seq_len: Maximum sequence length.
|
| 94 |
+
device: Target device for computations.
|
| 95 |
+
dtype: Data type for the computed tensors.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Tuple of (cosine, sine) tensors for frequency components.
|
| 99 |
+
"""
|
| 100 |
+
cache_key = (dim, seq_len, device, dtype)
|
| 101 |
+
if cache_key not in self.frequency_cache:
|
| 102 |
+
# Compute frequency bands
|
| 103 |
+
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
| 104 |
+
inv_freq = 1.0 / (self.base_frequency**exponents)
|
| 105 |
+
|
| 106 |
+
# Generate position-dependent frequencies
|
| 107 |
+
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 108 |
+
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
| 109 |
+
|
| 110 |
+
# Compute and cache frequency components
|
| 111 |
+
angles = angles.to(dtype)
|
| 112 |
+
angles = torch.cat((angles, angles), dim=-1)
|
| 113 |
+
cos_components = angles.cos().to(dtype)
|
| 114 |
+
sin_components = angles.sin().to(dtype)
|
| 115 |
+
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
| 116 |
+
|
| 117 |
+
return self.frequency_cache[cache_key]
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
"""Performs feature rotation by splitting and recombining feature dimensions.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: Input tensor to rotate.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Rotated feature tensor.
|
| 128 |
+
"""
|
| 129 |
+
feature_dim = x.shape[-1]
|
| 130 |
+
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
| 131 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 132 |
+
|
| 133 |
+
def _apply_1d_rope(
|
| 134 |
+
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
| 135 |
+
) -> torch.Tensor:
|
| 136 |
+
"""Applies 1D rotary position embeddings along one dimension.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
tokens: Input token features.
|
| 140 |
+
positions: Position indices.
|
| 141 |
+
cos_comp: Cosine components for rotation.
|
| 142 |
+
sin_comp: Sine components for rotation.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tokens with applied rotary position embeddings.
|
| 146 |
+
"""
|
| 147 |
+
# Embed positions with frequency components
|
| 148 |
+
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
| 149 |
+
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
| 150 |
+
|
| 151 |
+
# Apply rotation
|
| 152 |
+
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
| 153 |
+
|
| 154 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
"""Applies 2D rotary position embeddings to input tokens.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
| 159 |
+
The feature dimension (dim) must be divisible by 4.
|
| 160 |
+
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
| 161 |
+
the y and x coordinates for each token.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Tensor of same shape as input with applied 2D rotary position embeddings.
|
| 165 |
+
|
| 166 |
+
Raises:
|
| 167 |
+
AssertionError: If input dimensions are invalid or positions are malformed.
|
| 168 |
+
"""
|
| 169 |
+
# Validate inputs
|
| 170 |
+
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
| 171 |
+
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
| 172 |
+
|
| 173 |
+
# Compute feature dimension for each spatial direction
|
| 174 |
+
feature_dim = tokens.size(-1) // 2
|
| 175 |
+
|
| 176 |
+
# Get frequency components
|
| 177 |
+
max_position = int(positions.max()) + 1
|
| 178 |
+
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
| 179 |
+
|
| 180 |
+
# Split features for vertical and horizontal processing
|
| 181 |
+
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
| 182 |
+
|
| 183 |
+
# Apply RoPE separately for each dimension
|
| 184 |
+
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
| 185 |
+
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
| 186 |
+
|
| 187 |
+
# Combine processed features
|
| 188 |
+
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
argus/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
# try:
|
| 39 |
+
# if XFORMERS_ENABLED:
|
| 40 |
+
# from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
# XFORMERS_AVAILABLE = True
|
| 43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
# else:
|
| 45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
# raise ImportError
|
| 47 |
+
# except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
|
argus/layers/vision_transformer.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.checkpoint import checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
from .mlp import Mlp
|
| 20 |
+
from .patch_embed import PatchEmbed
|
| 21 |
+
from .swiglu_ffn import SwiGLUFFNFused
|
| 22 |
+
from .attention import MemEffAttention
|
| 23 |
+
from .block import NestedTensorBlock as Block
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("dinov2")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 29 |
+
if not depth_first and include_root:
|
| 30 |
+
fn(module=module, name=name)
|
| 31 |
+
for child_name, child_module in module.named_children():
|
| 32 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 33 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 34 |
+
if depth_first and include_root:
|
| 35 |
+
fn(module=module, name=name)
|
| 36 |
+
return module
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BlockChunk(nn.ModuleList):
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
for b in self:
|
| 42 |
+
x = b(x)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DinoVisionTransformer(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
img_size=224,
|
| 50 |
+
patch_size=16,
|
| 51 |
+
in_chans=3,
|
| 52 |
+
embed_dim=768,
|
| 53 |
+
depth=12,
|
| 54 |
+
num_heads=12,
|
| 55 |
+
mlp_ratio=4.0,
|
| 56 |
+
qkv_bias=True,
|
| 57 |
+
ffn_bias=True,
|
| 58 |
+
proj_bias=True,
|
| 59 |
+
drop_path_rate=0.0,
|
| 60 |
+
drop_path_uniform=False,
|
| 61 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 62 |
+
embed_layer=PatchEmbed,
|
| 63 |
+
act_layer=nn.GELU,
|
| 64 |
+
block_fn=Block,
|
| 65 |
+
ffn_layer="mlp",
|
| 66 |
+
block_chunks=1,
|
| 67 |
+
num_register_tokens=0,
|
| 68 |
+
interpolate_antialias=False,
|
| 69 |
+
interpolate_offset=0.1,
|
| 70 |
+
qk_norm=False,
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
img_size (int, tuple): input image size
|
| 75 |
+
patch_size (int, tuple): patch size
|
| 76 |
+
in_chans (int): number of input channels
|
| 77 |
+
embed_dim (int): embedding dimension
|
| 78 |
+
depth (int): depth of transformer
|
| 79 |
+
num_heads (int): number of attention heads
|
| 80 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 81 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 82 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 83 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 84 |
+
drop_path_rate (float): stochastic depth rate
|
| 85 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 86 |
+
weight_init (str): weight init scheme
|
| 87 |
+
init_values (float): layer-scale init values
|
| 88 |
+
embed_layer (nn.Module): patch embedding layer
|
| 89 |
+
act_layer (nn.Module): MLP activation layer
|
| 90 |
+
block_fn (nn.Module): transformer block class
|
| 91 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 92 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 93 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 94 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 95 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 96 |
+
"""
|
| 97 |
+
super().__init__()
|
| 98 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 99 |
+
|
| 100 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 101 |
+
self.num_tokens = 1
|
| 102 |
+
self.n_blocks = depth
|
| 103 |
+
self.num_heads = num_heads
|
| 104 |
+
self.patch_size = patch_size
|
| 105 |
+
self.num_register_tokens = num_register_tokens
|
| 106 |
+
self.interpolate_antialias = interpolate_antialias
|
| 107 |
+
self.interpolate_offset = interpolate_offset
|
| 108 |
+
self.use_reentrant = False # hardcoded to False
|
| 109 |
+
|
| 110 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 111 |
+
num_patches = self.patch_embed.num_patches
|
| 112 |
+
|
| 113 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 114 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 115 |
+
assert num_register_tokens >= 0
|
| 116 |
+
self.register_tokens = (
|
| 117 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if drop_path_uniform is True:
|
| 121 |
+
dpr = [drop_path_rate] * depth
|
| 122 |
+
else:
|
| 123 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 124 |
+
|
| 125 |
+
if ffn_layer == "mlp":
|
| 126 |
+
logger.info("using MLP layer as FFN")
|
| 127 |
+
ffn_layer = Mlp
|
| 128 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 129 |
+
logger.info("using SwiGLU layer as FFN")
|
| 130 |
+
ffn_layer = SwiGLUFFNFused
|
| 131 |
+
elif ffn_layer == "identity":
|
| 132 |
+
logger.info("using Identity layer as FFN")
|
| 133 |
+
|
| 134 |
+
def f(*args, **kwargs):
|
| 135 |
+
return nn.Identity()
|
| 136 |
+
|
| 137 |
+
ffn_layer = f
|
| 138 |
+
else:
|
| 139 |
+
raise NotImplementedError
|
| 140 |
+
|
| 141 |
+
blocks_list = [
|
| 142 |
+
block_fn(
|
| 143 |
+
dim=embed_dim,
|
| 144 |
+
num_heads=num_heads,
|
| 145 |
+
mlp_ratio=mlp_ratio,
|
| 146 |
+
qkv_bias=qkv_bias,
|
| 147 |
+
proj_bias=proj_bias,
|
| 148 |
+
ffn_bias=ffn_bias,
|
| 149 |
+
drop_path=dpr[i],
|
| 150 |
+
norm_layer=norm_layer,
|
| 151 |
+
act_layer=act_layer,
|
| 152 |
+
ffn_layer=ffn_layer,
|
| 153 |
+
init_values=init_values,
|
| 154 |
+
qk_norm=qk_norm,
|
| 155 |
+
)
|
| 156 |
+
for i in range(depth)
|
| 157 |
+
]
|
| 158 |
+
if block_chunks > 0:
|
| 159 |
+
self.chunked_blocks = True
|
| 160 |
+
chunked_blocks = []
|
| 161 |
+
chunksize = depth // block_chunks
|
| 162 |
+
for i in range(0, depth, chunksize):
|
| 163 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 164 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 165 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 166 |
+
else:
|
| 167 |
+
self.chunked_blocks = False
|
| 168 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 169 |
+
|
| 170 |
+
self.norm = norm_layer(embed_dim)
|
| 171 |
+
self.head = nn.Identity()
|
| 172 |
+
|
| 173 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 174 |
+
|
| 175 |
+
self.init_weights()
|
| 176 |
+
|
| 177 |
+
def init_weights(self):
|
| 178 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 179 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 180 |
+
if self.register_tokens is not None:
|
| 181 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 182 |
+
named_apply(init_weights_vit_timm, self)
|
| 183 |
+
|
| 184 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 185 |
+
previous_dtype = x.dtype
|
| 186 |
+
npatch = x.shape[1] - 1
|
| 187 |
+
N = self.pos_embed.shape[1] - 1
|
| 188 |
+
if npatch == N and w == h:
|
| 189 |
+
return self.pos_embed
|
| 190 |
+
pos_embed = self.pos_embed.float()
|
| 191 |
+
class_pos_embed = pos_embed[:, 0]
|
| 192 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 193 |
+
dim = x.shape[-1]
|
| 194 |
+
w0 = w // self.patch_size
|
| 195 |
+
h0 = h // self.patch_size
|
| 196 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 197 |
+
assert N == M * M
|
| 198 |
+
kwargs = {}
|
| 199 |
+
if self.interpolate_offset:
|
| 200 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 201 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 202 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 203 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 204 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 205 |
+
else:
|
| 206 |
+
# Simply specify an output size instead of a scale factor
|
| 207 |
+
kwargs["size"] = (w0, h0)
|
| 208 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 209 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 210 |
+
mode="bicubic",
|
| 211 |
+
antialias=self.interpolate_antialias,
|
| 212 |
+
**kwargs,
|
| 213 |
+
)
|
| 214 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 215 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 216 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 217 |
+
|
| 218 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 219 |
+
B, nc, w, h = x.shape
|
| 220 |
+
x = self.patch_embed(x)
|
| 221 |
+
if masks is not None:
|
| 222 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 223 |
+
|
| 224 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 225 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 226 |
+
|
| 227 |
+
if self.register_tokens is not None:
|
| 228 |
+
x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
|
| 229 |
+
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
def forward_features_list(self, x_list, masks_list):
|
| 233 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 234 |
+
|
| 235 |
+
for blk in self.blocks:
|
| 236 |
+
if self.training:
|
| 237 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
| 238 |
+
else:
|
| 239 |
+
x = blk(x)
|
| 240 |
+
|
| 241 |
+
all_x = x
|
| 242 |
+
output = []
|
| 243 |
+
for x, masks in zip(all_x, masks_list):
|
| 244 |
+
x_norm = self.norm(x)
|
| 245 |
+
output.append(
|
| 246 |
+
{
|
| 247 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 248 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 249 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 250 |
+
"x_prenorm": x,
|
| 251 |
+
"masks": masks,
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
return output
|
| 255 |
+
|
| 256 |
+
def forward_features(self, x, masks=None):
|
| 257 |
+
if isinstance(x, list):
|
| 258 |
+
return self.forward_features_list(x, masks)
|
| 259 |
+
|
| 260 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 261 |
+
|
| 262 |
+
for blk in self.blocks:
|
| 263 |
+
if self.training:
|
| 264 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
| 265 |
+
else:
|
| 266 |
+
x = blk(x)
|
| 267 |
+
|
| 268 |
+
x_norm = self.norm(x)
|
| 269 |
+
return {
|
| 270 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 271 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 272 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 273 |
+
"x_prenorm": x,
|
| 274 |
+
"masks": masks,
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 278 |
+
x = self.prepare_tokens_with_masks(x)
|
| 279 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 280 |
+
output, total_block_len = [], len(self.blocks)
|
| 281 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 282 |
+
for i, blk in enumerate(self.blocks):
|
| 283 |
+
x = blk(x)
|
| 284 |
+
if i in blocks_to_take:
|
| 285 |
+
output.append(x)
|
| 286 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 287 |
+
return output
|
| 288 |
+
|
| 289 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 290 |
+
x = self.prepare_tokens_with_masks(x)
|
| 291 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 292 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 293 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 294 |
+
for block_chunk in self.blocks:
|
| 295 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 296 |
+
x = blk(x)
|
| 297 |
+
if i in blocks_to_take:
|
| 298 |
+
output.append(x)
|
| 299 |
+
i += 1
|
| 300 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 301 |
+
return output
|
| 302 |
+
|
| 303 |
+
def get_intermediate_layers(
|
| 304 |
+
self,
|
| 305 |
+
x: torch.Tensor,
|
| 306 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 307 |
+
reshape: bool = False,
|
| 308 |
+
return_class_token: bool = False,
|
| 309 |
+
norm=True,
|
| 310 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 311 |
+
if self.chunked_blocks:
|
| 312 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 313 |
+
else:
|
| 314 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 315 |
+
if norm:
|
| 316 |
+
outputs = [self.norm(out) for out in outputs]
|
| 317 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 318 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 319 |
+
if reshape:
|
| 320 |
+
B, _, w, h = x.shape
|
| 321 |
+
outputs = [
|
| 322 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 323 |
+
for out in outputs
|
| 324 |
+
]
|
| 325 |
+
if return_class_token:
|
| 326 |
+
return tuple(zip(outputs, class_tokens))
|
| 327 |
+
return tuple(outputs)
|
| 328 |
+
|
| 329 |
+
def forward(self, *args, is_training=True, **kwargs):
|
| 330 |
+
ret = self.forward_features(*args, **kwargs)
|
| 331 |
+
if is_training:
|
| 332 |
+
return ret
|
| 333 |
+
else:
|
| 334 |
+
return self.head(ret["x_norm_clstoken"])
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 338 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 339 |
+
if isinstance(module, nn.Linear):
|
| 340 |
+
trunc_normal_(module.weight, std=0.02)
|
| 341 |
+
if module.bias is not None:
|
| 342 |
+
nn.init.zeros_(module.bias)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 346 |
+
model = DinoVisionTransformer(
|
| 347 |
+
patch_size=patch_size,
|
| 348 |
+
embed_dim=384,
|
| 349 |
+
depth=12,
|
| 350 |
+
num_heads=6,
|
| 351 |
+
mlp_ratio=4,
|
| 352 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 353 |
+
num_register_tokens=num_register_tokens,
|
| 354 |
+
**kwargs,
|
| 355 |
+
)
|
| 356 |
+
return model
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 360 |
+
model = DinoVisionTransformer(
|
| 361 |
+
patch_size=patch_size,
|
| 362 |
+
embed_dim=768,
|
| 363 |
+
depth=12,
|
| 364 |
+
num_heads=12,
|
| 365 |
+
mlp_ratio=4,
|
| 366 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 367 |
+
num_register_tokens=num_register_tokens,
|
| 368 |
+
**kwargs,
|
| 369 |
+
)
|
| 370 |
+
return model
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 374 |
+
model = DinoVisionTransformer(
|
| 375 |
+
patch_size=patch_size,
|
| 376 |
+
embed_dim=1024,
|
| 377 |
+
depth=24,
|
| 378 |
+
num_heads=16,
|
| 379 |
+
mlp_ratio=4,
|
| 380 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 381 |
+
num_register_tokens=num_register_tokens,
|
| 382 |
+
**kwargs,
|
| 383 |
+
)
|
| 384 |
+
return model
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 388 |
+
"""
|
| 389 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 390 |
+
"""
|
| 391 |
+
model = DinoVisionTransformer(
|
| 392 |
+
patch_size=patch_size,
|
| 393 |
+
embed_dim=1536,
|
| 394 |
+
depth=40,
|
| 395 |
+
num_heads=24,
|
| 396 |
+
mlp_ratio=4,
|
| 397 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 398 |
+
num_register_tokens=num_register_tokens,
|
| 399 |
+
**kwargs,
|
| 400 |
+
)
|
| 401 |
+
return model
|
argus/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Realsee. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0.
|
argus/models/aggregator.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.checkpoint import checkpoint
|
| 6 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
| 7 |
+
from argus.layers import Mlp
|
| 8 |
+
from argus.layers import PatchEmbed
|
| 9 |
+
from argus.layers.block import Block
|
| 10 |
+
from argus.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
| 11 |
+
from argus.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
| 12 |
+
from argus.heads.utils import reorder_by_reference
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 17 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Aggregator(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Args:
|
| 23 |
+
img_size (int): Image size in pixels.
|
| 24 |
+
patch_size (int): Size of each patch for PatchEmbed.
|
| 25 |
+
embed_dim (int): Dimension of the token embeddings.
|
| 26 |
+
depth (int): Number of blocks.
|
| 27 |
+
num_heads (int): Number of attention heads.
|
| 28 |
+
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
| 29 |
+
num_register_tokens (int): Number of register tokens.
|
| 30 |
+
block_fn (nn.Module): The block type used for attention (Block by default).
|
| 31 |
+
qkv_bias (bool): Whether to include bias in QKV projections.
|
| 32 |
+
proj_bias (bool): Whether to include bias in the output projection.
|
| 33 |
+
ffn_bias (bool): Whether to include bias in MLP layers.
|
| 34 |
+
patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
|
| 35 |
+
aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
|
| 36 |
+
aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
|
| 37 |
+
qk_norm (bool): Whether to apply QK normalization.
|
| 38 |
+
rope_freq (int): Base frequency for rotary embedding. -1 to disable.
|
| 39 |
+
init_values (float): Init scale for layer scale.
|
| 40 |
+
reorder_by_learning_ref (bool): Whether to reorder features by learning reference view index.
|
| 41 |
+
ref_aa_block_num (int): Number of aa blocks for reference view learning.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
img_size=518,
|
| 47 |
+
patch_size=14,
|
| 48 |
+
embed_dim=1024,
|
| 49 |
+
depth=24,
|
| 50 |
+
num_heads=16,
|
| 51 |
+
mlp_ratio=4.0,
|
| 52 |
+
num_register_tokens=4,
|
| 53 |
+
block_fn=Block,
|
| 54 |
+
qkv_bias=True,
|
| 55 |
+
proj_bias=True,
|
| 56 |
+
ffn_bias=True,
|
| 57 |
+
patch_embed="dinov2_vitl14_reg",
|
| 58 |
+
aa_order=["frame", "global"],
|
| 59 |
+
aa_block_size=1,
|
| 60 |
+
qk_norm=True,
|
| 61 |
+
rope_freq=100,
|
| 62 |
+
init_values=0.01,
|
| 63 |
+
reorder_by_learning_ref=True,
|
| 64 |
+
ref_aa_block_num=2,
|
| 65 |
+
save_inference_memory=True,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.reorder_by_learning_ref = reorder_by_learning_ref
|
| 70 |
+
self.save_inference_memory = save_inference_memory
|
| 71 |
+
|
| 72 |
+
self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
|
| 73 |
+
|
| 74 |
+
# Initialize rotary position embedding if frequency > 0
|
| 75 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
| 76 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
| 77 |
+
|
| 78 |
+
self.frame_blocks = nn.ModuleList(
|
| 79 |
+
[
|
| 80 |
+
block_fn(
|
| 81 |
+
dim=embed_dim,
|
| 82 |
+
num_heads=num_heads,
|
| 83 |
+
mlp_ratio=mlp_ratio,
|
| 84 |
+
qkv_bias=qkv_bias,
|
| 85 |
+
proj_bias=proj_bias,
|
| 86 |
+
ffn_bias=ffn_bias,
|
| 87 |
+
init_values=init_values,
|
| 88 |
+
qk_norm=qk_norm,
|
| 89 |
+
rope=self.rope,
|
| 90 |
+
)
|
| 91 |
+
for _ in range(depth)
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.global_blocks = nn.ModuleList(
|
| 96 |
+
[
|
| 97 |
+
block_fn(
|
| 98 |
+
dim=embed_dim,
|
| 99 |
+
num_heads=num_heads,
|
| 100 |
+
mlp_ratio=mlp_ratio,
|
| 101 |
+
qkv_bias=qkv_bias,
|
| 102 |
+
proj_bias=proj_bias,
|
| 103 |
+
ffn_bias=ffn_bias,
|
| 104 |
+
init_values=init_values,
|
| 105 |
+
qk_norm=qk_norm,
|
| 106 |
+
rope=self.rope,
|
| 107 |
+
)
|
| 108 |
+
for _ in range(depth)
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.depth = depth
|
| 113 |
+
self.aa_order = aa_order
|
| 114 |
+
self.patch_size = patch_size
|
| 115 |
+
self.aa_block_size = aa_block_size
|
| 116 |
+
|
| 117 |
+
# Validate that depth is divisible by aa_block_size
|
| 118 |
+
if self.depth % self.aa_block_size != 0:
|
| 119 |
+
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
| 120 |
+
|
| 121 |
+
self.aa_block_num = self.depth // self.aa_block_size
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# Reference Learning Network
|
| 125 |
+
if self.reorder_by_learning_ref:
|
| 126 |
+
self.ref_aa_block_num = ref_aa_block_num
|
| 127 |
+
self.ref_frame_blocks = nn.ModuleList(
|
| 128 |
+
[
|
| 129 |
+
block_fn(
|
| 130 |
+
dim=embed_dim,
|
| 131 |
+
num_heads=num_heads,
|
| 132 |
+
mlp_ratio=mlp_ratio,
|
| 133 |
+
qkv_bias=qkv_bias,
|
| 134 |
+
proj_bias=proj_bias,
|
| 135 |
+
ffn_bias=ffn_bias,
|
| 136 |
+
init_values=init_values,
|
| 137 |
+
qk_norm=qk_norm,
|
| 138 |
+
rope=self.rope,
|
| 139 |
+
)
|
| 140 |
+
for _ in range(self.ref_aa_block_num)
|
| 141 |
+
]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.ref_global_blocks = nn.ModuleList(
|
| 145 |
+
[
|
| 146 |
+
block_fn(
|
| 147 |
+
dim=embed_dim,
|
| 148 |
+
num_heads=num_heads,
|
| 149 |
+
mlp_ratio=mlp_ratio,
|
| 150 |
+
qkv_bias=qkv_bias,
|
| 151 |
+
proj_bias=proj_bias,
|
| 152 |
+
ffn_bias=ffn_bias,
|
| 153 |
+
init_values=init_values,
|
| 154 |
+
qk_norm=qk_norm,
|
| 155 |
+
rope=self.rope,
|
| 156 |
+
)
|
| 157 |
+
for _ in range(self.ref_aa_block_num)
|
| 158 |
+
]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Note: We have two camera tokens, one for the first frame and one for the rest
|
| 162 |
+
# The same applies for register tokens
|
| 163 |
+
self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
|
| 164 |
+
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
|
| 165 |
+
|
| 166 |
+
if self.reorder_by_learning_ref:
|
| 167 |
+
# describe the covisibility of the current frame with other frames
|
| 168 |
+
self.covisibility_token = nn.Parameter(torch.randn(1, 1, 1, embed_dim))
|
| 169 |
+
|
| 170 |
+
# The patch tokens start after the camera and register tokens
|
| 171 |
+
self.patch_start_idx = 1 + num_register_tokens
|
| 172 |
+
|
| 173 |
+
# Initialize parameters with small values
|
| 174 |
+
nn.init.normal_(self.camera_token, std=1e-6)
|
| 175 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
| 176 |
+
if self.reorder_by_learning_ref:
|
| 177 |
+
nn.init.normal_(self.covisibility_token, std=1e-6)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# Register normalization constants as buffers
|
| 181 |
+
for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
|
| 182 |
+
self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
|
| 183 |
+
|
| 184 |
+
self.use_reentrant = False # hardcoded to False
|
| 185 |
+
|
| 186 |
+
def __build_patch_embed__(
|
| 187 |
+
self,
|
| 188 |
+
patch_embed,
|
| 189 |
+
img_size,
|
| 190 |
+
patch_size,
|
| 191 |
+
num_register_tokens,
|
| 192 |
+
interpolate_antialias=True,
|
| 193 |
+
interpolate_offset=0.0,
|
| 194 |
+
block_chunks=0,
|
| 195 |
+
init_values=1.0,
|
| 196 |
+
embed_dim=1024,
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Build the patch embed layer. If 'conv', we use a
|
| 200 |
+
simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
if "conv" in patch_embed:
|
| 204 |
+
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
|
| 205 |
+
else:
|
| 206 |
+
vit_models = {
|
| 207 |
+
"dinov2_vitl14_reg": vit_large,
|
| 208 |
+
"dinov2_vitb14_reg": vit_base,
|
| 209 |
+
"dinov2_vits14_reg": vit_small,
|
| 210 |
+
"dinov2_vitg2_reg": vit_giant2,
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
self.patch_embed = vit_models[patch_embed](
|
| 214 |
+
img_size=img_size,
|
| 215 |
+
patch_size=patch_size,
|
| 216 |
+
num_register_tokens=num_register_tokens,
|
| 217 |
+
interpolate_antialias=interpolate_antialias,
|
| 218 |
+
interpolate_offset=interpolate_offset,
|
| 219 |
+
block_chunks=block_chunks,
|
| 220 |
+
init_values=init_values,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Disable gradient updates for mask token
|
| 224 |
+
if hasattr(self.patch_embed, "mask_token"):
|
| 225 |
+
# self.patch_embed.mask_token.requires_grad_(False)
|
| 226 |
+
del self.patch_embed.mask_token
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# covisibility head
|
| 231 |
+
if self.reorder_by_learning_ref:
|
| 232 |
+
self.token_norm = nn.LayerNorm(embed_dim * 2)
|
| 233 |
+
self.covisibility_head = Mlp(in_features=embed_dim * 2, hidden_features=embed_dim * 2 // 2, out_features=1, drop=0)
|
| 234 |
+
|
| 235 |
+
def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]:
|
| 236 |
+
"""
|
| 237 |
+
Args:
|
| 238 |
+
images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 239 |
+
B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
(list[torch.Tensor], int):
|
| 243 |
+
The list of outputs from the attention blocks,
|
| 244 |
+
and the patch_start_idx indicating where patch tokens begin.
|
| 245 |
+
"""
|
| 246 |
+
B, S, C_in, H, W = images.shape
|
| 247 |
+
|
| 248 |
+
if C_in != 3:
|
| 249 |
+
raise ValueError(f"Expected 3 input channels, got {C_in}")
|
| 250 |
+
|
| 251 |
+
# Normalize images and reshape for patch embed
|
| 252 |
+
images = (images - self._resnet_mean) / self._resnet_std
|
| 253 |
+
|
| 254 |
+
# Reshape to [B*S, C, H, W] for patch embedding
|
| 255 |
+
images = images.view(B * S, C_in, H, W)
|
| 256 |
+
patch_tokens = self.patch_embed(images)
|
| 257 |
+
|
| 258 |
+
if isinstance(patch_tokens, dict):
|
| 259 |
+
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
| 260 |
+
|
| 261 |
+
_, P, C = patch_tokens.shape
|
| 262 |
+
|
| 263 |
+
################# ref learning
|
| 264 |
+
covisibility_scores = None
|
| 265 |
+
ref_idx = None
|
| 266 |
+
if self.reorder_by_learning_ref:
|
| 267 |
+
# expand covisibility token to match batch size and sequence length
|
| 268 |
+
covisibility_token = self.covisibility_token.expand(B, S, 1, C).view(B * S, 1, C).contiguous()
|
| 269 |
+
# Concatenate covisibility token with patch tokens
|
| 270 |
+
covisibility_patch_tokens = torch.cat([covisibility_token, patch_tokens], dim=1) # [BS,1+HW,C]
|
| 271 |
+
|
| 272 |
+
covisibility_pos = None
|
| 273 |
+
if self.rope is not None:
|
| 274 |
+
covisibility_pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
|
| 275 |
+
|
| 276 |
+
# do not use position embedding for special covisibility_token
|
| 277 |
+
# so set pos to 0 for the special tokens
|
| 278 |
+
covisibility_pos = covisibility_pos + 1
|
| 279 |
+
covisibility_pos_special = torch.zeros(B * S, 1, 2).to(images.device).to(covisibility_pos.dtype)
|
| 280 |
+
covisibility_pos = torch.cat([covisibility_pos_special, covisibility_pos], dim=1) # [BS, 1+HW, 2]
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# update P because we added special tokens
|
| 284 |
+
_, P_covis, C_covis = covisibility_patch_tokens.shape
|
| 285 |
+
|
| 286 |
+
frame_idx = 0
|
| 287 |
+
global_idx = 0
|
| 288 |
+
output_list = []
|
| 289 |
+
|
| 290 |
+
for ref_block_i in range(self.ref_aa_block_num):
|
| 291 |
+
for attn_type in self.aa_order:
|
| 292 |
+
if attn_type == "frame":
|
| 293 |
+
covisibility_patch_tokens, frame_idx, frame_intermediates = self._ref_process_frame_attention(
|
| 294 |
+
covisibility_patch_tokens, B, S, P_covis, C_covis, frame_idx, pos=covisibility_pos
|
| 295 |
+
)
|
| 296 |
+
elif attn_type == "global":
|
| 297 |
+
covisibility_patch_tokens, global_idx, global_intermediates = self._ref_process_global_attention(
|
| 298 |
+
covisibility_patch_tokens, B, S, P_covis, C_covis, global_idx, pos=covisibility_pos
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 302 |
+
|
| 303 |
+
for i in range(len(frame_intermediates)):
|
| 304 |
+
# concat frame and global intermediates, [B x S x P x 2C]
|
| 305 |
+
concat_inter = torch.cat([frame_intermediates[-1], global_intermediates[-1]], dim=-1)
|
| 306 |
+
output_list.append(concat_inter)
|
| 307 |
+
|
| 308 |
+
last_covisibility_patch_tokens = output_list[-1][:,:,0,:] # [B, S, C]
|
| 309 |
+
# normalize
|
| 310 |
+
last_covisibility_patch_tokens = self.token_norm(last_covisibility_patch_tokens)
|
| 311 |
+
|
| 312 |
+
covisibility_scores = self.covisibility_head(last_covisibility_patch_tokens).squeeze(-1) # [B, S]
|
| 313 |
+
# # cos
|
| 314 |
+
# feat_norm = F.normalize(covisibility_features, p=2, dim=-1, eps=1e-8) # [B, S, D]
|
| 315 |
+
# covisibility_scores = feat_norm @ feat_norm.transpose(-1, -2)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
ref_idx = covisibility_scores.argmax(-1) # [B, S] -> [B]
|
| 319 |
+
patch_tokens = patch_tokens.view(B,S,P,C)
|
| 320 |
+
patch_tokens = reorder_by_reference(patch_tokens, ref_idx)
|
| 321 |
+
patch_tokens = patch_tokens.view(B*S,P,C).contiguous()
|
| 322 |
+
|
| 323 |
+
####################
|
| 324 |
+
# Expand camera and register tokens to match batch size and sequence length
|
| 325 |
+
camera_token = slice_expand_and_flatten(self.camera_token, B, S)
|
| 326 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S)
|
| 327 |
+
# Concatenate special tokens with patch tokens
|
| 328 |
+
tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) # [BS,1+4+HW,C]
|
| 329 |
+
|
| 330 |
+
pos = None
|
| 331 |
+
if self.rope is not None:
|
| 332 |
+
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
|
| 333 |
+
|
| 334 |
+
if self.patch_start_idx > 0:
|
| 335 |
+
# do not use position embedding for special tokens (camera and register tokens)
|
| 336 |
+
# so set pos to 0 for the special tokens
|
| 337 |
+
pos = pos + 1
|
| 338 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
|
| 339 |
+
pos = torch.cat([pos_special, pos], dim=1) # [BS, 1+4+HW, 2]
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# update P because we added special tokens
|
| 344 |
+
_, P, C = tokens.shape
|
| 345 |
+
|
| 346 |
+
frame_idx = 0
|
| 347 |
+
global_idx = 0
|
| 348 |
+
output_list = []
|
| 349 |
+
|
| 350 |
+
for block_i in range(self.aa_block_num):
|
| 351 |
+
for attn_type in self.aa_order:
|
| 352 |
+
if attn_type == "frame":
|
| 353 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
| 354 |
+
tokens, B, S, P, C, frame_idx, pos=pos
|
| 355 |
+
)
|
| 356 |
+
elif attn_type == "global":
|
| 357 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
| 358 |
+
tokens, B, S, P, C, global_idx, pos=pos
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
| 362 |
+
|
| 363 |
+
for i in range(len(frame_intermediates)):
|
| 364 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
| 365 |
+
if (not self.training ) and (self.save_inference_memory) and (block_i not in [4,11,17,23]):
|
| 366 |
+
# only save the useful indices of intermediates
|
| 367 |
+
output_list.append(torch.tensor(0))
|
| 368 |
+
else:
|
| 369 |
+
# concat frame and global intermediates, [B x S x P x 2C]
|
| 370 |
+
output_list.append(concat_inter)
|
| 371 |
+
|
| 372 |
+
del concat_inter
|
| 373 |
+
del frame_intermediates
|
| 374 |
+
del global_intermediates
|
| 375 |
+
return output_list, self.patch_start_idx, covisibility_scores, ref_idx
|
| 376 |
+
|
| 377 |
+
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
|
| 378 |
+
"""
|
| 379 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
| 380 |
+
"""
|
| 381 |
+
# If needed, reshape tokens or positions:
|
| 382 |
+
if tokens.shape != (B * S, P, C):
|
| 383 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
| 384 |
+
|
| 385 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
| 386 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
| 387 |
+
|
| 388 |
+
intermediates = []
|
| 389 |
+
|
| 390 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 391 |
+
for _ in range(self.aa_block_size):
|
| 392 |
+
if self.training:
|
| 393 |
+
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
| 394 |
+
else:
|
| 395 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
|
| 396 |
+
frame_idx += 1
|
| 397 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 398 |
+
|
| 399 |
+
return tokens, frame_idx, intermediates
|
| 400 |
+
|
| 401 |
+
def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
|
| 402 |
+
"""
|
| 403 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
| 404 |
+
"""
|
| 405 |
+
if tokens.shape != (B, S * P, C):
|
| 406 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
| 407 |
+
|
| 408 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
| 409 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
| 410 |
+
|
| 411 |
+
intermediates = []
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 415 |
+
for _ in range(self.aa_block_size):
|
| 416 |
+
if self.training:
|
| 417 |
+
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
| 418 |
+
else:
|
| 419 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos)
|
| 420 |
+
global_idx += 1
|
| 421 |
+
|
| 422 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
return tokens, global_idx, intermediates
|
| 427 |
+
|
| 428 |
+
def _ref_process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
|
| 429 |
+
"""
|
| 430 |
+
Process frame attention blocks. We keep tokens in shape (B*S, P, C).
|
| 431 |
+
"""
|
| 432 |
+
# If needed, reshape tokens or positions:
|
| 433 |
+
if tokens.shape != (B * S, P, C):
|
| 434 |
+
tokens = tokens.view(B, S, P, C).view(B * S, P, C)
|
| 435 |
+
|
| 436 |
+
if pos is not None and pos.shape != (B * S, P, 2):
|
| 437 |
+
pos = pos.view(B, S, P, 2).view(B * S, P, 2)
|
| 438 |
+
|
| 439 |
+
intermediates = []
|
| 440 |
+
|
| 441 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 442 |
+
for _ in range(self.aa_block_size):
|
| 443 |
+
if self.training:
|
| 444 |
+
tokens = checkpoint(self.ref_frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
| 445 |
+
else:
|
| 446 |
+
tokens = self.ref_frame_blocks[frame_idx](tokens, pos=pos)
|
| 447 |
+
frame_idx += 1
|
| 448 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 449 |
+
|
| 450 |
+
return tokens, frame_idx, intermediates
|
| 451 |
+
|
| 452 |
+
def _ref_process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
|
| 453 |
+
"""
|
| 454 |
+
Process global attention blocks. We keep tokens in shape (B, S*P, C).
|
| 455 |
+
"""
|
| 456 |
+
if tokens.shape != (B, S * P, C):
|
| 457 |
+
tokens = tokens.view(B, S, P, C).view(B, S * P, C)
|
| 458 |
+
|
| 459 |
+
if pos is not None and pos.shape != (B, S * P, 2):
|
| 460 |
+
pos = pos.view(B, S, P, 2).view(B, S * P, 2)
|
| 461 |
+
|
| 462 |
+
intermediates = []
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# by default, self.aa_block_size=1, which processes one block at a time
|
| 466 |
+
for _ in range(self.aa_block_size):
|
| 467 |
+
if self.training:
|
| 468 |
+
tokens = checkpoint(self.ref_global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
| 469 |
+
else:
|
| 470 |
+
tokens = self.ref_global_blocks[global_idx](tokens, pos=pos)
|
| 471 |
+
global_idx += 1
|
| 472 |
+
|
| 473 |
+
intermediates.append(tokens.view(B, S, P, C))
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
return tokens, global_idx, intermediates
|
| 478 |
+
|
| 479 |
+
def slice_expand_and_flatten(token_tensor, B, S):
|
| 480 |
+
"""
|
| 481 |
+
Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
|
| 482 |
+
1) Uses the first position (index=0) for the first frame only
|
| 483 |
+
2) Uses the second position (index=1) for all remaining frames (S-1 frames)
|
| 484 |
+
3) Expands both to match batch size B
|
| 485 |
+
4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
|
| 486 |
+
followed by (S-1) second-position tokens
|
| 487 |
+
5) Flattens to (B*S, X, C) for processing
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
torch.Tensor: Processed tokens with shape (B*S, X, C)
|
| 491 |
+
"""
|
| 492 |
+
|
| 493 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
| 494 |
+
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
|
| 495 |
+
# Slice out the "other" tokens => shape (1, S-1, ...)
|
| 496 |
+
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
|
| 497 |
+
# Concatenate => shape (B, S, ...)
|
| 498 |
+
combined = torch.cat([query, others], dim=1)
|
| 499 |
+
|
| 500 |
+
# Finally flatten => shape (B*S, ...)
|
| 501 |
+
combined = combined.view(B * S, *combined.shape[2:])
|
| 502 |
+
return combined
|
argus/models/argus.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Dict
|
| 4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 5 |
+
|
| 6 |
+
# Import model components
|
| 7 |
+
from argus.models.aggregator import Aggregator
|
| 8 |
+
from argus.heads.camera_head import CameraHead
|
| 9 |
+
from argus.heads.dpt_head import DPTHead
|
| 10 |
+
from argus.heads.utils import reorder_by_reference
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Argus(nn.Module, PyTorchModelHubMixin):
|
| 14 |
+
"""
|
| 15 |
+
Argus multi-task vision model for camera pose estimation, depth prediction, and 3D points.
|
| 16 |
+
|
| 17 |
+
Integrates an aggregator backbone with task-specific heads for:
|
| 18 |
+
- Camera pose encoding
|
| 19 |
+
- Depth map prediction
|
| 20 |
+
- 3D camera/rotated/world point prediction
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
img_size: Input image size (height/width, assumes square) (default: 518)
|
| 24 |
+
patch_size: Patch size for vision transformer backbone (default: 14)
|
| 25 |
+
embed_dim: Embedding dimension for transformer features (default: 1024)
|
| 26 |
+
enable_camera: Enable camera pose estimation head (default: True)
|
| 27 |
+
enable_depth: Enable depth prediction head (default: True)
|
| 28 |
+
enable_cam_point: Enable camera coordinate 3D point prediction head (default: False)
|
| 29 |
+
enable_rotated_point: Enable rotated 3D point prediction head (default: False)
|
| 30 |
+
enable_point: Enable world coordinate 3D point prediction head (default: False, Please do not set it to True during training)
|
| 31 |
+
|
| 32 |
+
Note:
|
| 33 |
+
All heads share the same aggregated transformer features from the Aggregator backbone.
|
| 34 |
+
Each DPT-based head outputs both predictions and confidence scores.
|
| 35 |
+
"""
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
img_size: int = 518,
|
| 39 |
+
patch_size: int = 14,
|
| 40 |
+
embed_dim: int = 1024,
|
| 41 |
+
enable_camera: bool = True,
|
| 42 |
+
enable_depth: bool = True,
|
| 43 |
+
enable_cam_point: bool = False,
|
| 44 |
+
enable_rotated_point: bool = False,
|
| 45 |
+
enable_point: bool = False,
|
| 46 |
+
reorder_by_learning_ref: bool = True,
|
| 47 |
+
restore_metric_scale: bool = False
|
| 48 |
+
) -> None:
|
| 49 |
+
super().__init__()
|
| 50 |
+
# For inference
|
| 51 |
+
self.restore_metric_scale = restore_metric_scale
|
| 52 |
+
self.reorder_by_learning_ref = reorder_by_learning_ref
|
| 53 |
+
|
| 54 |
+
# Backbone and geometry transformer
|
| 55 |
+
self.aggregator = Aggregator(
|
| 56 |
+
img_size=img_size,
|
| 57 |
+
patch_size=patch_size,
|
| 58 |
+
embed_dim=embed_dim,
|
| 59 |
+
reorder_by_learning_ref=reorder_by_learning_ref,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Task-specific prediction heads (lazy initialization based on flags)
|
| 63 |
+
self.camera_head: Optional[CameraHead] = CameraHead(dim_in=2 * embed_dim) if enable_camera else None
|
| 64 |
+
self.depth_head: Optional[DPTHead] = DPTHead(
|
| 65 |
+
dim_in=2 * embed_dim,
|
| 66 |
+
output_dim=2,
|
| 67 |
+
activation="exp",
|
| 68 |
+
conf_activation="expp1"
|
| 69 |
+
) if enable_depth else None
|
| 70 |
+
|
| 71 |
+
# 3D point prediction heads (shared architecture, different output semantics)
|
| 72 |
+
self.cam_point_head: Optional[DPTHead] = DPTHead(
|
| 73 |
+
dim_in=2 * embed_dim,
|
| 74 |
+
output_dim=4,
|
| 75 |
+
activation="inv_log",
|
| 76 |
+
conf_activation="expp1"
|
| 77 |
+
) if enable_cam_point else None
|
| 78 |
+
|
| 79 |
+
self.rotated_point_head: Optional[DPTHead] = DPTHead(
|
| 80 |
+
dim_in=2 * embed_dim,
|
| 81 |
+
output_dim=4,
|
| 82 |
+
activation="inv_log",
|
| 83 |
+
conf_activation="expp1"
|
| 84 |
+
) if enable_rotated_point else None
|
| 85 |
+
|
| 86 |
+
self.point_head: Optional[DPTHead] = DPTHead(
|
| 87 |
+
dim_in=2 * embed_dim,
|
| 88 |
+
output_dim=4,
|
| 89 |
+
activation="inv_log",
|
| 90 |
+
conf_activation="expp1"
|
| 91 |
+
) if enable_point else None
|
| 92 |
+
|
| 93 |
+
def forward(
|
| 94 |
+
self,
|
| 95 |
+
images: torch.Tensor,
|
| 96 |
+
) -> Dict[str, torch.Tensor]:
|
| 97 |
+
"""
|
| 98 |
+
Forward pass of the Argus model.
|
| 99 |
+
|
| 100 |
+
Automatically adds batch dimension if missing and processes multi-task predictions.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
images: Input RGB images with shape:
|
| 104 |
+
- [S, 3, H, W] (sequence without batch) or
|
| 105 |
+
- [B, S, 3, H, W] (batch of sequences)
|
| 106 |
+
Values in range [0, 1], where:
|
| 107 |
+
- B: batch size
|
| 108 |
+
- S: sequence length (number of frames)
|
| 109 |
+
- 3: RGB channels
|
| 110 |
+
- H/W: image height/width (matches img_size)
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Dictionary of model predictions with task-specific outputs:
|
| 114 |
+
Common outputs:
|
| 115 |
+
- covisibility_scores: Covisibility scores from aggregator (shape varies)
|
| 116 |
+
- ref_idx: Reference frame indices (shape varies)
|
| 117 |
+
|
| 118 |
+
Camera head outputs (if enabled):
|
| 119 |
+
- pose_enc: Final camera pose encoding [B, S, 9]
|
| 120 |
+
- pose_enc_list: List of pose encodings from all iterations [List[torch.Tensor]]
|
| 121 |
+
|
| 122 |
+
Depth head outputs (if enabled):
|
| 123 |
+
- depth: Predicted depth maps [B, S, H, W, 1]
|
| 124 |
+
- depth_conf: Depth prediction confidence [B, S, H, W]
|
| 125 |
+
|
| 126 |
+
Camera point head outputs (if enabled):
|
| 127 |
+
- cam_points: 3D camera coordinates per pixel [B, S, H, W, 3]
|
| 128 |
+
- cam_points_conf: Camera point confidence [B, S, H, W]
|
| 129 |
+
|
| 130 |
+
Rotated point head outputs (if enabled):
|
| 131 |
+
- rotated_points: Rotated 3D coordinates per pixel [B, S, H, W, 3]
|
| 132 |
+
- rotated_points_conf: Rotated point confidence [B, S, H, W]
|
| 133 |
+
|
| 134 |
+
World point head outputs (if enabled):
|
| 135 |
+
- world_points: 3D world coordinates per pixel [B, S, H, W, 3]
|
| 136 |
+
- world_points_conf: World point confidence [B, S, H, W]
|
| 137 |
+
|
| 138 |
+
Inference-only outputs (not training):
|
| 139 |
+
- images: Original input images (for visualization) [B, S, 3, H, W]
|
| 140 |
+
"""
|
| 141 |
+
# Add batch dimension if missing (handle [S,3,H,W] -> [1,S,3,H,W])
|
| 142 |
+
if len(images.shape) == 4:
|
| 143 |
+
images = images.unsqueeze(0)
|
| 144 |
+
|
| 145 |
+
# Extract aggregated features from backbone
|
| 146 |
+
(
|
| 147 |
+
aggregated_tokens_list, # List of aggregated transformer tokens across iterations
|
| 148 |
+
patch_start_idx, # Patch start indices for feature reconstruction
|
| 149 |
+
covisibility_scores, # Covisibility scores between frames
|
| 150 |
+
ref_idx # Reference frame indices
|
| 151 |
+
) = self.aggregator(images)
|
| 152 |
+
|
| 153 |
+
# Initialize prediction dictionary
|
| 154 |
+
predictions: Dict[str, torch.Tensor] = {}
|
| 155 |
+
|
| 156 |
+
# Disable mixed precision for precise prediction calculations
|
| 157 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 158 |
+
# Add aggregator outputs to predictions
|
| 159 |
+
if covisibility_scores is not None:
|
| 160 |
+
predictions["covisibility_scores"] = covisibility_scores
|
| 161 |
+
if ref_idx is not None:
|
| 162 |
+
predictions["ref_idx"] = ref_idx
|
| 163 |
+
|
| 164 |
+
# Camera pose prediction (if enabled)
|
| 165 |
+
if self.camera_head is not None:
|
| 166 |
+
pose_enc_list = self.camera_head(aggregated_tokens_list)
|
| 167 |
+
predictions["pose_enc"] = pose_enc_list[-1] # Use final iteration encoding
|
| 168 |
+
predictions["pose_enc_list"] = pose_enc_list # Mutil-layer supervision
|
| 169 |
+
|
| 170 |
+
# Depth prediction (if enabled)
|
| 171 |
+
if self.depth_head is not None:
|
| 172 |
+
depth, depth_conf = self.depth_head(
|
| 173 |
+
aggregated_tokens_list,
|
| 174 |
+
images=images,
|
| 175 |
+
patch_start_idx=patch_start_idx
|
| 176 |
+
)
|
| 177 |
+
predictions["depth"] = depth
|
| 178 |
+
predictions["depth_conf"] = depth_conf
|
| 179 |
+
|
| 180 |
+
# Camera 3D point prediction (if enabled)
|
| 181 |
+
if self.cam_point_head is not None:
|
| 182 |
+
cam_pts3d, cam_pts3d_conf = self.cam_point_head(
|
| 183 |
+
aggregated_tokens_list,
|
| 184 |
+
images=images,
|
| 185 |
+
patch_start_idx=patch_start_idx
|
| 186 |
+
)
|
| 187 |
+
predictions["cam_points"] = cam_pts3d
|
| 188 |
+
predictions["cam_points_conf"] = cam_pts3d_conf
|
| 189 |
+
|
| 190 |
+
# Rotated 3D point prediction (if enabled)
|
| 191 |
+
if self.rotated_point_head is not None:
|
| 192 |
+
rotated_pts3d, rotated_pts3d_conf = self.rotated_point_head(
|
| 193 |
+
aggregated_tokens_list,
|
| 194 |
+
images=images,
|
| 195 |
+
patch_start_idx=patch_start_idx
|
| 196 |
+
)
|
| 197 |
+
predictions["rotated_points"] = rotated_pts3d
|
| 198 |
+
predictions["rotated_points_conf"] = rotated_pts3d_conf
|
| 199 |
+
|
| 200 |
+
# World 3D point prediction (if enabled)
|
| 201 |
+
if self.point_head is not None:
|
| 202 |
+
world_pts3d, world_pts3d_conf = self.point_head(
|
| 203 |
+
aggregated_tokens_list,
|
| 204 |
+
images=images,
|
| 205 |
+
patch_start_idx=patch_start_idx
|
| 206 |
+
)
|
| 207 |
+
predictions["world_points"] = world_pts3d
|
| 208 |
+
predictions["world_points_conf"] = world_pts3d_conf
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# Store input images for visualization during inference (skip in training)
|
| 212 |
+
if not self.training:
|
| 213 |
+
predictions["images"] = images
|
| 214 |
+
if "ref_idx" in predictions:
|
| 215 |
+
ref_idx = predictions["ref_idx"].detach()
|
| 216 |
+
# Reorder all spatial/temporal data (exclude adjacency matrix and IDs)
|
| 217 |
+
predictions["images"] = reorder_by_reference(predictions["images"], ref_idx)
|
| 218 |
+
|
| 219 |
+
if self.restore_metric_scale:
|
| 220 |
+
# Restore metric scale
|
| 221 |
+
abs_scale = 10.0
|
| 222 |
+
if self.camera_head is not None:
|
| 223 |
+
predictions["pose_enc"][...,:3] *= abs_scale
|
| 224 |
+
if self.depth_head is not None:
|
| 225 |
+
predictions["depth"] *= abs_scale
|
| 226 |
+
if self.cam_point_head is not None:
|
| 227 |
+
predictions["cam_points"] *= abs_scale
|
| 228 |
+
if self.rotated_point_head is not None:
|
| 229 |
+
predictions["rotated_points"] *= abs_scale
|
| 230 |
+
if self.point_head is not None:
|
| 231 |
+
predictions["world_points"] *= abs_scale
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
return predictions
|
argus/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Realsee. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0.
|
argus/utils/data_io.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Realsee. All rights reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0.
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Shared I/O and preprocessing utilities for panoramic image data.
|
| 6 |
+
|
| 7 |
+
These functions are used by both evaluation and training pipelines.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def read_image_cv2_360(path: str, rgb: bool = True, shape=(560, 280)) -> np.ndarray:
|
| 18 |
+
"""Read and resize a 360 panorama image.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
path: Path to the image file.
|
| 22 |
+
rgb: If True, convert BGR to RGB (default: True).
|
| 23 |
+
shape: Target (width, height) tuple.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Image as numpy array with shape (H, W, 3).
|
| 27 |
+
"""
|
| 28 |
+
img = cv2.imread(path)
|
| 29 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 30 |
+
if img.shape[1] != shape[0]:
|
| 31 |
+
img = cv2.resize(img, shape, interpolation=cv2.INTER_AREA)
|
| 32 |
+
return img
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def read_depth_360(path: str, depth_scale=5000.0, shape=(560, 280)) -> np.ndarray:
|
| 36 |
+
"""Read and normalize a 360 depth map.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
path: Path to the depth image file.
|
| 40 |
+
depth_scale: Scale factor to convert raw depth to meters.
|
| 41 |
+
shape: Target (width, height) tuple.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Depth map as float32 numpy array with shape (H, W).
|
| 45 |
+
"""
|
| 46 |
+
d = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
| 47 |
+
if d.shape[1] != shape[0]:
|
| 48 |
+
d = cv2.resize(d, shape, interpolation=cv2.INTER_NEAREST)
|
| 49 |
+
d = d.astype(np.float32) / depth_scale
|
| 50 |
+
return d
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def random_rotate_theta(W=560, max_shift_percent=0.5):
|
| 54 |
+
"""Generate a random rotation angle for panorama augmentation.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
W: Panorama width in pixels.
|
| 58 |
+
max_shift_percent: Maximum horizontal shift as fraction of width.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Rotation angle in radians.
|
| 62 |
+
"""
|
| 63 |
+
max_shift = int(W * max_shift_percent)
|
| 64 |
+
shift_pixels = np.random.randint(-max_shift, max_shift + 1)
|
| 65 |
+
theta = (shift_pixels * 2 * np.pi) / W
|
| 66 |
+
return theta
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def rotate_y(theta):
|
| 70 |
+
"""Create a 3x3 rotation matrix around the Y-axis.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
theta: Rotation angle in radians.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
3x3 rotation matrix as float64 numpy array.
|
| 77 |
+
"""
|
| 78 |
+
cos_theta = np.cos(theta)
|
| 79 |
+
sin_theta = np.sin(theta)
|
| 80 |
+
return np.array(
|
| 81 |
+
[[cos_theta, 0, -sin_theta], [0, 1, 0], [sin_theta, 0, cos_theta]],
|
| 82 |
+
dtype=np.float64,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def pano_depth_to_points(depth_map, pano_shape=(560, 280), crop=True, crop_ratio=0.15):
|
| 87 |
+
"""Convert a panorama depth map to 3D point cloud.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
depth_map: 2D depth map (H, W) or flattened array.
|
| 91 |
+
pano_shape: Original panorama (width, height) tuple.
|
| 92 |
+
crop: Whether the depth map has been vertically cropped.
|
| 93 |
+
crop_ratio: Crop ratio applied to top and bottom.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Point cloud as numpy array with shape (N, 3).
|
| 97 |
+
"""
|
| 98 |
+
w, h = pano_shape
|
| 99 |
+
|
| 100 |
+
if not crop:
|
| 101 |
+
px = np.tile(np.arange(w), int(h))
|
| 102 |
+
py = np.arange(0, int(h)).repeat(w)
|
| 103 |
+
else:
|
| 104 |
+
px = np.tile(np.arange(w), int(h * (1 - 2 * crop_ratio)))
|
| 105 |
+
py = np.arange(int(crop_ratio * h), int((1 - crop_ratio) * h)).repeat(w)
|
| 106 |
+
|
| 107 |
+
dist = depth_map.reshape(-1)
|
| 108 |
+
|
| 109 |
+
lat = (py / h - 0.5) * np.pi
|
| 110 |
+
long = (px / w - 0.5) * np.pi * 2.0
|
| 111 |
+
|
| 112 |
+
y = dist * np.sin(lat)
|
| 113 |
+
tmp = dist * np.cos(lat)
|
| 114 |
+
x = tmp * np.sin(long)
|
| 115 |
+
z = tmp * np.cos(long)
|
| 116 |
+
|
| 117 |
+
point_map = np.concatenate([i.reshape(-1, 1) for i in (x, y, z)], axis=-1)
|
| 118 |
+
|
| 119 |
+
return point_map # (h*w, 3)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def crop_panorama(pano, crop_ratio=0.15):
|
| 123 |
+
"""Crop the top and bottom of a panorama by a given ratio.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
pano: Input panorama array with shape (H, W, ...).
|
| 127 |
+
crop_ratio: Fraction to crop from top and bottom.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Cropped panorama.
|
| 131 |
+
"""
|
| 132 |
+
H, W = pano.shape[:2]
|
| 133 |
+
crop_H_top = int(crop_ratio * H)
|
| 134 |
+
crop_H_bottom = H - int(crop_ratio * H)
|
| 135 |
+
crop_pano = pano[crop_H_top:crop_H_bottom, ...]
|
| 136 |
+
return crop_pano
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def rotate_panorama(panorama, theta):
|
| 140 |
+
"""Horizontally rotate a panorama by shifting pixels.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
panorama: Input panorama array with shape (H, W, ...).
|
| 144 |
+
theta: Rotation angle in radians.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Shifted panorama.
|
| 148 |
+
"""
|
| 149 |
+
H, W = panorama.shape[:2]
|
| 150 |
+
shift_pixels = int((theta * W) / (2 * np.pi))
|
| 151 |
+
shifted = np.roll(panorama, shift_pixels, axis=1)
|
| 152 |
+
return shifted
|
argus/utils/geometry.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def closed_form_inverse_se3(se3, R=None, T=None):
|
| 6 |
+
"""
|
| 7 |
+
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
| 8 |
+
|
| 9 |
+
If `R` and `T` are provided, they must correspond to the rotation and translation
|
| 10 |
+
components of `se3`. Otherwise, they will be extracted from `se3`.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
|
| 14 |
+
R (optional): Nx3x3 array or tensor of rotation matrices.
|
| 15 |
+
T (optional): Nx3x1 array or tensor of translation vectors.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Inverted SE3 matrices with the same type and device as `se3`.
|
| 19 |
+
|
| 20 |
+
Shapes:
|
| 21 |
+
se3: (N, 4, 4)
|
| 22 |
+
R: (N, 3, 3)
|
| 23 |
+
T: (N, 3, 1)
|
| 24 |
+
"""
|
| 25 |
+
# Check if se3 is a numpy array or a torch tensor
|
| 26 |
+
is_numpy = isinstance(se3, np.ndarray)
|
| 27 |
+
|
| 28 |
+
# Validate shapes
|
| 29 |
+
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
| 30 |
+
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
| 31 |
+
|
| 32 |
+
# Extract R and T if not provided
|
| 33 |
+
if R is None:
|
| 34 |
+
R = se3[:, :3, :3] # (N,3,3)
|
| 35 |
+
if T is None:
|
| 36 |
+
T = se3[:, :3, 3:] # (N,3,1)
|
| 37 |
+
|
| 38 |
+
# Transpose R
|
| 39 |
+
if is_numpy:
|
| 40 |
+
# Compute the transpose of the rotation for NumPy
|
| 41 |
+
R_transposed = np.transpose(R, (0, 2, 1))
|
| 42 |
+
# -R^T t for NumPy
|
| 43 |
+
top_right = -np.matmul(R_transposed, T)
|
| 44 |
+
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
| 45 |
+
else:
|
| 46 |
+
R_transposed = R.transpose(1, 2) # (N,3,3)
|
| 47 |
+
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
|
| 48 |
+
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
| 49 |
+
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
| 50 |
+
|
| 51 |
+
inverted_matrix[:, :3, :3] = R_transposed
|
| 52 |
+
inverted_matrix[:, :3, 3:] = top_right
|
| 53 |
+
|
| 54 |
+
return inverted_matrix
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def pano_depth_to_points(depth_map, original_pano_shape=(560, 280), crop_ratio=0.15):
|
| 58 |
+
"""
|
| 59 |
+
Convert batched cropped panoramic depth maps to 3D point clouds (PyTorch implementation).
|
| 60 |
+
Assumption: Input depth maps are already cropped by crop_ratio on top and bottom.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
depth_map (torch.Tensor): Input cropped depth map, shape [B, S, H_crop, W, 1]
|
| 64 |
+
original_pano_shape (tuple): Original uncropped panorama size (W_ori, H_ori), default (560, 280)
|
| 65 |
+
crop_ratio (float): Crop ratio of original panorama (top and bottom respectively), default 0.15
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
torch.Tensor: 3D point cloud with shape [B, S, H_crop, W, 3]
|
| 69 |
+
"""
|
| 70 |
+
# Validate input shape
|
| 71 |
+
assert depth_map.dim() == 5 and depth_map.shape[-1] == 1, \
|
| 72 |
+
f"Input must be [B, S, H_crop, W, 1], got {depth_map.shape}"
|
| 73 |
+
|
| 74 |
+
B, S, H_crop, W, _ = depth_map.shape
|
| 75 |
+
W_ori, H_ori = original_pano_shape
|
| 76 |
+
device = depth_map.device # Align tensor device automatically
|
| 77 |
+
|
| 78 |
+
# Generate pixel grid coordinates (H_crop, W)
|
| 79 |
+
px_grid, py_grid = torch.meshgrid(
|
| 80 |
+
torch.arange(W, device=device),
|
| 81 |
+
torch.arange(H_crop, device=device),
|
| 82 |
+
indexing='xy' # Consistent with numpy's meshgrid
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Restore to original panorama y-coordinates (compensate for cropping)
|
| 86 |
+
crop_top = int(crop_ratio * H_ori)
|
| 87 |
+
py_ori = py_grid + crop_top
|
| 88 |
+
|
| 89 |
+
# Compute spherical coordinates (lat: latitude, long: longitude)
|
| 90 |
+
lat = (py_ori / H_ori - 0.5) * torch.pi
|
| 91 |
+
long = (px_grid / W_ori - 0.5) * 2 * torch.pi
|
| 92 |
+
|
| 93 |
+
# Remove channel dim and compute 3D Cartesian coordinates
|
| 94 |
+
dist = depth_map.squeeze(-1) # [B, S, H_crop, W]
|
| 95 |
+
y = dist * torch.sin(lat)
|
| 96 |
+
tmp = dist * torch.cos(lat)
|
| 97 |
+
x = tmp * torch.sin(long)
|
| 98 |
+
z = tmp * torch.cos(long)
|
| 99 |
+
|
| 100 |
+
# Concatenate to form 3D point cloud
|
| 101 |
+
point_cloud = torch.stack([x, y, z], dim=-1)
|
| 102 |
+
|
| 103 |
+
return point_cloud
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def points_to_pano_depth(points):
|
| 107 |
+
"""
|
| 108 |
+
Convert 3D point cloud back to ray panoramic depth map.
|
| 109 |
+
Ignore the error in direction.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
points (torch.Tensor): Input 3D point cloud, shape [B, S, H, W, 3]
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
torch.Tensor: panoramic depth map, shape [B, S, H, W, 1]
|
| 116 |
+
"""
|
| 117 |
+
# Validate input shape and fill mode
|
| 118 |
+
assert points.dim() == 5 and points.shape[-1] == 3, \
|
| 119 |
+
f"Input point cloud must be [B, S, H, W, 3], got {points.shape}"
|
| 120 |
+
|
| 121 |
+
# Compute radial depth (dist = sqrt(x² + y² + z²))
|
| 122 |
+
dist = torch.norm(points, dim=-1, keepdim=True) # [B, S, H, W, 1]
|
| 123 |
+
|
| 124 |
+
return dist
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def camera_points_to_rotated_points(cam_points, R):
|
| 128 |
+
"""
|
| 129 |
+
Rotate batched panoramic camera point clouds with corresponding rotation matrices.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
cam_points (torch.Tensor): Input camera 3D point cloud, shape [B, S, H, W, 3]
|
| 133 |
+
R (torch.Tensor): Corresponding rotation matrices, shape [B, S, 3, 3]
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
torch.Tensor: Rotated 3D point cloud, shape [B, S, H, W, 3] (same as input cam_points)
|
| 137 |
+
"""
|
| 138 |
+
# Validate input shapes and dimensions matching
|
| 139 |
+
assert cam_points.dim() == 5 and cam_points.shape[-1] == 3, \
|
| 140 |
+
f"Camera points must be [B, S, H, W, 3], got {cam_points.shape}"
|
| 141 |
+
assert R.dim() == 4 and R.shape[2:] == (3, 3), \
|
| 142 |
+
f"Rotation matrices R must be [B, S, 3, 3], got {R.shape}"
|
| 143 |
+
assert cam_points.shape[:2] == R.shape[:2], \
|
| 144 |
+
f"Batch/Sequence dim mismatch: cam_points {cam_points.shape[:2]} vs R {R.shape[:2]}"
|
| 145 |
+
|
| 146 |
+
# Expand dimensions for broadcasting (align spatial dimensions H, W)
|
| 147 |
+
cam_points_expanded = cam_points.unsqueeze(-1) # [B, S, H, W, 3, 1]
|
| 148 |
+
R_expanded = R.unsqueeze(2).unsqueeze(2) # [B, S, 1, 1, 3, 3]
|
| 149 |
+
|
| 150 |
+
# Batch matrix multiplication: R @ p (rotation operation)
|
| 151 |
+
rotated_points_expanded = torch.matmul(R_expanded, cam_points_expanded)
|
| 152 |
+
|
| 153 |
+
# Squeeze redundant dimension to recover original shape
|
| 154 |
+
rotated_points = rotated_points_expanded.squeeze(-1)
|
| 155 |
+
|
| 156 |
+
return rotated_points
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def rotated_points_to_world_points(rotated_points, t):
|
| 160 |
+
"""
|
| 161 |
+
Transform rotated camera points to world coordinates by adding translation vector.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
rotated_points (torch.Tensor): Rotated 3D point cloud, shape [B, S, H, W, 3]
|
| 165 |
+
t (torch.Tensor): Translation vector, shape [B, S, 3] (per batch-sequence translation)
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
torch.Tensor: World-coordinate 3D point cloud, shape [B, S, H, W, 3] (same as input)
|
| 169 |
+
"""
|
| 170 |
+
# Validate input shapes and dimension matching
|
| 171 |
+
assert rotated_points.dim() == 5 and rotated_points.shape[-1] == 3, \
|
| 172 |
+
f"Rotated points must be [B, S, H, W, 3], got {rotated_points.shape}"
|
| 173 |
+
assert t.dim() == 3 and t.shape[-1] == 3, \
|
| 174 |
+
f"Translation t must be [B, S, 3], got {t.shape}"
|
| 175 |
+
assert rotated_points.shape[:2] == t.shape[:2], \
|
| 176 |
+
f"Batch/Sequence dim mismatch: rotated_points {rotated_points.shape[:2]} vs t {t.shape[:2]}"
|
| 177 |
+
|
| 178 |
+
# Expand translation dimensions for broadcasting with spatial dimensions (H, W)
|
| 179 |
+
# t: [B, S, 3] -> [B, S, 1, 1, 3] (broadcast to H and W)
|
| 180 |
+
t_expanded = t.unsqueeze(2).unsqueeze(2)
|
| 181 |
+
|
| 182 |
+
# Add translation (broadcasting automatically applies t to all H×W points per B-S pair)
|
| 183 |
+
world_points = rotated_points + t_expanded
|
| 184 |
+
|
| 185 |
+
return world_points
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def unproject_depth_to_world_points(depth, extrinsic, size=560):
|
| 190 |
+
'''
|
| 191 |
+
Args:
|
| 192 |
+
depth: [S, H, W, 1]
|
| 193 |
+
extrinsic: [S, 4, 4]
|
| 194 |
+
Returns:
|
| 195 |
+
world_points: [S, H, W, 3]
|
| 196 |
+
'''
|
| 197 |
+
camera_points = pano_depth_to_points(depth, original_pano_shape=(size, size//2))
|
| 198 |
+
rotated_points = camera_points_to_rotated_points(camera_points, extrinsic[:, :, :3, :3])
|
| 199 |
+
world_points = rotated_points_to_world_points(rotated_points, extrinsic[:, :, :3, 3])
|
| 200 |
+
|
| 201 |
+
return world_points
|
argus/utils/normalization.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
from argus.utils.geometry import closed_form_inverse_se3
|
| 4 |
+
|
| 5 |
+
def cal_scale_by_points(points: torch.Tensor, point_masks: torch.Tensor) -> torch.Tensor:
|
| 6 |
+
# Calculate average distance of valid 3D points (batch-wise)
|
| 7 |
+
dist = points.norm(dim=-1)
|
| 8 |
+
dist_sum = (dist * point_masks).sum(dim=[1, 2, 3]) # Shape: [B,]
|
| 9 |
+
valid_count = point_masks.sum(dim=[1, 2, 3])
|
| 10 |
+
avg_scale = (dist_sum / (valid_count + 1e-3)).clamp(min=1e-6, max=1e6)
|
| 11 |
+
return avg_scale
|
| 12 |
+
|
| 13 |
+
def normalize_camera_extrinsics_and_points_batch(
|
| 14 |
+
extrinsics: torch.Tensor,
|
| 15 |
+
cam_points: torch.Tensor,
|
| 16 |
+
depths: torch.Tensor,
|
| 17 |
+
point_masks: torch.Tensor,
|
| 18 |
+
scale_mode: str = "none",
|
| 19 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 20 |
+
# Basic input validation
|
| 21 |
+
assert extrinsics.ndim == 4 and extrinsics.shape[2:] == (4, 4), \
|
| 22 |
+
f"Extrinsics must be (B, S, 4, 4), got {extrinsics.shape}"
|
| 23 |
+
B, S = extrinsics.shape[:2]
|
| 24 |
+
device = extrinsics.device
|
| 25 |
+
|
| 26 |
+
# Step 1: Transform all extrinsics to reference frame (1st frame of each batch)
|
| 27 |
+
ref_extrinsics = extrinsics[:,0,:,:] # (B, 4, 4)
|
| 28 |
+
ref_extr_inv = closed_form_inverse_se3(ref_extrinsics)
|
| 29 |
+
new_extrinsics = torch.matmul(ref_extr_inv.unsqueeze(1), extrinsics) # (B, S, 4, 4) world coordinate
|
| 30 |
+
|
| 31 |
+
# Step 2: Clone tensors to avoid in-place modification
|
| 32 |
+
new_depths = depths.clone()
|
| 33 |
+
new_cam_points = cam_points.clone()
|
| 34 |
+
|
| 35 |
+
# Step 3: Compute rotated/world points from new extrinsics
|
| 36 |
+
R_new = new_extrinsics[:, :, :3, :3] # (B, S, 3, 3)
|
| 37 |
+
t_new = new_extrinsics[:, :, :3, 3] # (B, S, 3)
|
| 38 |
+
new_rotated_points = torch.matmul(R_new.unsqueeze(2).unsqueeze(3), new_cam_points.unsqueeze(-1)).squeeze(-1) # (B,S,1,1,3,3) × (B,S,H,W,3,1) -> (B,S,H,W,3)
|
| 39 |
+
new_world_points = new_rotated_points + t_new.unsqueeze(2).unsqueeze(3)
|
| 40 |
+
|
| 41 |
+
# Step 4: Apply scene scaling
|
| 42 |
+
if scale_mode == "avg_dist":
|
| 43 |
+
avg_scale = cal_scale_by_points(new_world_points, point_masks) # (B,)
|
| 44 |
+
# Reshape scale for broadcasting with different tensor shapes
|
| 45 |
+
scale_3d = avg_scale.view(-1, 1, 1) # For extrinsics (B, S, 4, 4)
|
| 46 |
+
scale_4d = avg_scale.view(-1, 1, 1, 1) # For depths (B, S, H, W)
|
| 47 |
+
scale_5d = avg_scale.view(-1, 1, 1, 1, 1) # For 3D points (B, S, H, W, 3)
|
| 48 |
+
new_extrinsics[:, :, :3, 3] /= scale_3d
|
| 49 |
+
new_depths /= scale_4d
|
| 50 |
+
new_cam_points /= scale_5d
|
| 51 |
+
new_rotated_points /= scale_5d
|
| 52 |
+
new_world_points /= scale_5d
|
| 53 |
+
elif scale_mode == "abs":
|
| 54 |
+
metric_scale = 10.0
|
| 55 |
+
new_extrinsics[:, :, :3, 3] /= metric_scale
|
| 56 |
+
new_depths /= metric_scale
|
| 57 |
+
new_cam_points /= metric_scale
|
| 58 |
+
new_rotated_points /= metric_scale
|
| 59 |
+
new_world_points /= metric_scale
|
| 60 |
+
elif scale_mode == "none":
|
| 61 |
+
pass
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Unknown scale_mode: {scale_mode}")
|
| 64 |
+
|
| 65 |
+
return new_extrinsics, new_cam_points, new_rotated_points, new_world_points, new_depths
|
argus/utils/pose_enc.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
from .rotation import quat_to_mat, mat_to_quat
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def extri_to_pose_encoding360(
|
| 7 |
+
extrinsics: torch.Tensor,
|
| 8 |
+
pose_encoding_type: Union[str, "absT_quaR"] = "absT_quaR"
|
| 9 |
+
) -> torch.Tensor:
|
| 10 |
+
"""
|
| 11 |
+
Convert camera extrinsic parameters to a compact pose encoding (absolute translation + quaternion rotation).
|
| 12 |
+
|
| 13 |
+
Transforms OpenCV-style camera extrinsics (3x4 [R|t] matrix) into a flattened encoding format
|
| 14 |
+
suitable for machine learning tasks like pose prediction or representation learning.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
extrinsics: Camera extrinsic matrices with shape [B, S, 3, 4] or [B, S, 4, 4]
|
| 18 |
+
- B: Batch size
|
| 19 |
+
- S: Sequence length (number of frames)
|
| 20 |
+
- 3x4/4x4: Extrinsic matrix in OpenCV coordinate system (x-right, y-down, z-forward)
|
| 21 |
+
representing the transformation from world to camera space ([R|t] where R=3x3 rotation, t=3x1 translation)
|
| 22 |
+
pose_encoding_type: Type of pose encoding format (only "absT_quaR" supported):
|
| 23 |
+
- "absT_quaR": Absolute translation (3D) + quaternion rotation (4D)
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Encoded pose tensor with shape [B, S, 7]
|
| 27 |
+
- [:3]: Absolute translation vector (T) in world coordinates
|
| 28 |
+
- [3:7]: Rotation represented as unit quaternion (quat)
|
| 29 |
+
"""
|
| 30 |
+
# Extract rotation matrix (R) and translation vector (T) from extrinsics
|
| 31 |
+
# Handle both 3x4 and 4x4 extrinsic matrix inputs
|
| 32 |
+
R = extrinsics[:, :, :3, :3] # [B, S, 3, 3] - rotation matrix
|
| 33 |
+
T = extrinsics[:, :, :3, 3] # [B, S, 3] - translation vector
|
| 34 |
+
|
| 35 |
+
if pose_encoding_type == "absT_quaR":
|
| 36 |
+
# Convert rotation matrix to quaternion (4D)
|
| 37 |
+
quat = mat_to_quat(R)
|
| 38 |
+
|
| 39 |
+
# Concatenate translation and quaternion to form compact pose encoding
|
| 40 |
+
pose_encoding = torch.cat([T, quat], dim=-1).float()
|
| 41 |
+
else:
|
| 42 |
+
raise NotImplementedError(f"Pose encoding type '{pose_encoding_type}' not supported. Only 'absT_quaR' is implemented.")
|
| 43 |
+
|
| 44 |
+
return pose_encoding
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def pose_encoding_to_extri360(
|
| 48 |
+
pose_encoding: torch.Tensor,
|
| 49 |
+
pose_encoding_type: Union[str, "absT_quaR"] = "absT_quaR"
|
| 50 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 51 |
+
"""
|
| 52 |
+
Convert compact pose encoding back to full camera extrinsic parameters (inverse of extri_to_pose_encoding360).
|
| 53 |
+
|
| 54 |
+
Reconstructs the 4x4 homogeneous extrinsic matrix from the flattened pose encoding,
|
| 55 |
+
including extraction of confidence scores from the encoding's extra dimensions.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
pose_encoding: Encoded pose tensor with shape [B, S, 9]
|
| 59 |
+
- B: Batch size
|
| 60 |
+
- S: Sequence length (number of frames)
|
| 61 |
+
- [:3]: Absolute translation vector (T)
|
| 62 |
+
- [3:7]: Rotation quaternion (quat)
|
| 63 |
+
- [-2:]: Confidence scores for translation and rotation
|
| 64 |
+
pose_encoding_type: Type of pose encoding format (only "absT_quaR" supported):
|
| 65 |
+
- "absT_quaR": Absolute translation (3D) + quaternion rotation (4D)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tuple containing:
|
| 69 |
+
1. extrinsics: Reconstructed camera extrinsic matrices with shape [B, S, 4, 4]
|
| 70 |
+
(homogeneous matrix in OpenCV coordinate system: [R|t; 0 0 0 1])
|
| 71 |
+
2. conf: Confidence scores with shape [B, S, 2]
|
| 72 |
+
- [:, :, 0]: Translation confidence
|
| 73 |
+
- [:, :, 1]: Rotation confidence
|
| 74 |
+
|
| 75 |
+
Raises:
|
| 76 |
+
NotImplementedError: If unsupported pose encoding type is provided
|
| 77 |
+
"""
|
| 78 |
+
if pose_encoding_type == "absT_quaR":
|
| 79 |
+
# Extract translation (T) and rotation quaternion (quat) from pose encoding
|
| 80 |
+
T = pose_encoding[..., :3] # [B, S, 3] - translation vector
|
| 81 |
+
quat = pose_encoding[..., 3:7] # [B, S, 4] - rotation quaternion
|
| 82 |
+
|
| 83 |
+
# Convert quaternion back to rotation matrix (3x3)
|
| 84 |
+
R = quat_to_mat(quat) # [B, S, 3, 3]
|
| 85 |
+
|
| 86 |
+
# Reconstruct 3x4 [R|t] matrix (rotation + translation)
|
| 87 |
+
extri_3x4 = torch.cat([R, T[..., None]], dim=-1) # [B, S, 3, 4]
|
| 88 |
+
|
| 89 |
+
# Add homogeneous row [0, 0, 0, 1] to form 4x4 extrinsic matrix
|
| 90 |
+
batch_size, seq_len = extri_3x4.shape[:2]
|
| 91 |
+
homogenous_row = torch.tensor(
|
| 92 |
+
[0, 0, 0, 1],
|
| 93 |
+
device=extri_3x4.device,
|
| 94 |
+
dtype=extri_3x4.dtype
|
| 95 |
+
).expand(batch_size, seq_len, 1, 4) # [B, S, 1, 4]
|
| 96 |
+
|
| 97 |
+
# Combine to form 4x4 homogeneous extrinsic matrix
|
| 98 |
+
extrinsics = torch.cat((extri_3x4, homogenous_row), dim=2) # [B, S, 4, 4]
|
| 99 |
+
|
| 100 |
+
# Extract confidence scores (last two dimensions of pose encoding)
|
| 101 |
+
conf = pose_encoding[..., -2:] # [B, S, 2]
|
| 102 |
+
|
| 103 |
+
return extrinsics, conf
|
| 104 |
+
|
| 105 |
+
raise NotImplementedError(f"Pose encoding type '{pose_encoding_type}' not supported. Only 'absT_quaR' is implemented.")
|
argus/utils/rotation.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
| 7 |
+
"""
|
| 8 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 9 |
+
|
| 10 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 11 |
+
Args:
|
| 12 |
+
quaternions: quaternions with real part last,
|
| 13 |
+
as tensor of shape (..., 4).
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 17 |
+
"""
|
| 18 |
+
# Normalize quaternions to unit length
|
| 19 |
+
quaternions = F.normalize(quaternions, dim=-1)
|
| 20 |
+
|
| 21 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
| 22 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 23 |
+
|
| 24 |
+
o = torch.stack(
|
| 25 |
+
(
|
| 26 |
+
1 - two_s * (j * j + k * k),
|
| 27 |
+
two_s * (i * j - k * r),
|
| 28 |
+
two_s * (i * k + j * r),
|
| 29 |
+
two_s * (i * j + k * r),
|
| 30 |
+
1 - two_s * (i * i + k * k),
|
| 31 |
+
two_s * (j * k - i * r),
|
| 32 |
+
two_s * (i * k - j * r),
|
| 33 |
+
two_s * (j * k + i * r),
|
| 34 |
+
1 - two_s * (i * i + j * j),
|
| 35 |
+
),
|
| 36 |
+
-1,
|
| 37 |
+
)
|
| 38 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
| 50 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 51 |
+
"""
|
| 52 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 53 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 54 |
+
|
| 55 |
+
batch_dim = matrix.shape[:-2]
|
| 56 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
| 57 |
+
|
| 58 |
+
q_abs = _sqrt_positive_part(
|
| 59 |
+
torch.stack(
|
| 60 |
+
[1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 65 |
+
quat_by_rijk = torch.stack(
|
| 66 |
+
[
|
| 67 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 68 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 69 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 70 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 71 |
+
],
|
| 72 |
+
dim=-2,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 76 |
+
# the candidate won't be picked.
|
| 77 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 78 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 79 |
+
|
| 80 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 81 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 82 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
| 83 |
+
|
| 84 |
+
# Convert from rijk to ijkr
|
| 85 |
+
out = out[..., [1, 2, 3, 0]]
|
| 86 |
+
|
| 87 |
+
out = standardize_quaternion(out)
|
| 88 |
+
|
| 89 |
+
return out
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""
|
| 94 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 95 |
+
but with a zero subgradient where x is 0.
|
| 96 |
+
"""
|
| 97 |
+
ret = torch.zeros_like(x)
|
| 98 |
+
positive_mask = x > 0
|
| 99 |
+
if torch.is_grad_enabled():
|
| 100 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 101 |
+
else:
|
| 102 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 103 |
+
return ret
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
"""
|
| 108 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 109 |
+
part is non negative.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
quaternions: Quaternions with real part last,
|
| 113 |
+
as tensor of shape (..., 4).
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 117 |
+
"""
|
| 118 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
assets/argus_logo.png
ADDED
|
Git LFS Details
|
examples/far_4/0.jpg
ADDED
|
Git LFS Details
|
examples/far_4/1.jpg
ADDED
|
Git LFS Details
|
examples/far_4/2.jpg
ADDED
|
Git LFS Details
|
examples/far_4/3.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748389.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748429.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748477.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748528.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748562.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748600.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748638.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748685.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748728.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748770.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748817.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748866.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748907.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757748959.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757749004.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757749043.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757749091.jpg
ADDED
|
Git LFS Details
|
examples/scene_00008/1757749140.jpg
ADDED
|
Git LFS Details
|