lixi042 commited on
Commit
510e990
·
1 Parent(s): a4415c0

Initial commit: Argus metric panoramic 3D reconstruction demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. app.py +1499 -0
  3. argus/__init__.py +2 -0
  4. argus/heads/__init__.py +2 -0
  5. argus/heads/camera_head.py +142 -0
  6. argus/heads/dpt_head.py +474 -0
  7. argus/heads/head_act.py +122 -0
  8. argus/heads/utils.py +142 -0
  9. argus/layers/__init__.py +8 -0
  10. argus/layers/attention.py +93 -0
  11. argus/layers/block.py +247 -0
  12. argus/layers/drop_path.py +34 -0
  13. argus/layers/layer_scale.py +22 -0
  14. argus/layers/mlp.py +40 -0
  15. argus/layers/patch_embed.py +85 -0
  16. argus/layers/rope.py +188 -0
  17. argus/layers/swiglu_ffn.py +67 -0
  18. argus/layers/vision_transformer.py +401 -0
  19. argus/models/__init__.py +2 -0
  20. argus/models/aggregator.py +502 -0
  21. argus/models/argus.py +234 -0
  22. argus/utils/__init__.py +2 -0
  23. argus/utils/data_io.py +152 -0
  24. argus/utils/geometry.py +201 -0
  25. argus/utils/normalization.py +65 -0
  26. argus/utils/pose_enc.py +105 -0
  27. argus/utils/rotation.py +118 -0
  28. assets/argus_logo.png +3 -0
  29. examples/far_4/0.jpg +3 -0
  30. examples/far_4/1.jpg +3 -0
  31. examples/far_4/2.jpg +3 -0
  32. examples/far_4/3.jpg +3 -0
  33. examples/scene_00008/1757748389.jpg +3 -0
  34. examples/scene_00008/1757748429.jpg +3 -0
  35. examples/scene_00008/1757748477.jpg +3 -0
  36. examples/scene_00008/1757748528.jpg +3 -0
  37. examples/scene_00008/1757748562.jpg +3 -0
  38. examples/scene_00008/1757748600.jpg +3 -0
  39. examples/scene_00008/1757748638.jpg +3 -0
  40. examples/scene_00008/1757748685.jpg +3 -0
  41. examples/scene_00008/1757748728.jpg +3 -0
  42. examples/scene_00008/1757748770.jpg +3 -0
  43. examples/scene_00008/1757748817.jpg +3 -0
  44. examples/scene_00008/1757748866.jpg +3 -0
  45. examples/scene_00008/1757748907.jpg +3 -0
  46. examples/scene_00008/1757748959.jpg +3 -0
  47. examples/scene_00008/1757749004.jpg +3 -0
  48. examples/scene_00008/1757749043.jpg +3 -0
  49. examples/scene_00008/1757749091.jpg +3 -0
  50. examples/scene_00008/1757749140.jpg +3 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,1499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard library imports
2
+ import os
3
+ import sys
4
+ import shutil
5
+ import glob
6
+ import gc
7
+ import time
8
+ import base64
9
+ import argparse
10
+ import tempfile
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+
14
+ # Third-party library imports
15
+ import cv2
16
+ import torch
17
+ import trimesh
18
+ import numpy as np
19
+ import gradio as gr
20
+ import matplotlib
21
+ import matplotlib.pyplot as plt
22
+ from scipy.spatial.transform import Rotation
23
+
24
+ # Custom module imports
25
+ from argus.models.argus import Argus
26
+ from argus.utils.pose_enc import pose_encoding_to_extri360
27
+ from argus.utils.geometry import unproject_depth_to_world_points
28
+
29
+
30
+ # -------------------------- Argument Parsing --------------------------
31
+ def parse_args():
32
+ parser = argparse.ArgumentParser(description="Argus Gradio Demo")
33
+ parser.add_argument(
34
+ "--model_path",
35
+ type=str,
36
+ default=None,
37
+ help="Path to pre-trained model weights (.pt file). "
38
+ "If not specified, auto-downloads from HuggingFace.",
39
+ )
40
+ parser.add_argument(
41
+ "--img_size",
42
+ type=int,
43
+ default=560,
44
+ help="Input panoramic image target width (height = width // 2)",
45
+ )
46
+ parser.add_argument(
47
+ "--crop_ratio",
48
+ type=float,
49
+ default=0.15,
50
+ help="Vertical crop ratio for panoramic image preprocessing (0-0.5)",
51
+ )
52
+ parser.add_argument(
53
+ "--port",
54
+ type=int,
55
+ default=7860,
56
+ help="Port number for Gradio server",
57
+ )
58
+ parser.add_argument(
59
+ "--share",
60
+ action="store_true",
61
+ default=False,
62
+ help="Enable Gradio public sharing link",
63
+ )
64
+ parser.add_argument(
65
+ "--server_name",
66
+ type=str,
67
+ default="0.0.0.0",
68
+ help="Server host address (0.0.0.0 for all interfaces)",
69
+ )
70
+ parser.add_argument(
71
+ "--device",
72
+ type=str,
73
+ default=None,
74
+ help="Device to use (cuda/cpu). Default: auto-detect",
75
+ )
76
+ parser.add_argument(
77
+ "--examples_dir",
78
+ type=str,
79
+ default="examples",
80
+ help="Directory containing example scenes",
81
+ )
82
+ parser.add_argument(
83
+ "--save_tmp",
84
+ type=str,
85
+ default=None,
86
+ help="Directory to persist intermediate files (images, predictions, GLB). "
87
+ "If not set, uses system temp dir and cleans up automatically.",
88
+ )
89
+ return parser.parse_args()
90
+
91
+
92
+ args = parse_args()
93
+
94
+ # -------------------------- Global Configuration --------------------------
95
+ # Device configuration: use specified device or auto-detect
96
+ DEVICE = args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
97
+ # Input panoramic image target size (ERP: W=img_size, H=img_size//2)
98
+ IMG_SIZE = args.img_size
99
+ # Vertical crop ratio for panoramic image preprocessing
100
+ CROP_RATIO = args.crop_ratio
101
+
102
+
103
+ def resolve_model_path(model_path: str) -> str:
104
+ """
105
+ Resolve model path: if a local file is specified and exists, use it directly;
106
+ otherwise download from HuggingFace Hub.
107
+ Requires `huggingface-cli login` for gated repos.
108
+ """
109
+ if model_path is not None and os.path.isfile(model_path):
110
+ return model_path
111
+
112
+ if model_path is not None:
113
+ print(f"Specified model path '{model_path}' not found.")
114
+
115
+ print("Downloading model from HuggingFace (RealseeTechnology/argus-realsee3d)...")
116
+ try:
117
+ from huggingface_hub import hf_hub_download
118
+ downloaded_path = hf_hub_download(
119
+ repo_id="RealseeTechnology/argus-realsee3d",
120
+ filename="argus_realsee3d.pt",
121
+ )
122
+ print(f"Model downloaded to: {downloaded_path}")
123
+ return downloaded_path
124
+ except Exception as e:
125
+ error_msg = str(e)
126
+ if "GatedRepoError" in type(e).__name__ or "401" in error_msg:
127
+ raise RuntimeError(
128
+ "Cannot access gated model repo. Please authenticate first:\n"
129
+ " 1. Run: hf auth login\n"
130
+ " 2. Accept the model license at: https://huggingface.co/RealseeTechnology/argus-realsee3d\n"
131
+ " 3. Re-run this script.\n"
132
+ "Or download manually and specify --model_path."
133
+ ) from e
134
+ raise
135
+
136
+
137
+ # Pre-trained model path (auto-download if not found locally)
138
+ MODEL_PATH = resolve_model_path(args.model_path)
139
+
140
+ # -------------------------- Model Initialization --------------------------
141
+ print("Initializing and loading Argus model...")
142
+ # Initialize Argus model with metric scale and learning ref reorder
143
+ model = Argus(reorder_by_learning_ref=True, restore_metric_scale=True)
144
+ # Load model weights (non-strict to ignore unused parameters)
145
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)["model"], strict=False)
146
+ # Set model to evaluation mode and move to target device
147
+ model.eval()
148
+ model = model.to(DEVICE)
149
+
150
+
151
+ # -------------------------- Image Preprocessing --------------------------
152
+ def load_and_preprocess_images(image_path_list, target_size=IMG_SIZE):
153
+ """
154
+ Load and preprocess panoramic images for model inference
155
+ Args:
156
+ image_path_list (list): List of input image file paths
157
+ target_size (int): Target width of panoramic image (height = target_size//2)
158
+ Returns:
159
+ torch.Tensor: Preprocessed tensor with shape (S, C, H, W)
160
+ S: sequence length, C: 3(RGB), H/W: image size
161
+ """
162
+ images = []
163
+ pano_W, pano_H = target_size, target_size // 2
164
+
165
+ # Load and resize each image
166
+ for image_path in image_path_list:
167
+ img = cv2.imread(image_path) # Load as BGR (H, W, C)
168
+ h, w = img.shape[:2]
169
+ if w != pano_W or h != pano_H:
170
+ img = cv2.resize(img, (pano_W, pano_H), interpolation=cv2.INTER_AREA)
171
+ images.append(img)
172
+
173
+ # Stack and preprocess: crop vertical → BGR2RGB → normalize → reshape
174
+ images = np.stack(images) # (S, H, W, C)
175
+ # Crop top/bottom 15% of height and convert BGR to RGB
176
+ images = np.ascontiguousarray(
177
+ images[:, int(pano_H * CROP_RATIO) : int(pano_H * (1 - CROP_RATIO)), :, ::-1]
178
+ )
179
+ # Convert to tensor and normalize to [0,1]
180
+ images = torch.from_numpy(images).float() / 255.0
181
+ # Reshape to (S, C, H, W) for PyTorch model input
182
+ images = images.permute(0, 3, 1, 2)
183
+
184
+ return images
185
+
186
+
187
+ # -------------------------- Point Cloud Utils --------------------------
188
+ def save_point_cloud_to_ply(points: np.ndarray, save_path: str):
189
+ """
190
+ Save 3D point cloud (N,3) to PLY format (ASCII) for universal compatibility
191
+ Args:
192
+ points (np.ndarray): 3D point cloud with shape [N, 3] (x, y, z for each point)
193
+ save_path (str): Output PLY file path
194
+ Raises:
195
+ ValueError: If input points shape is not [N, 3]
196
+ """
197
+ # Validate input point cloud shape
198
+ if points.ndim != 2 or points.shape[1] != 3:
199
+ raise ValueError(f"Point cloud must be [N,3], got {points.shape}")
200
+
201
+ num_points = points.shape[0]
202
+ # PLY format header (follow official specification)
203
+ ply_header = f"""ply
204
+ format ascii 1.0
205
+ element vertex {num_points}
206
+ property float x
207
+ property float y
208
+ property float z
209
+ end_header
210
+ """
211
+ # Write header and point data to file
212
+ with open(save_path, "w", encoding="utf-8") as f:
213
+ f.write(ply_header)
214
+ np.savetxt(f, points, fmt="%.6f %.6f %.6f")
215
+
216
+
217
+ # -------------------------- Core Model Inference --------------------------
218
+ def run_model(target_dir, model) -> dict:
219
+ """
220
+ Run Argus model inference on images in target_dir/images
221
+ Args:
222
+ target_dir (str): Root directory containing 'images' subfolder
223
+ model (Argus): Pre-initialized Argus model
224
+ Returns:
225
+ dict: Model predictions with tensor converted to numpy array
226
+ Raises:
227
+ ValueError: If CUDA unavailable or no images found in target_dir
228
+ """
229
+ print(f"Processing images from {target_dir}")
230
+
231
+ # Enforce CUDA for inference
232
+ if not torch.cuda.is_available():
233
+ raise ValueError("CUDA is not available. Inference requires GPU acceleration.")
234
+
235
+ model = model.to(DEVICE)
236
+ model.eval()
237
+
238
+ # Load and sort input images
239
+ image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
240
+ print(f"Found {len(image_names)} input images")
241
+ if len(image_names) == 0:
242
+ raise ValueError("No images found in target_dir/images. Check your upload.")
243
+
244
+ # Preprocess images and move to device
245
+ images = load_and_preprocess_images(image_names, target_size=IMG_SIZE).to(DEVICE)
246
+ print(f"Preprocessed images shape: {images.shape}")
247
+
248
+ # Mixed precision inference for speed and memory efficiency
249
+ print("Running model inference...")
250
+ dtype = (
251
+ torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
252
+ )
253
+
254
+ torch.cuda.synchronize()
255
+ t0 = time.perf_counter()
256
+
257
+ with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
258
+ predictions = model(images)
259
+
260
+ torch.cuda.synchronize()
261
+ t1 = time.perf_counter()
262
+ inference_time = t1 - t0
263
+ print(f"Inference time: {inference_time:.3f} s")
264
+
265
+ # Convert pose encoding to extrinsic/intrinsic matrices
266
+ print("Converting pose encoding to extrinsic matrices...")
267
+ extrinsic, conf = pose_encoding_to_extri360(pose_encoding=predictions["pose_enc"])
268
+ predictions["extrinsic"] = extrinsic[:, :, :3, :]
269
+
270
+ # Unproject depth map to 3D world coordinates
271
+ print("Computing 3D world points from depth map...")
272
+ world_points = unproject_depth_to_world_points(
273
+ predictions["depth"], predictions["extrinsic"], size=IMG_SIZE
274
+ )
275
+ predictions["world_points_from_depth"] = world_points
276
+
277
+ # Convert all torch tensors to numpy arrays and remove batch dimension
278
+ print("Converting model outputs to numpy arrays...")
279
+ for key in predictions.keys():
280
+ if isinstance(predictions[key], torch.Tensor):
281
+ predictions[key] = predictions[key].cpu().float().numpy().squeeze(0)
282
+ elif isinstance(predictions[key], list):
283
+ for i in range(len(predictions[key])):
284
+ if isinstance(predictions[key][i], torch.Tensor):
285
+ predictions[key][i] = (
286
+ predictions[key][i].cpu().float().numpy().squeeze(0)
287
+ )
288
+
289
+ print(f"Model prediction keys: {predictions.keys()}")
290
+ # Clear CUDA cache to save memory
291
+ torch.cuda.empty_cache()
292
+ return predictions, inference_time
293
+
294
+
295
+ # -------------------------- Upload File Handling --------------------------
296
+ def handle_uploads(input_images):
297
+ """
298
+ Create directory for uploaded images and copy files to target path.
299
+ Uses system temp dir by default; uses --save_tmp dir if specified.
300
+ Args:
301
+ input_images: Gradio uploaded file data
302
+ Returns:
303
+ tuple: (target_dir, sorted_image_paths)
304
+ """
305
+ start_time = time.time()
306
+ gc.collect()
307
+ torch.cuda.empty_cache()
308
+
309
+ # Create target directory: persistent if --save_tmp is set, otherwise temp
310
+ if args.save_tmp:
311
+ os.makedirs(args.save_tmp, exist_ok=True)
312
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
313
+ target_dir = os.path.join(args.save_tmp, f"input_images_{timestamp}")
314
+ else:
315
+ target_dir = tempfile.mkdtemp(prefix="argus_")
316
+ target_img_dir = os.path.join(target_dir, "images")
317
+
318
+ # Clean up if directory exists (edge case)
319
+ if os.path.exists(target_dir) and args.save_tmp:
320
+ shutil.rmtree(target_dir)
321
+ os.makedirs(target_dir, exist_ok=True)
322
+ os.makedirs(target_img_dir, exist_ok=True)
323
+
324
+ # Copy uploaded images to target directory
325
+ image_paths = []
326
+ if input_images is not None:
327
+ for file_data in input_images:
328
+ # Get file path from Gradio file data
329
+ file_path = file_data["name"] if isinstance(file_data, dict) else file_data
330
+ dst_path = os.path.join(target_img_dir, os.path.basename(file_path))
331
+ shutil.copy(file_path, dst_path)
332
+ image_paths.append(dst_path)
333
+
334
+ # Sort images for consistent processing
335
+ image_paths = sorted(image_paths)
336
+ print(
337
+ f"Files copied to {target_img_dir} | Time cost: {time.time() - start_time:.3f}s"
338
+ )
339
+ return target_dir, image_paths
340
+
341
+
342
+ def update_gallery_on_upload(input_images):
343
+ """
344
+ Update image gallery immediately after file upload
345
+ Args:
346
+ input_images: Gradio uploaded file data
347
+ Returns:
348
+ tuple: Gradio component update values
349
+ """
350
+ if not input_images:
351
+ return None, None, None, None
352
+ target_dir, image_paths = handle_uploads(input_images)
353
+ return (
354
+ None,
355
+ target_dir,
356
+ image_paths,
357
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
358
+ )
359
+
360
+
361
+ # -------------------------- 3D Reconstruction Pipeline --------------------------
362
+ def gradio_demo(
363
+ target_dir,
364
+ conf_thres=5.0,
365
+ frame_filter="All",
366
+ show_cam=True,
367
+ show_index=True,
368
+ ceiling_remove=25,
369
+ ):
370
+ """
371
+ Main 3D reconstruction pipeline for Gradio interface
372
+ Args:
373
+ target_dir (str): Directory with input images
374
+ conf_thres (float): Confidence threshold for point cloud filtering
375
+ frame_filter (str): Filter frames to show in 3D model
376
+ show_cam (bool): Whether to show camera poses in 3D model
377
+ show_index (bool): Whether to show frame indices in 3D model
378
+ ceiling_remove (float): Percentage of top Y-coordinate points to remove as ceiling (0-100, 0=disabled)
379
+ Returns:
380
+ tuple: Gradio component update values (3D model, logs, dropdown, etc.)
381
+ """
382
+ # Validate target directory
383
+ if not os.path.isdir(target_dir) or target_dir == "None":
384
+ return (
385
+ None,
386
+ "No valid target directory. Please upload images first.",
387
+ None,
388
+ None,
389
+ None,
390
+ "",
391
+ None,
392
+ )
393
+
394
+ start_time = time.time()
395
+ gc.collect()
396
+ torch.cuda.empty_cache()
397
+
398
+ # Prepare frame filter dropdown options
399
+ target_img_dir = os.path.join(target_dir, "images")
400
+ all_files = (
401
+ sorted(os.listdir(target_img_dir)) if os.path.isdir(target_img_dir) else []
402
+ )
403
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
404
+ frame_filter_choices = ["All"] + all_files
405
+
406
+ # Run model inference
407
+ with torch.no_grad():
408
+ predictions, inference_time = run_model(target_dir, model)
409
+
410
+ # Save predictions to NPZ for later visualization update
411
+ pred_save_path = os.path.join(target_dir, "predictions.npz")
412
+ np.savez(pred_save_path, **predictions)
413
+
414
+ # Default frame filter to All if None
415
+ frame_filter = frame_filter if frame_filter is not None else "All"
416
+
417
+ # Generate unique GLB filename with parameters
418
+ glb_filename = f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_index{show_index}_ceiling{ceiling_remove}.glb"
419
+ glbfile = os.path.join(target_dir, glb_filename)
420
+
421
+ # Convert model predictions to GLB 3D model
422
+ glbscene = predictions_to_glb(
423
+ predictions,
424
+ conf_thres=conf_thres,
425
+ filter_by_frames=frame_filter,
426
+ show_cam=show_cam,
427
+ show_index=show_index,
428
+ ceiling_remove=ceiling_remove,
429
+ target_dir=target_dir,
430
+ )
431
+ glbscene.export(file_obj=glbfile)
432
+
433
+ # Prepare measure view
434
+ measure_img, _ = update_measure_view(predictions, 0)
435
+ # Create view selector based on number of input images
436
+ num_views = (
437
+ predictions["images"].shape[0] if predictions["images"].shape[0] > 0 else 1
438
+ )
439
+ view_choices = [f"View {i + 1}" for i in range(num_views)]
440
+ measure_selector = gr.Dropdown(choices=view_choices, value=view_choices[0])
441
+
442
+ # Clean up memory
443
+ gc.collect()
444
+ torch.cuda.empty_cache()
445
+
446
+ total_time = time.time() - start_time
447
+ log_msg = f"Reconstruction Success ({len(all_files)} frames). Inference: {inference_time:.2f}s | Total: {total_time:.2f}s"
448
+ print(f"Reconstruction complete | Inference: {inference_time:.2f}s | Total: {total_time:.2f}s")
449
+
450
+ return (
451
+ glbfile,
452
+ log_msg,
453
+ gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
454
+ predictions,
455
+ measure_img,
456
+ "",
457
+ measure_selector,
458
+ )
459
+
460
+
461
+ # -------------------------- UI Utility Functions --------------------------
462
+ def clear_fields():
463
+ """Clear 3D model viewer for Gradio interface"""
464
+ return None
465
+
466
+
467
+ def update_log():
468
+ """Update log message during model processing"""
469
+ return "Loading and Reconstructing..."
470
+
471
+
472
+ def update_visualization(
473
+ target_dir,
474
+ conf_thres,
475
+ frame_filter,
476
+ show_cam,
477
+ show_index,
478
+ ceiling_remove,
479
+ is_example,
480
+ ):
481
+ """
482
+ Update 3D visualization when parameters change (without re-running model)
483
+ Args:
484
+ is_example (str): Whether it's example data (skip if "True")
485
+ Returns:
486
+ tuple: (GLB file path, log message)
487
+ """
488
+ # Skip if loading example data
489
+ if is_example == "True":
490
+ return (
491
+ None,
492
+ "No reconstruction available. Please click the Reconstruct button first.",
493
+ )
494
+ # Validate target directory and prediction file
495
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
496
+ return None, "No valid reconstruction. Please upload and reconstruct first."
497
+
498
+ pred_path = os.path.join(target_dir, "predictions.npz")
499
+ if not os.path.exists(pred_path):
500
+ return None, f"No prediction file found at {pred_path}. Run Reconstruct first."
501
+
502
+ # Load saved predictions
503
+ key_list = [
504
+ "pose_enc",
505
+ "depth",
506
+ "depth_conf",
507
+ "images",
508
+ "extrinsic",
509
+ "world_points_from_depth",
510
+ ]
511
+ loaded = np.load(pred_path)
512
+ predictions = {key: np.array(loaded[key]) for key in key_list if key in loaded}
513
+
514
+ # Generate GLB file (create if not exists)
515
+ glb_filename = f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_index{show_index}_ceiling{ceiling_remove}.glb"
516
+ glbfile = os.path.join(target_dir, glb_filename)
517
+
518
+ if not os.path.exists(glbfile):
519
+ glbscene = predictions_to_glb(
520
+ predictions,
521
+ conf_thres=conf_thres,
522
+ filter_by_frames=frame_filter,
523
+ show_cam=show_cam,
524
+ show_index=show_index,
525
+ ceiling_remove=ceiling_remove,
526
+ target_dir=target_dir,
527
+ )
528
+ glbscene.export(file_obj=glbfile)
529
+
530
+ return glbfile, "Visualization updated successfully"
531
+
532
+
533
+ # -------------------------- Metric Measurement --------------------------
534
+ def update_measure_view(predictions, view_index):
535
+ """
536
+ Update measure view with depth confidence mask overlay
537
+ Args:
538
+ predictions (dict): Model predictions with images and depth confidence
539
+ view_index (int): Index of the view to show
540
+ Returns:
541
+ tuple: (processed_image, empty_list)
542
+ """
543
+ # Get image and depth confidence
544
+ image = predictions["images"][view_index].transpose(1, 2, 0).copy()
545
+ depth_conf = predictions["depth_conf"][view_index].copy()
546
+
547
+ # Convert image to uint8 format
548
+ if image.dtype != np.uint8:
549
+ image = (
550
+ (image * 255).astype(np.uint8)
551
+ if image.max() <= 1.0
552
+ else image.astype(np.uint8)
553
+ )
554
+
555
+ # Create depth confidence mask (filter low confidence areas)
556
+ depth_conf_norm = (depth_conf - depth_conf.min()) / (
557
+ depth_conf.max() - depth_conf.min()
558
+ )
559
+ mask = depth_conf_norm > 0.05
560
+ invalid_mask = ~mask
561
+
562
+ # Apply red overlay to invalid areas (low confidence)
563
+ if invalid_mask.any():
564
+ overlay_color = np.array([255, 220, 220], dtype=np.uint8)
565
+ alpha = 0.5 # Transparency
566
+ for c in range(3):
567
+ image[:, :, c] = np.where(
568
+ invalid_mask,
569
+ (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
570
+ image[:, :, c],
571
+ ).astype(np.uint8)
572
+
573
+ return image, []
574
+
575
+
576
+ def navigate_measure_view(processed_data, current_selector_value, direction):
577
+ """
578
+ Navigate between different measure views (previous/next)
579
+ Args:
580
+ direction (int): -1 for previous, +1 for next
581
+ Returns:
582
+ tuple: (new_selector_value, measure_image, empty_points)
583
+ """
584
+ if processed_data["images"].shape[0] == 0:
585
+ return "View 1", None, []
586
+
587
+ # Parse current view index from selector
588
+ try:
589
+ current_view = int(current_selector_value.split()[1]) - 1
590
+ except:
591
+ current_view = 0
592
+
593
+ # Calculate new view index (circular navigation)
594
+ num_views = processed_data["images"].shape[0]
595
+ new_view = (current_view + direction) % num_views
596
+
597
+ # Update selector and image
598
+ new_selector = f"View {new_view + 1}"
599
+ measure_image, _ = update_measure_view(processed_data, new_view)
600
+ return new_selector, measure_image, []
601
+
602
+
603
+ def measure(
604
+ processed_data, measure_points, current_view_selector, event: gr.SelectData
605
+ ):
606
+ """
607
+ Core metric measurement function: click to select points and calculate 3D distance
608
+ Args:
609
+ event (gr.SelectData): Gradio click event data (image coordinates)
610
+ Returns:
611
+ tuple: (annotated_image, measure_points, measurement_text)
612
+ """
613
+ try:
614
+ # Get current view index
615
+ try:
616
+ current_view = int(current_view_selector.split()[1]) - 1
617
+ except:
618
+ current_view = 0
619
+ # Validate view index
620
+ current_view = (
621
+ 0
622
+ if current_view < 0 or current_view >= processed_data["images"].shape[0]
623
+ else current_view
624
+ )
625
+
626
+ # Get clicked 2D point
627
+ point2d = event.index[0], event.index[1]
628
+ measure_points.append(point2d)
629
+ print(f"Measuring: clicked point {point2d} (view {current_view + 1})")
630
+
631
+ # Get base image and 3D points
632
+ image, _ = update_measure_view(processed_data, current_view)
633
+ image = image.copy()
634
+ points3d = processed_data["world_points_from_depth"][current_view]
635
+
636
+ # Draw blue circles for clicked points
637
+ for p in measure_points:
638
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
639
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
640
+
641
+ # Calculate depth for single point
642
+ depth_text = ""
643
+ depth = processed_data["depth"][current_view].squeeze(axis=-1)
644
+ for i, p in enumerate(measure_points):
645
+ try:
646
+ if 0 <= p[1] < depth.shape[0] and 0 <= p[0] < depth.shape[1]:
647
+ d = depth[p[1], p[0]]
648
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
649
+ else:
650
+ d = np.linalg.norm(points3d[p[1], p[0]], ord=2)
651
+ depth_text += f"- **P{i + 1} dist: {d:.2f}m.**\n"
652
+ except:
653
+ depth_text += f"- **P{i + 1}: Depth unavailable**\n"
654
+
655
+ # Calculate 3D distance for two points
656
+ if len(measure_points) == 2:
657
+ p1, p2 = measure_points
658
+ # Draw blue line between two points
659
+ if all(
660
+ 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]
661
+ for p in [p1, p2]
662
+ ):
663
+ image = cv2.line(image, p1, p2, color=(255, 0, 0), thickness=2)
664
+ # Calculate 3D Euclidean distance
665
+ try:
666
+ p1_3d = points3d[p1[1], p1[0]]
667
+ p2_3d = points3d[p2[1], p2[0]]
668
+ distance = np.linalg.norm(p1_3d - p2_3d)
669
+ distance_text = f"- **Distance: {distance:.2f}m**"
670
+ except:
671
+ distance_text = "- **Distance: Unable to compute**"
672
+ # Reset points after measurement
673
+ measure_points = []
674
+ return [image, measure_points, depth_text + distance_text]
675
+
676
+ return [image, measure_points, depth_text]
677
+ except Exception as e:
678
+ print(f"Measurement error: {str(e)}")
679
+ return None, [], f"Measure error: {str(e)}"
680
+
681
+
682
+ # -------------------------- Example Data Loader --------------------------
683
+ def get_scene_info(examples_dir):
684
+ """
685
+ Load example scene information from examples directory
686
+ Args:
687
+ examples_dir (str): Directory containing example scenes
688
+ Returns:
689
+ list: List of scene dicts with name, path, thumbnail, image files
690
+ """
691
+ scenes = []
692
+ if not os.path.exists(examples_dir):
693
+ return scenes
694
+
695
+ # Iterate over example scene folders
696
+ for scene_folder in sorted(os.listdir(examples_dir)):
697
+ scene_path = os.path.join(examples_dir, scene_folder)
698
+ if not os.path.isdir(scene_path):
699
+ continue
700
+ # Load all image files
701
+ img_exts = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
702
+ image_files = []
703
+ for ext in img_exts:
704
+ image_files.extend(glob.glob(os.path.join(scene_path, ext)))
705
+ image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
706
+ # Skip empty folders
707
+ if not image_files:
708
+ continue
709
+ # Sort images and get thumbnail
710
+ image_files = sorted(image_files)
711
+ scenes.append(
712
+ {
713
+ "name": scene_folder,
714
+ "path": scene_path,
715
+ "thumbnail": image_files[0],
716
+ "num_images": len(image_files),
717
+ "image_files": image_files,
718
+ }
719
+ )
720
+ return scenes
721
+
722
+
723
+ def example_pipeline(
724
+ scene,
725
+ conf_thres=5.0,
726
+ show_cam=True,
727
+ show_index=True,
728
+ ceiling_remove=25,
729
+ ):
730
+ """
731
+ Pipeline for loading example scenes and running reconstruction
732
+ Args:
733
+ scene (dict): Example scene info from get_scene_info
734
+ Returns:
735
+ tuple: Gradio component update values
736
+ """
737
+ input_image_paths = scene["image_files"]
738
+ target_dir, image_paths = handle_uploads(input_image_paths)
739
+ frame_filter = "All" # Default to all frames for examples
740
+ # Run reconstruction
741
+ (
742
+ glbfile,
743
+ log_msg,
744
+ dropdown,
745
+ predictions,
746
+ measure_img,
747
+ measure_text,
748
+ measure_selector,
749
+ ) = gradio_demo(
750
+ target_dir, conf_thres, frame_filter, show_cam, show_index, ceiling_remove
751
+ )
752
+ return (
753
+ glbfile,
754
+ log_msg,
755
+ target_dir,
756
+ dropdown,
757
+ image_paths,
758
+ predictions,
759
+ measure_img,
760
+ measure_text,
761
+ measure_selector,
762
+ )
763
+
764
+
765
+ # -------------------------- 3D Visualization Utilities --------------------------
766
+ class SevenSegmentDigit:
767
+ """7-segment display definition for digital watch style 3D point cloud generation"""
768
+ # 7 segments definition: A(top), B(upper right), C(lower right), D(bottom), E(lower left), F(upper left), G(middle)
769
+ SEGMENTS = {
770
+ 'A': np.array([(x, 0.5, 0) for x in np.linspace(-0.4, 0.4, 80) for y in np.linspace(0.45, 0.55, 10)]),
771
+ 'B': np.array([(x, y, 0) for x in np.linspace(0.4, 0.5, 10) for y in np.linspace(0, 0.5, 80)]),
772
+ 'C': np.array([(x, y, 0) for x in np.linspace(0.4, 0.5, 10) for y in np.linspace(-0.5, 0, 80)]),
773
+ 'D': np.array([(x, y, 0) for x in np.linspace(-0.4, 0.4, 80) for y in np.linspace(-0.55, -0.45, 10)]),
774
+ 'E': np.array([(x, y, 0) for x in np.linspace(-0.5, -0.4, 10) for y in np.linspace(-0.5, 0, 80)]),
775
+ 'F': np.array([(x, y, 0) for x in np.linspace(-0.5, -0.4, 10) for y in np.linspace(0, 0.5, 80)]),
776
+ 'G': np.array([(x, y, 0) for x in np.linspace(-0.4, 0.4, 80) for y in np.linspace(-0.05, 0.05, 10)])
777
+ }
778
+
779
+ # Segment mapping for standard 0-9 digits (specify lit segments for each digit)
780
+ DIGIT_SEGMENTS = {
781
+ 0: ['A', 'B', 'C', 'D', 'E', 'F'],
782
+ 1: ['B', 'C'],
783
+ 2: ['A', 'B', 'G', 'E', 'D'],
784
+ 3: ['A', 'B', 'G', 'C', 'D'],
785
+ 4: ['F', 'G', 'B', 'C'],
786
+ 5: ['A', 'F', 'G', 'C', 'D'],
787
+ 6: ['A', 'F', 'G', 'C', 'D', 'E'],
788
+ 7: ['A', 'B', 'C'],
789
+ 8: ['A', 'B', 'C', 'D', 'E', 'F', 'G'],
790
+ 9: ['A', 'B', 'C', 'D', 'F', 'G']
791
+ }
792
+
793
+ @classmethod
794
+ def get_digit_points(cls, digit, scale=0.05):
795
+ """
796
+ Generate 3D point cloud for a single digital watch style digit (0-9)
797
+ Args:
798
+ digit (int): Target digit (0-9 only)
799
+ scale (float): Scale factor for point cloud size
800
+ Returns:
801
+ np.ndarray: N×3 array of 3D points for the digit
802
+ Raises:
803
+ ValueError: If digit is not in 0-9 range
804
+ """
805
+ if not 0 <= digit <= 9:
806
+ raise ValueError(f"Digit must be 0-9, got {digit}")
807
+
808
+ # Combine lit segments for the target digit
809
+ segments = cls.DIGIT_SEGMENTS[digit]
810
+ points = np.vstack([cls.SEGMENTS[seg] for seg in segments])
811
+
812
+ # Scale point cloud and center to origin
813
+ points = points * scale
814
+ points -= points.mean(axis=0)
815
+
816
+ # Remove duplicate points and supplement sparse points (ensure dense distribution)
817
+ points = np.unique(points.round(6), axis=0)
818
+ if len(points) < 200:
819
+ points = trimesh.sample.sample_surface(trimesh.Trimesh(points), 500)[0]
820
+
821
+ return points
822
+
823
+
824
+ def create_number_point_cloud(number, scale=0.05):
825
+ """
826
+ Generate 3D point cloud for multi-digit number (digital watch style), facing +Y axis
827
+ Args:
828
+ number (int): Non-negative target integer (any digit length)
829
+ scale (float): Scale factor for single digit point cloud size
830
+ Returns:
831
+ trimesh.PointCloud: Colored (red) 3D point cloud of the number
832
+ Raises:
833
+ ValueError: If number is negative or non-integer
834
+ """
835
+ if not isinstance(number, int) or number < 0:
836
+ raise ValueError(f"Number must be non-negative integer, got {number}")
837
+
838
+ # Split number into individual digits and handle 0 specially
839
+ digits = [int(d) for d in str(number)] if number != 0 else [0]
840
+ all_points, spacing = [], scale * 1.2
841
+ total_width = (len(digits)-1) * spacing
842
+
843
+ # Arrange digits horizontally and center the whole number
844
+ for idx, d in enumerate(digits):
845
+ digit_points = SevenSegmentDigit.get_digit_points(d, scale)
846
+ digit_points[:, 0] += -total_width/2 + idx * spacing
847
+ all_points.append(digit_points)
848
+
849
+ # Merge all digit points and apply rotation to face +Y axis
850
+ all_points = np.vstack(all_points)
851
+ rotation = np.array([[1, 0, 0],
852
+ [0, 0, -1],
853
+ [0, 1, 0]])
854
+ all_points = np.dot(all_points, rotation.T)
855
+
856
+ # Create red point cloud (classic digital watch color)
857
+ colors = np.full((len(all_points), 3), [255, 0, 0], dtype=np.uint8)
858
+
859
+ return trimesh.PointCloud(all_points, colors)
860
+
861
+
862
+ def predictions_to_glb(
863
+ predictions,
864
+ conf_thres=50.0,
865
+ filter_by_frames="all",
866
+ show_cam=True,
867
+ show_index=True,
868
+ ceiling_remove=25,
869
+ target_dir=None,
870
+ prediction_mode="Predicted Pointmap",
871
+ ) -> trimesh.Scene:
872
+ """
873
+ Convert VGGT model predictions to a 3D trimesh Scene (exportable to GLB)
874
+ Integrates colored point cloud, camera meshes and digital camera indexes
875
+ Args:
876
+ predictions (dict): Model prediction dict with keys:
877
+ - world_points: 3D point coordinates (S, H, W, 3)
878
+ - world_points_conf: Confidence scores (S, H, W)
879
+ - images: Input images (S, H, W, 3)
880
+ - extrinsic: Camera extrinsic matrices (S, 3, 4)
881
+ conf_thres (float): Low-confidence point filter (percentile, 0-100)
882
+ filter_by_frames (str): Frame filter ("all" or specific frame index like "0:")
883
+ show_cam (bool): Whether to add camera mesh visualization to scene
884
+ show_index (bool): Whether to add digital index point cloud above cameras
885
+ ceiling_remove (float): Percentage of top Y-coordinate points to remove as ceiling (0-100, 0=disabled)
886
+ target_dir (str): Directory for intermediate files (images)
887
+ prediction_mode (str): Prediction branch ("Predicted Pointmap" / others for depth-based)
888
+ Returns:
889
+ trimesh.Scene: 3D scene with point cloud, cameras and indexes (if enabled)
890
+ Raises:
891
+ ValueError: If predictions is not a dictionary
892
+ """
893
+ if not isinstance(predictions, dict):
894
+ raise ValueError("predictions must be a dictionary")
895
+
896
+ conf_thres = 10.0 if conf_thres is None else conf_thres
897
+ print("Building GLB scene")
898
+ selected_frame_idx = None
899
+
900
+ # Parse selected frame index from filter string (e.g., "0:" -> 0)
901
+ if filter_by_frames not in ["all", "All"]:
902
+ try:
903
+ selected_frame_idx = int(filter_by_frames.split(":")[0])
904
+ except (ValueError, IndexError):
905
+ pass
906
+
907
+ # Select prediction branch (Pointmap direct / Depthmap derived)
908
+ if "Pointmap" in prediction_mode:
909
+ print("Using Pointmap Branch")
910
+ if "world_points" in predictions:
911
+ pred_world_points = predictions["world_points"]
912
+ pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0]))
913
+ else:
914
+ print("Warning: world_points not found, falling back to depth-based world points")
915
+ pred_world_points = predictions["world_points_from_depth"]
916
+ pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
917
+ else:
918
+ print("Using Depthmap and Camera Branch")
919
+ pred_world_points = predictions["world_points_from_depth"]
920
+ pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0]))
921
+
922
+ # Extract core prediction data: images and camera extrinsic matrices
923
+ images = predictions["images"]
924
+ camera_matrices = predictions["extrinsic"]
925
+
926
+ # Filter prediction data to selected single frame if specified
927
+ if selected_frame_idx is not None:
928
+ pred_world_points = pred_world_points[selected_frame_idx][None]
929
+ pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
930
+ images = images[selected_frame_idx][None]
931
+ camera_matrices = camera_matrices[selected_frame_idx][None]
932
+
933
+ # Reshape 3D points and convert image colors to 8-bit RGB (match point cloud)
934
+ vertices_3d = pred_world_points.reshape(-1, 3)
935
+ if images.ndim == 4 and images.shape[1] == 3: # Convert NCHW to NHWC format
936
+ colors_rgb = np.transpose(images, (0, 2, 3, 1))
937
+ else: # Direct use if already NHWC format
938
+ colors_rgb = images
939
+ colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
940
+
941
+ # Filter points by confidence threshold (remove low-confidence points)
942
+ conf = pred_world_points_conf.reshape(-1)
943
+ conf_threshold = 0.0 if conf_thres == 0.0 else np.percentile(conf, conf_thres)
944
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
945
+
946
+ vertices_3d = vertices_3d[conf_mask]
947
+ colors_rgb = colors_rgb[conf_mask]
948
+
949
+ # Create dummy point if no valid points left (avoid scene empty error)
950
+ if vertices_3d is None or np.asarray(vertices_3d).size == 0:
951
+ vertices_3d = np.array([[1, 0, 0]])
952
+ colors_rgb = np.array([[255, 255, 255]])
953
+ scene_scale = 1
954
+ else:
955
+ # Calculate scene scale by 5th/95th percentile bounding box diagonal
956
+ lower_percentile = np.percentile(vertices_3d, 5, axis=0)
957
+ upper_percentile = np.percentile(vertices_3d, 95, axis=0)
958
+ scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
959
+
960
+ # Initialize 3D scene and colormap for camera unique colors
961
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
962
+ scene_3d = trimesh.Scene()
963
+
964
+ # Filter out ceiling points (remove top N% of Y-coordinates by percentile)
965
+ if ceiling_remove > 0 and vertices_3d.size > 1:
966
+ y_coords = vertices_3d[:, 1]
967
+ y_percentile = np.percentile(y_coords, ceiling_remove)
968
+ mask = y_coords > y_percentile
969
+ vertices_3d = vertices_3d[mask]
970
+ colors_rgb = colors_rgb[mask]
971
+
972
+ # Add colored 3D point cloud to the scene
973
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
974
+ scene_3d.add_geometry(point_cloud_data)
975
+
976
+ # Convert 3x4 camera extrinsics to 4x4 homogeneous matrices
977
+ num_cameras = len(camera_matrices)
978
+ extrinsics_matrices = np.zeros((num_cameras, 4, 4))
979
+ extrinsics_matrices[:, :3, :4] = camera_matrices
980
+ extrinsics_matrices[:, 3, 3] = 1
981
+
982
+ # Add camera meshes and digital index point clouds to the scene
983
+ for i in range(num_cameras):
984
+ camera_to_world = extrinsics_matrices[i]
985
+ rgba_color = colormap(i / num_cameras) # Unique color for each camera
986
+ current_color = tuple(int(255 * x) for x in rgba_color[:3])
987
+
988
+ # Add camera mesh to scene
989
+ if show_cam:
990
+ integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
991
+
992
+ # Add digital index point cloud above each camera (red, digital watch style)
993
+ if show_index:
994
+ camera_center = camera_to_world[:3, 3]
995
+ y_offset = 0.5 # Y-axis offset for index position (above camera)
996
+ number_position = camera_center + np.array([0, y_offset, 0])
997
+
998
+ # Generate index point cloud and translate to target position
999
+ number_scale = 0.3
1000
+ number_pc = create_number_point_cloud(number=i, scale=number_scale)
1001
+ number_pc.apply_translation(number_position)
1002
+ scene_3d.add_geometry(number_pc)
1003
+
1004
+ # Align the whole scene to the first camera's viewing perspective
1005
+ scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
1006
+
1007
+ print("GLB Scene built successfully")
1008
+ return scene_3d
1009
+
1010
+
1011
+ def integrate_camera_into_scene(
1012
+ scene: trimesh.Scene, transform: np.ndarray, face_colors: tuple, scene_scale: float
1013
+ ):
1014
+ """
1015
+ Add a 3D cone-shaped camera mesh to the 3D scene with specified transform and color
1016
+ Args:
1017
+ scene (trimesh.Scene): Target 3D scene to add camera mesh
1018
+ transform (np.ndarray): 4x4 camera-to-world transformation matrix
1019
+ face_colors (tuple): RGB color tuple (0-255) for camera mesh faces
1020
+ scene_scale (float): Overall scale of the 3D scene (for camera size adaptation)
1021
+ """
1022
+ # Set camera mesh size based on scene scale
1023
+ cam_width = scene_scale * 0.02
1024
+ cam_height = scene_scale * 0.02
1025
+
1026
+ # 45° Z-axis rotation for camera cone shape and backward translation
1027
+ rot_45_degree = np.eye(4)
1028
+ rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
1029
+ rot_45_degree[2, 3] = -cam_height
1030
+
1031
+ # Combine OpenGL conversion, rotation and camera transform matrices
1032
+ opengl_transform = get_opengl_conversion_matrix()
1033
+ complete_transform = transform @ opengl_transform @ rot_45_degree
1034
+ camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
1035
+
1036
+ # Slight Z-axis rotation for camera mesh detail enhancement
1037
+ slight_rotation = np.eye(4)
1038
+ slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
1039
+
1040
+ # Combine original, scaled and rotated cone vertices for dense camera mesh
1041
+ vertices_combined = np.concatenate(
1042
+ [
1043
+ camera_cone_shape.vertices,
1044
+ 0.95 * camera_cone_shape.vertices,
1045
+ transform_points(slight_rotation, camera_cone_shape.vertices),
1046
+ ]
1047
+ )
1048
+ vertices_transformed = transform_points(complete_transform, vertices_combined)
1049
+
1050
+ # Compute camera mesh faces from cone shape
1051
+ mesh_faces = compute_camera_faces(camera_cone_shape)
1052
+
1053
+ # Create camera mesh with specified color and add to scene
1054
+ camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
1055
+ camera_mesh.visual.face_colors[:, :3] = face_colors
1056
+ scene.add_geometry(camera_mesh)
1057
+
1058
+
1059
+ def apply_scene_alignment(
1060
+ scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
1061
+ ) -> trimesh.Scene:
1062
+ """
1063
+ Align the 3D scene to the first camera's viewing perspective with OpenGL conversion
1064
+ Args:
1065
+ scene_3d (trimesh.Scene): Unaligned 3D scene
1066
+ extrinsics_matrices (np.ndarray): N×4×4 camera extrinsic matrices
1067
+ Returns:
1068
+ trimesh.Scene: Aligned 3D scene
1069
+ """
1070
+ # Get OpenGL coordinate conversion matrix and 180° Y-axis rotation for alignment
1071
+ opengl_conversion_matrix = get_opengl_conversion_matrix()
1072
+ align_rotation = np.eye(4)
1073
+ align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
1074
+
1075
+ # Combine transformation matrices and apply to the whole scene
1076
+ initial_transformation = np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
1077
+ scene_3d.apply_transform(initial_transformation)
1078
+ return scene_3d
1079
+
1080
+
1081
+ def get_opengl_conversion_matrix() -> np.ndarray:
1082
+ """
1083
+ Create 4x4 OpenGL coordinate system conversion matrix (flip Y and Z axes)
1084
+ Returns:
1085
+ np.ndarray: 4x4 identity-based conversion matrix
1086
+ """
1087
+ matrix = np.identity(4)
1088
+ matrix[1, 1] = -1 # Flip Y axis
1089
+ matrix[2, 2] = -1 # Flip Z axis
1090
+ return matrix
1091
+
1092
+
1093
+ def transform_points(
1094
+ transformation: np.ndarray, points: np.ndarray, dim: int = None
1095
+ ) -> np.ndarray:
1096
+ """
1097
+ Apply 4x4 homogeneous transformation matrix to a set of 3D points
1098
+ Args:
1099
+ transformation (np.ndarray): 4x4 transformation matrix
1100
+ points (np.ndarray): N×3 array of 3D points to transform
1101
+ dim (int, optional): Target dimension of output points (default: 3)
1102
+ Returns:
1103
+ np.ndarray: N×dim array of transformed points (same shape as input except last dim)
1104
+ """
1105
+ points = np.asarray(points)
1106
+ initial_shape = points.shape[:-1]
1107
+ dim = dim or points.shape[-1]
1108
+
1109
+ # Transpose matrix and apply affine transformation to points
1110
+ transformation = transformation.swapaxes(-1, -2)
1111
+ points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
1112
+
1113
+ # Reshape transformed points to original shape (excluding last dimension)
1114
+ result = points[..., :dim].reshape(*initial_shape, dim)
1115
+ return result
1116
+
1117
+
1118
+ def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
1119
+ """
1120
+ Compute face indices for camera mesh from original cone shape faces (enhance detail)
1121
+ Args:
1122
+ cone_shape (trimesh.Trimesh): Original cone mesh for camera base shape
1123
+ Returns:
1124
+ np.ndarray: M×3 array of face indices for the camera mesh
1125
+ """
1126
+ faces_list = []
1127
+ num_vertices_cone = len(cone_shape.vertices)
1128
+
1129
+ # Generate enhanced faces from cone faces (skip origin vertex 0)
1130
+ for face in cone_shape.faces:
1131
+ if 0 in face:
1132
+ continue
1133
+ v1, v2, v3 = face
1134
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
1135
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
1136
+
1137
+ # Add multiple face variations for dense camera mesh
1138
+ faces_list.extend(
1139
+ [
1140
+ (v1, v2, v2_offset),
1141
+ (v1, v1_offset, v3),
1142
+ (v3_offset, v2, v3),
1143
+ (v1, v2, v2_offset_2),
1144
+ (v1, v1_offset_2, v3),
1145
+ (v3_offset_2, v2, v3),
1146
+ ]
1147
+ )
1148
+
1149
+ # Add reversed faces for double-sided rendering
1150
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
1151
+ return np.array(faces_list)
1152
+
1153
+
1154
+ # -------------------------- Gradio UI Construction --------------------------
1155
+ if __name__ == "__main__":
1156
+ # Gradio theme configuration
1157
+ theme = gr.themes.Ocean()
1158
+ theme.set(
1159
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
1160
+ checkbox_label_text_color_selected="*button_primary_text_color",
1161
+ )
1162
+
1163
+ with gr.Blocks(
1164
+ theme=theme,
1165
+ title="Argus - 3D Reconstruction",
1166
+ css="""
1167
+ .custom-log * {
1168
+ font-style: italic;
1169
+ font-size: 20px !important;
1170
+ background-image: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1171
+ -webkit-background-clip: text;
1172
+ background-clip: text;
1173
+ font-weight: 600 !important;
1174
+ color: transparent !important;
1175
+ text-align: center !important;
1176
+ }
1177
+ .example-log * {
1178
+ font-size: 15px !important;
1179
+ background-image: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1180
+ -webkit-background-clip: text;
1181
+ background-clip: text;
1182
+ color: transparent !important;
1183
+ font-weight: 500 !important;
1184
+ }
1185
+ .header-banner {
1186
+ background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
1187
+ border-radius: 16px;
1188
+ padding: 32px 24px 24px;
1189
+ margin-bottom: 16px;
1190
+ border: 1px solid #e2e8f0;
1191
+ text-align: center;
1192
+ }
1193
+ .header-banner h1 {
1194
+ font-size: 28px;
1195
+ font-weight: 700;
1196
+ color: #1e293b;
1197
+ margin: 12px 0 8px;
1198
+ }
1199
+ .header-banner .links {
1200
+ margin-top: 12px;
1201
+ font-size: 15px;
1202
+ }
1203
+ .header-banner .links a {
1204
+ margin: 0 10px;
1205
+ color: #4f46e5;
1206
+ text-decoration: none;
1207
+ font-weight: 500;
1208
+ }
1209
+ .header-banner .links a:hover {
1210
+ text-decoration: underline;
1211
+ }
1212
+ .instructions {
1213
+ font-size: 14px;
1214
+ color: #475569;
1215
+ line-height: 1.7;
1216
+ padding: 12px 20px;
1217
+ background: #f8fafc;
1218
+ border-radius: 10px;
1219
+ border: 1px solid #e2e8f0;
1220
+ }
1221
+ .instructions ol {
1222
+ padding-left: 20px;
1223
+ margin: 8px 0;
1224
+ }
1225
+ .instructions li {
1226
+ margin-bottom: 4px;
1227
+ }
1228
+ .param-group {
1229
+ padding: 8px 0;
1230
+ }
1231
+ footer {visibility: hidden;}
1232
+ """,
1233
+ ) as demo:
1234
+ # Hidden state components for data passing
1235
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
1236
+ processed_data_state = gr.State(value=None)
1237
+ measure_points_state = gr.State(value=[])
1238
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
1239
+
1240
+ # Load and display logo (base64 encoded)
1241
+ root_dir = Path(__file__).parent
1242
+ logo_path = root_dir / "assets" / "argus_logo.png"
1243
+ if logo_path.exists():
1244
+ with open(logo_path, "rb") as f:
1245
+ logo_base64 = base64.b64encode(f.read()).decode()
1246
+ logo_src = f"data:image/png;base64,{logo_base64}"
1247
+ else:
1248
+ logo_src = "" # Fallback if logo not found
1249
+
1250
+ # UI Header and Instructions
1251
+ gr.HTML(
1252
+ f"""
1253
+ <div class="header-banner">
1254
+ <div style="display: flex; justify-content: center;">
1255
+ <img src="{logo_src}" alt="Argus Logo" style="height: 72px; border-radius: 8px;">
1256
+ </div>
1257
+ <h1>Argus: Metric Panoramic 3D Reconstruction for Indoor Scenes</h1>
1258
+ <div class="links">
1259
+ <a href="https://github.com/realsee-developer/Argus" target="_blank">🌟 GitHub</a>
1260
+ <a href="https://argus-paper.realsee.ai" target="_blank">🚀 Project Page</a>
1261
+ <a href="https://arxiv.org/abs/2606.30047" target="_blank">📄 Paper</a>
1262
+ </div>
1263
+ </div>
1264
+ <div class="instructions">
1265
+ <ol>
1266
+ <li><strong>Upload</strong> a set of ERP panoramic images on the left.</li>
1267
+ <li><strong>Click "Reconstruct"</strong> to run the 3D reconstruction pipeline.</li>
1268
+ <li><strong>Explore</strong> the 3D model — rotate, pan, zoom, and download the GLB.</li>
1269
+ <li><strong>Measure</strong> — switch to the Metric tab and click two points to measure real-world distance.</li>
1270
+ </ol>
1271
+ </div>
1272
+ """
1273
+ )
1274
+
1275
+ # Main UI Layout (2 columns: upload/gallery | 3D model/measurement)
1276
+ with gr.Row(equal_height=False):
1277
+ with gr.Column(scale=2, min_width=280):
1278
+ input_images = gr.File(
1279
+ file_count="multiple", label="📁 Upload Panoramic Images", interactive=True
1280
+ )
1281
+ image_gallery = gr.Gallery(
1282
+ label="Preview",
1283
+ columns=3,
1284
+ height="280px",
1285
+ object_fit="contain",
1286
+ preview=True,
1287
+ )
1288
+
1289
+ with gr.Column(scale=5):
1290
+ # Log output
1291
+ log_output = gr.Markdown(
1292
+ "Upload panoramic images (ERP), then click Reconstruct.",
1293
+ elem_classes=["custom-log"],
1294
+ )
1295
+ # Tabbed interface: 3D Model + Metric Measure
1296
+ with gr.Tabs():
1297
+ with gr.Tab("🏠 3D Model"):
1298
+ reconstruction_output = gr.Model3D(
1299
+ height=540, zoom_speed=0.5, pan_speed=0.5
1300
+ )
1301
+ with gr.Tab("📏 Metric Measure"):
1302
+ gr.Markdown(
1303
+ "Click two points on the panorama to measure the real-world distance between them."
1304
+ )
1305
+ with gr.Row():
1306
+ prev_measure_btn = gr.Button(
1307
+ "◀ Prev", size="sm", scale=1
1308
+ )
1309
+ measure_view_selector = gr.Dropdown(
1310
+ choices=["View 1"],
1311
+ value="View 1",
1312
+ label="Select View",
1313
+ scale=3,
1314
+ interactive=True,
1315
+ allow_custom_value=True,
1316
+ )
1317
+ next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
1318
+ measure_image = gr.Image(
1319
+ type="numpy",
1320
+ show_label=False,
1321
+ format="webp",
1322
+ interactive=False,
1323
+ sources=[],
1324
+ )
1325
+ measure_text = gr.Markdown("")
1326
+
1327
+ # Action buttons
1328
+ with gr.Row():
1329
+ submit_btn = gr.Button("🔨 Reconstruct", scale=2, variant="primary")
1330
+ clear_btn = gr.ClearButton(
1331
+ [
1332
+ input_images,
1333
+ reconstruction_output,
1334
+ log_output,
1335
+ target_dir_output,
1336
+ image_gallery,
1337
+ ],
1338
+ value="🗑️ Clear",
1339
+ scale=1,
1340
+ )
1341
+
1342
+ # Reconstruction parameters
1343
+ gr.Markdown("**Visualization Settings**")
1344
+ with gr.Row():
1345
+ conf_thres = gr.Slider(
1346
+ 0, 100, 5, 1, label="Confidence Threshold (%)"
1347
+ )
1348
+ ceiling_remove = gr.Slider(
1349
+ 0, 100, 25, 1, label="Ceiling Remove (%)"
1350
+ )
1351
+ with gr.Row():
1352
+ frame_filter = gr.Dropdown(
1353
+ ["All"], "All", label="Show Points from Frame", scale=2
1354
+ )
1355
+ show_cam = gr.Checkbox(True, label="Show Camera")
1356
+ show_index = gr.Checkbox(True, label="Show Index")
1357
+
1358
+ # Example Scenes Section
1359
+ gr.Markdown("---")
1360
+ gr.Markdown("### 🖼️ Example Scenes")
1361
+ gr.Markdown("Click any thumbnail to load and reconstruct.", elem_classes=["example-log"])
1362
+ example_scenes = get_scene_info(args.examples_dir)
1363
+ # Create 4-column example thumbnail grid
1364
+ if example_scenes:
1365
+ for i in range(0, len(example_scenes), 4):
1366
+ with gr.Row():
1367
+ for j in range(4):
1368
+ idx = i + j
1369
+ if idx < len(example_scenes):
1370
+ scene = example_scenes[idx]
1371
+ with gr.Column(scale=1):
1372
+ scene_state = gr.State(value=scene)
1373
+ scene_img = gr.Image(
1374
+ value=scene["thumbnail"],
1375
+ height=150,
1376
+ interactive=False,
1377
+ show_label=False,
1378
+ sources=[],
1379
+ )
1380
+ gr.Markdown(
1381
+ f"**{scene['name']}** \n {scene['num_images']} images"
1382
+ )
1383
+ # Bind thumbnail click to example pipeline
1384
+ scene_img.select(
1385
+ example_pipeline,
1386
+ [scene_state],
1387
+ [
1388
+ reconstruction_output,
1389
+ log_output,
1390
+ target_dir_output,
1391
+ frame_filter,
1392
+ image_gallery,
1393
+ processed_data_state,
1394
+ measure_image,
1395
+ measure_text,
1396
+ measure_view_selector,
1397
+ ],
1398
+ )
1399
+ else:
1400
+ with gr.Column(scale=1):
1401
+ pass # Empty column for grid alignment
1402
+
1403
+ # -------------------------- Gradio Event Bindings --------------------------
1404
+ # Reconstruct button logic
1405
+ submit_btn.click(clear_fields, [], [reconstruction_output]).then(
1406
+ update_log, [], [log_output]
1407
+ ).then(
1408
+ gradio_demo,
1409
+ [
1410
+ target_dir_output,
1411
+ conf_thres,
1412
+ frame_filter,
1413
+ show_cam,
1414
+ show_index,
1415
+ ceiling_remove,
1416
+ ],
1417
+ [
1418
+ reconstruction_output,
1419
+ log_output,
1420
+ frame_filter,
1421
+ processed_data_state,
1422
+ measure_image,
1423
+ measure_text,
1424
+ measure_view_selector,
1425
+ ],
1426
+ ).then(
1427
+ lambda: "False", [], [is_example]
1428
+ )
1429
+
1430
+ # Real-time parameter update for 3D visualization
1431
+ for param in [conf_thres, frame_filter, show_cam, show_index, ceiling_remove]:
1432
+ param.change(
1433
+ update_visualization,
1434
+ [
1435
+ target_dir_output,
1436
+ conf_thres,
1437
+ frame_filter,
1438
+ show_cam,
1439
+ show_index,
1440
+ ceiling_remove,
1441
+ is_example,
1442
+ ],
1443
+ [reconstruction_output, log_output],
1444
+ )
1445
+
1446
+ # Auto-update gallery on file upload
1447
+ input_images.change(
1448
+ update_gallery_on_upload,
1449
+ [input_images],
1450
+ [reconstruction_output, target_dir_output, image_gallery, log_output],
1451
+ )
1452
+
1453
+ # Metric measure event bindings
1454
+ measure_image.select(
1455
+ measure,
1456
+ [processed_data_state, measure_points_state, measure_view_selector],
1457
+ [measure_image, measure_points_state, measure_text],
1458
+ )
1459
+ # Measure view navigation
1460
+ prev_measure_btn.click(
1461
+ lambda d, s: navigate_measure_view(d, s, -1),
1462
+ [processed_data_state, measure_view_selector],
1463
+ [measure_view_selector, measure_image, measure_points_state],
1464
+ )
1465
+ next_measure_btn.click(
1466
+ lambda d, s: navigate_measure_view(d, s, 1),
1467
+ [processed_data_state, measure_view_selector],
1468
+ [measure_view_selector, measure_image, measure_points_state],
1469
+ )
1470
+ # Update measure view when selector changes
1471
+ measure_view_selector.change(
1472
+ lambda d, s: (
1473
+ update_measure_view(d, int(s.split()[1]) - 1) if s else (None, [])
1474
+ ),
1475
+ [processed_data_state, measure_view_selector],
1476
+ [measure_image, measure_points_state],
1477
+ )
1478
+
1479
+ # Footer acknowledgement
1480
+ gr.HTML(
1481
+ """
1482
+ <hr style="margin-top: 40px; margin-bottom: 20px; border-color: #e2e8f0;">
1483
+ <div style="text-align: center; font-size: 13px; color: #94a3b8; margin-bottom: 20px;">
1484
+ <p style="margin-bottom: 8px; font-weight: 500; color: #64748b;">Acknowledgements</p>
1485
+ <p>Built upon
1486
+ <a href="https://github.com/facebookresearch/vggt" style="color: #6366f1;">VGGT</a> &
1487
+ <a href="https://github.com/facebookresearch/map-anything" style="color: #6366f1;">Map-Anything</a>
1488
+ </p>
1489
+ </div>
1490
+ """
1491
+ )
1492
+
1493
+ # Launch Gradio demo
1494
+ demo.queue(max_size=20).launch(
1495
+ show_error=True,
1496
+ share=args.share,
1497
+ server_name=args.server_name,
1498
+ server_port=args.port,
1499
+ )
argus/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2026 Realsee. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0.
argus/heads/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2026 Realsee. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0.
argus/heads/camera_head.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from argus.layers import Mlp
6
+ from argus.layers.block import Block
7
+ from argus.heads.head_act import activate_pose
8
+
9
+
10
+ class CameraHead(nn.Module):
11
+ """
12
+ CameraHead predicts camera parameters from token representations using iterative refinement.
13
+
14
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ dim_in: int = 2048,
20
+ trunk_depth: int = 4,
21
+ num_heads: int = 16,
22
+ mlp_ratio: int = 4,
23
+ init_values: float = 0.01,
24
+ trans_act: str = "linear",
25
+ quat_act: str = "linear",
26
+ ):
27
+ super().__init__()
28
+
29
+ self.target_dim = 9
30
+ self.trans_act = trans_act
31
+ self.quat_act = quat_act
32
+ self.trunk_depth = trunk_depth
33
+
34
+ # Build the trunk using a sequence of transformer blocks.
35
+ self.trunk = nn.Sequential(
36
+ *[
37
+ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
38
+ for _ in range(trunk_depth)
39
+ ]
40
+ )
41
+
42
+ # Normalizations for camera token and trunk output.
43
+ self.token_norm = nn.LayerNorm(dim_in)
44
+ self.trunk_norm = nn.LayerNorm(dim_in)
45
+
46
+ # Learnable empty camera pose token.
47
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
48
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
49
+
50
+ # Module for producing modulation parameters: shift, scale, and a gate.
51
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
52
+
53
+ # Adaptive layer normalization without affine parameters.
54
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
55
+ self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
56
+
57
+ # conf branch for T and R
58
+ self.conf_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=2, drop=0)
59
+
60
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
61
+ """
62
+ Forward pass to predict camera parameters.
63
+
64
+ Args:
65
+ aggregated_tokens_list (list): List of token tensors from the network;
66
+ the last tensor is used for prediction.
67
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
68
+
69
+ Returns:
70
+ list: A list of predicted camera encodings (post-activation) from each iteration.
71
+ """
72
+ # Use tokens from the last block for camera prediction.
73
+ tokens = aggregated_tokens_list[-1]
74
+
75
+ # Extract the camera tokens
76
+ pose_tokens = tokens[:, :, 0]
77
+ pose_tokens = self.token_norm(pose_tokens)
78
+
79
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
80
+ return pred_pose_enc_list
81
+
82
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
83
+ """
84
+ Iteratively refine camera pose predictions.
85
+
86
+ Args:
87
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
88
+ num_iterations (int): Number of refinement iterations.
89
+
90
+ Returns:
91
+ list: List of activated camera encodings from each iteration.
92
+ """
93
+ B, S, C = pose_tokens.shape
94
+ pred_pose_enc = None
95
+ pred_pose_enc_conf = None
96
+ pred_pose_enc_list = []
97
+
98
+ for _ in range(num_iterations):
99
+ # Use a learned empty pose for the first iteration.
100
+ if pred_pose_enc is None:
101
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
102
+ else:
103
+ # Detach the previous prediction to avoid backprop through time.
104
+ pred_pose_enc = pred_pose_enc.detach()
105
+ module_input = self.embed_pose(pred_pose_enc)
106
+
107
+ # Generate modulation parameters and split them into shift, scale, and gate components.
108
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
109
+
110
+ # Adaptive layer normalization and modulation.
111
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
112
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
113
+
114
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
115
+ # Compute the delta update for the pose encoding.
116
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
117
+ pred_pose_enc_conf_delta = self.conf_branch(self.trunk_norm(pose_tokens_modulated))
118
+
119
+ if pred_pose_enc is None:
120
+ pred_pose_enc = pred_pose_enc_delta
121
+ pred_pose_enc_conf = pred_pose_enc_conf_delta
122
+ else:
123
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
124
+ pred_pose_enc_conf = pred_pose_enc_conf + pred_pose_enc_conf_delta
125
+
126
+ # Apply final activation functions for translation, quaternion
127
+ activated_pose = activate_pose(
128
+ pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act
129
+ )
130
+ activated_conf = 1 + pred_pose_enc_conf.exp()
131
+ activated_pose = torch.cat([activated_pose, activated_conf], dim=-1)
132
+ pred_pose_enc_list.append(activated_pose)
133
+
134
+ return pred_pose_enc_list
135
+
136
+
137
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
138
+ """
139
+ Modulate the input tensor using scaling and shifting parameters.
140
+ """
141
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
142
+ return x * (1 + scale) + shift
argus/heads/dpt_head.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from .head_act import activate_head
8
+ from .utils import create_uv_grid, position_grid_to_embed
9
+
10
+
11
+ class DPTHead(nn.Module):
12
+ """
13
+ DPT Head for dense prediction tasks.
14
+
15
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
16
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
17
+ backbone and produces dense predictions by fusing multi-scale features.
18
+
19
+ Args:
20
+ dim_in (int): Input dimension (channels).
21
+ patch_size (int, optional): Patch size. Default is 14.
22
+ output_dim (int, optional): Number of output channels. Default is 4.
23
+ activation (str, optional): Activation type. Default is "inv_log".
24
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
25
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
26
+ out_channels (List[int], optional): Output channels for each intermediate layer.
27
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
28
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
29
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
30
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ dim_in: int,
36
+ patch_size: int = 14,
37
+ output_dim: int = 4,
38
+ activation: str = "inv_log",
39
+ conf_activation: str = "expp1",
40
+ features: int = 256,
41
+ out_channels: List[int] = [256, 512, 1024, 1024],
42
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
43
+ pos_embed: bool = True,
44
+ feature_only: bool = False,
45
+ down_ratio: int = 1,
46
+ ) -> None:
47
+ super(DPTHead, self).__init__()
48
+ self.patch_size = patch_size
49
+ self.activation = activation
50
+ self.conf_activation = conf_activation
51
+ self.pos_embed = pos_embed
52
+ self.feature_only = feature_only
53
+ self.down_ratio = down_ratio
54
+ self.intermediate_layer_idx = intermediate_layer_idx
55
+
56
+ self.norm = nn.LayerNorm(dim_in)
57
+
58
+ # Projection layers for each output channel from tokens.
59
+ self.projects = nn.ModuleList(
60
+ [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
61
+ )
62
+
63
+ # Resize layers for upsampling feature maps.
64
+ self.resize_layers = nn.ModuleList(
65
+ [
66
+ nn.ConvTranspose2d(
67
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
68
+ ),
69
+ nn.ConvTranspose2d(
70
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
71
+ ),
72
+ nn.Identity(),
73
+ nn.Conv2d(
74
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
75
+ ),
76
+ ]
77
+ )
78
+
79
+ self.scratch = _make_scratch(out_channels, features, expand=False)
80
+
81
+ # Attach additional modules to scratch.
82
+ self.scratch.stem_transpose = None
83
+ self.scratch.refinenet1 = _make_fusion_block(features)
84
+ self.scratch.refinenet2 = _make_fusion_block(features)
85
+ self.scratch.refinenet3 = _make_fusion_block(features)
86
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
87
+
88
+ head_features_1 = features
89
+ head_features_2 = 32
90
+
91
+ if feature_only:
92
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
93
+ else:
94
+ self.scratch.output_conv1 = nn.Conv2d(
95
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
96
+ )
97
+ conv2_in_channels = head_features_1 // 2
98
+
99
+ self.scratch.output_conv2 = nn.Sequential(
100
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
101
+ nn.ReLU(inplace=True),
102
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
103
+ )
104
+
105
+ def forward(
106
+ self,
107
+ aggregated_tokens_list: List[torch.Tensor],
108
+ images: torch.Tensor,
109
+ patch_start_idx: int,
110
+ frames_chunk_size: int = 8,
111
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
112
+ """
113
+ Forward pass through the DPT head, supports processing by chunking frames.
114
+ Args:
115
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
116
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
117
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
118
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
119
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
120
+ If None or larger than S, all frames are processed at once. Default: 8.
121
+
122
+ Returns:
123
+ Tensor or Tuple[Tensor, Tensor]:
124
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
125
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
126
+ """
127
+ B, S, _, H, W = images.shape
128
+
129
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
130
+ if frames_chunk_size is None or frames_chunk_size >= S:
131
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
132
+
133
+ # Otherwise, process frames in chunks to manage memory usage
134
+ assert frames_chunk_size > 0
135
+
136
+ # Process frames in batches
137
+ all_preds = []
138
+ all_conf = []
139
+
140
+ for frames_start_idx in range(0, S, frames_chunk_size):
141
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
142
+
143
+ # Process batch of frames
144
+ if self.feature_only:
145
+ chunk_output = self._forward_impl(
146
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
147
+ )
148
+ all_preds.append(chunk_output)
149
+ else:
150
+ chunk_preds, chunk_conf = self._forward_impl(
151
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
152
+ )
153
+ all_preds.append(chunk_preds)
154
+ all_conf.append(chunk_conf)
155
+
156
+ # Concatenate results along the sequence dimension
157
+ if self.feature_only:
158
+ return torch.cat(all_preds, dim=1)
159
+ else:
160
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
161
+
162
+ def _forward_impl(
163
+ self,
164
+ aggregated_tokens_list: List[torch.Tensor],
165
+ images: torch.Tensor,
166
+ patch_start_idx: int,
167
+ frames_start_idx: int = None,
168
+ frames_end_idx: int = None,
169
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
170
+ """
171
+ Implementation of the forward pass through the DPT head.
172
+
173
+ This method processes a specific chunk of frames from the sequence.
174
+
175
+ Args:
176
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
177
+ images (Tensor): Input images with shape [B, S, 3, H, W].
178
+ patch_start_idx (int): Starting index for patch tokens.
179
+ frames_start_idx (int, optional): Starting index for frames to process.
180
+ frames_end_idx (int, optional): Ending index for frames to process.
181
+
182
+ Returns:
183
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
184
+ """
185
+ if frames_start_idx is not None and frames_end_idx is not None:
186
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
187
+
188
+ B, S, _, H, W = images.shape
189
+
190
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
191
+
192
+ out = []
193
+ dpt_idx = 0
194
+
195
+ for layer_idx in self.intermediate_layer_idx:
196
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
197
+
198
+ # Select frames if processing a chunk
199
+ if frames_start_idx is not None and frames_end_idx is not None:
200
+ x = x[:, frames_start_idx:frames_end_idx]
201
+
202
+ x = x.reshape(B * S, -1, x.shape[-1])
203
+
204
+ x = self.norm(x)
205
+
206
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
207
+
208
+ x = self.projects[dpt_idx](x)
209
+ if self.pos_embed:
210
+ x = self._apply_pos_embed(x, W, H)
211
+ x = self.resize_layers[dpt_idx](x)
212
+
213
+ out.append(x)
214
+ dpt_idx += 1
215
+
216
+ # Fuse features from multiple layers.
217
+ out = self.scratch_forward(out)
218
+ # Interpolate fused output to match target image resolution.
219
+ out = custom_interpolate(
220
+ out,
221
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
222
+ mode="bilinear",
223
+ align_corners=True,
224
+ )
225
+
226
+ if self.pos_embed:
227
+ out = self._apply_pos_embed(out, W, H)
228
+
229
+ if self.feature_only:
230
+ return out.view(B, S, *out.shape[1:])
231
+
232
+ out = self.scratch.output_conv2(out)
233
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
234
+
235
+ preds = preds.view(B, S, *preds.shape[1:])
236
+ conf = conf.view(B, S, *conf.shape[1:])
237
+ return preds, conf
238
+
239
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
240
+ """
241
+ Apply positional embedding to tensor x.
242
+ """
243
+ patch_w = x.shape[-1]
244
+ patch_h = x.shape[-2]
245
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
246
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
247
+ pos_embed = pos_embed * ratio
248
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
249
+ return x + pos_embed
250
+
251
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
252
+ """
253
+ Forward pass through the fusion blocks.
254
+
255
+ Args:
256
+ features (List[Tensor]): List of feature maps from different layers.
257
+
258
+ Returns:
259
+ Tensor: Fused feature map.
260
+ """
261
+ layer_1, layer_2, layer_3, layer_4 = features
262
+
263
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
264
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
265
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
266
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
267
+
268
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
269
+ del layer_4_rn, layer_4
270
+
271
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
272
+ del layer_3_rn, layer_3
273
+
274
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
275
+ del layer_2_rn, layer_2
276
+
277
+ out = self.scratch.refinenet1(out, layer_1_rn)
278
+ del layer_1_rn, layer_1
279
+
280
+ out = self.scratch.output_conv1(out)
281
+ return out
282
+
283
+
284
+ ################################################################################
285
+ # Modules
286
+ ################################################################################
287
+
288
+
289
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
290
+ return FeatureFusionBlock(
291
+ features,
292
+ nn.ReLU(inplace=True),
293
+ deconv=False,
294
+ bn=False,
295
+ expand=False,
296
+ align_corners=True,
297
+ size=size,
298
+ has_residual=has_residual,
299
+ groups=groups,
300
+ )
301
+
302
+
303
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
304
+ scratch = nn.Module()
305
+ out_shape1 = out_shape
306
+ out_shape2 = out_shape
307
+ out_shape3 = out_shape
308
+ if len(in_shape) >= 4:
309
+ out_shape4 = out_shape
310
+
311
+ if expand:
312
+ out_shape1 = out_shape
313
+ out_shape2 = out_shape * 2
314
+ out_shape3 = out_shape * 4
315
+ if len(in_shape) >= 4:
316
+ out_shape4 = out_shape * 8
317
+
318
+ scratch.layer1_rn = nn.Conv2d(
319
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
320
+ )
321
+ scratch.layer2_rn = nn.Conv2d(
322
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
323
+ )
324
+ scratch.layer3_rn = nn.Conv2d(
325
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
326
+ )
327
+ if len(in_shape) >= 4:
328
+ scratch.layer4_rn = nn.Conv2d(
329
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
330
+ )
331
+ return scratch
332
+
333
+
334
+ class ResidualConvUnit(nn.Module):
335
+ """Residual convolution module."""
336
+
337
+ def __init__(self, features, activation, bn, groups=1):
338
+ """Init.
339
+
340
+ Args:
341
+ features (int): number of features
342
+ """
343
+ super().__init__()
344
+
345
+ self.bn = bn
346
+ self.groups = groups
347
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
348
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
349
+
350
+ self.norm1 = None
351
+ self.norm2 = None
352
+
353
+ self.activation = activation
354
+ self.skip_add = nn.quantized.FloatFunctional()
355
+
356
+ def forward(self, x):
357
+ """Forward pass.
358
+
359
+ Args:
360
+ x (tensor): input
361
+
362
+ Returns:
363
+ tensor: output
364
+ """
365
+
366
+ out = self.activation(x)
367
+ out = self.conv1(out)
368
+ if self.norm1 is not None:
369
+ out = self.norm1(out)
370
+
371
+ out = self.activation(out)
372
+ out = self.conv2(out)
373
+ if self.norm2 is not None:
374
+ out = self.norm2(out)
375
+
376
+ return self.skip_add.add(out, x)
377
+
378
+
379
+ class FeatureFusionBlock(nn.Module):
380
+ """Feature fusion block."""
381
+
382
+ def __init__(
383
+ self,
384
+ features,
385
+ activation,
386
+ deconv=False,
387
+ bn=False,
388
+ expand=False,
389
+ align_corners=True,
390
+ size=None,
391
+ has_residual=True,
392
+ groups=1,
393
+ ):
394
+ """Init.
395
+
396
+ Args:
397
+ features (int): number of features
398
+ """
399
+ super(FeatureFusionBlock, self).__init__()
400
+
401
+ self.deconv = deconv
402
+ self.align_corners = align_corners
403
+ self.groups = groups
404
+ self.expand = expand
405
+ out_features = features
406
+ if self.expand == True:
407
+ out_features = features // 2
408
+
409
+ self.out_conv = nn.Conv2d(
410
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
411
+ )
412
+
413
+ if has_residual:
414
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
415
+
416
+ self.has_residual = has_residual
417
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
418
+
419
+ self.skip_add = nn.quantized.FloatFunctional()
420
+ self.size = size
421
+
422
+ def forward(self, *xs, size=None):
423
+ """Forward pass.
424
+
425
+ Returns:
426
+ tensor: output
427
+ """
428
+ output = xs[0]
429
+
430
+ if self.has_residual:
431
+ res = self.resConfUnit1(xs[1])
432
+ output = self.skip_add.add(output, res)
433
+
434
+ output = self.resConfUnit2(output)
435
+
436
+ if (size is None) and (self.size is None):
437
+ modifier = {"scale_factor": 2}
438
+ elif size is None:
439
+ modifier = {"size": self.size}
440
+ else:
441
+ modifier = {"size": size}
442
+
443
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
444
+ output = self.out_conv(output)
445
+
446
+ return output
447
+
448
+
449
+ def custom_interpolate(
450
+ x: torch.Tensor,
451
+ size: Tuple[int, int] = None,
452
+ scale_factor: float = None,
453
+ mode: str = "bilinear",
454
+ align_corners: bool = True,
455
+ ) -> torch.Tensor:
456
+ """
457
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
458
+ """
459
+ if size is None:
460
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
461
+
462
+ INT_MAX = 1610612736
463
+
464
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
465
+
466
+ if input_elements > INT_MAX:
467
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
468
+ interpolated_chunks = [
469
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
470
+ ]
471
+ x = torch.cat(interpolated_chunks, dim=0)
472
+ return x.contiguous()
473
+ else:
474
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
argus/heads/head_act.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear"):
6
+ """
7
+ Activate pose parameters with specified activation functions.
8
+
9
+ Args:
10
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, xx]
11
+ trans_act: Activation type for translation component
12
+ quat_act: Activation type for quaternion component
13
+
14
+ Returns:
15
+ Activated pose parameters tensor
16
+ """
17
+ T = pred_pose_enc[..., :3]
18
+ quat = pred_pose_enc[..., 3:7]
19
+
20
+ T = base_pose_act(T, trans_act)
21
+ quat = base_pose_act(quat, quat_act)
22
+
23
+ # Discard the remaining parameters
24
+ pred_pose_enc = torch.cat([T, quat], dim=-1)
25
+
26
+ return pred_pose_enc
27
+
28
+
29
+ def base_pose_act(pose_enc, act_type="linear"):
30
+ """
31
+ Apply basic activation function to pose parameters.
32
+
33
+ Args:
34
+ pose_enc: Tensor containing encoded pose parameters
35
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
36
+
37
+ Returns:
38
+ Activated pose parameters
39
+ """
40
+ if act_type == "linear":
41
+ return pose_enc
42
+ elif act_type == "inv_log":
43
+ return inverse_log_transform(pose_enc)
44
+ elif act_type == "exp":
45
+ return torch.exp(pose_enc)
46
+ elif act_type == "relu":
47
+ return F.relu(pose_enc)
48
+ elif act_type == "expp1":
49
+ return 1 + pose_enc.exp()
50
+ elif act_type == "expp0":
51
+ return pose_enc.exp()
52
+ elif act_type == "sigmoid":
53
+ return torch.sigmoid(pose_enc)
54
+ else:
55
+ raise ValueError(f"Unknown act_type: {act_type}")
56
+
57
+
58
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
59
+ """
60
+ Process network output to extract 3D points and confidence values.
61
+
62
+ Args:
63
+ out: Network output tensor (B, C, H, W)
64
+ activation: Activation type for 3D points
65
+ conf_activation: Activation type for confidence values
66
+
67
+ Returns:
68
+ Tuple of (3D points tensor, confidence tensor)
69
+ """
70
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
71
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
72
+
73
+ # Split into xyz (first C-1 channels) and confidence (last channel)
74
+ xyz = fmap[:, :, :, :-1]
75
+ conf = fmap[:, :, :, -1]
76
+
77
+ if activation == "norm_exp":
78
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
79
+ xyz_normed = xyz / d
80
+ pts3d = xyz_normed * torch.expm1(d)
81
+ elif activation == "norm":
82
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
83
+ elif activation == "exp":
84
+ pts3d = torch.exp(xyz)
85
+ elif activation == "relu":
86
+ pts3d = F.relu(xyz)
87
+ elif activation == "inv_log":
88
+ pts3d = inverse_log_transform(xyz)
89
+ elif activation == "xy_inv_log":
90
+ xy, z = xyz.split([2, 1], dim=-1)
91
+ z = inverse_log_transform(z)
92
+ pts3d = torch.cat([xy * z, z], dim=-1)
93
+ elif activation == "sigmoid":
94
+ pts3d = torch.sigmoid(xyz)
95
+ elif activation == "linear":
96
+ pts3d = xyz
97
+ else:
98
+ raise ValueError(f"Unknown activation: {activation}")
99
+
100
+ if conf_activation == "expp1":
101
+ conf_out = 1 + conf.exp()
102
+ elif conf_activation == "expp0":
103
+ conf_out = conf.exp()
104
+ elif conf_activation == "sigmoid":
105
+ conf_out = torch.sigmoid(conf)
106
+ else:
107
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
108
+
109
+ return pts3d, conf_out
110
+
111
+
112
+ def inverse_log_transform(y):
113
+ """
114
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
115
+
116
+ Args:
117
+ y: Input tensor
118
+
119
+ Returns:
120
+ Transformed tensor
121
+ """
122
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
argus/heads/utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
5
+ """
6
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
7
+
8
+ Args:
9
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
10
+ embed_dim: Output channel dimension for embeddings
11
+
12
+ Returns:
13
+ Tensor of shape (H, W, embed_dim) with positional embeddings
14
+ """
15
+ H, W, grid_dim = pos_grid.shape
16
+ assert grid_dim == 2
17
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
18
+
19
+ # Process x and y coordinates separately
20
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
21
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
22
+
23
+ # Combine and reshape
24
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
25
+
26
+ return emb.view(H, W, embed_dim) # [H, W, D]
27
+
28
+
29
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
30
+ """
31
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
32
+
33
+ Args:
34
+ - embed_dim: The embedding dimension.
35
+ - pos: The position to generate the embedding from.
36
+
37
+ Returns:
38
+ - emb: The generated 1D positional embedding.
39
+ """
40
+ assert embed_dim % 2 == 0
41
+ device = pos.device
42
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
43
+ omega /= embed_dim / 2.0
44
+ omega = 1.0 / omega_0**omega # (D/2,)
45
+
46
+ pos = pos.reshape(-1) # (M,)
47
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
48
+
49
+ emb_sin = torch.sin(out) # (M, D/2)
50
+ emb_cos = torch.cos(out) # (M, D/2)
51
+
52
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
53
+ return emb.float()
54
+
55
+
56
+ # Inspired by https://github.com/microsoft/moge
57
+
58
+
59
+ def create_uv_grid(
60
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
61
+ ) -> torch.Tensor:
62
+ """
63
+ Create a normalized UV grid of shape (width, height, 2).
64
+
65
+ The grid spans horizontally and vertically according to an aspect ratio,
66
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
67
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
68
+
69
+ Args:
70
+ width (int): Number of points horizontally.
71
+ height (int): Number of points vertically.
72
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
73
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
74
+ device (torch.device, optional): Device on which the tensor is created.
75
+
76
+ Returns:
77
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
78
+ """
79
+ # Derive aspect ratio if not explicitly provided
80
+ if aspect_ratio is None:
81
+ aspect_ratio = float(width) / float(height)
82
+
83
+ # Compute normalized spans for X and Y
84
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
85
+ span_x = aspect_ratio / diag_factor
86
+ span_y = 1.0 / diag_factor
87
+
88
+ # Establish the linspace boundaries
89
+ left_x = -span_x * (width - 1) / width
90
+ right_x = span_x * (width - 1) / width
91
+ top_y = -span_y * (height - 1) / height
92
+ bottom_y = span_y * (height - 1) / height
93
+
94
+ # Generate 1D coordinates
95
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
96
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
97
+
98
+ # Create 2D meshgrid (width x height) and stack into UV
99
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
100
+ uv_grid = torch.stack((uu, vv), dim=-1)
101
+
102
+ return uv_grid
103
+
104
+
105
+
106
+ def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
107
+ """Reorder tensor views to place the selected reference view at the first position (index 0),
108
+ while keeping the remaining views in their original order (excluding the reference view).
109
+
110
+ Args:
111
+ x: Input tensor with shape (B, S, ...) where B = batch size, S = number of views,
112
+ and trailing dimensions can be arbitrary (e.g., N, C for patch tokens).
113
+ b_idx: 1D tensor of shape (B,) containing the index of the reference view for each batch element,
114
+ each value must be in the range [0, S-1].
115
+
116
+ Returns:
117
+ Reordered tensor with the same shape as input, where the reference view is at position 0
118
+ and other views retain their original order (skipping the reference view).
119
+
120
+ Example:
121
+ If B=1, S=5, b_idx=[2], input view order is [0,1,2,3,4],
122
+ output order becomes [2,0,1,3,4].
123
+ """
124
+ # Extract batch size (B) and number of views (S) from input shape
125
+ B, S = x.shape[0], x.shape[1]
126
+
127
+ # No reordering needed if only one view exists
128
+ if S <= 1:
129
+ return x
130
+
131
+ # Generate base index matrix (B, S): each row is [0, 1, ..., S-1] (same across batches)
132
+ idx = torch.arange(S, device=x.device).expand(B, -1)
133
+
134
+ # Create mask to exclude reference view indices (True for non-reference positions)
135
+ mask = idx != b_idx.unsqueeze(1)
136
+
137
+ # Build reorder indices: [reference_idx] + [all non-reference indices in original order]
138
+ # Reshape non-reference indices to (B, S-1) to match batch dimension, then concatenate
139
+ reorder_idx = torch.cat([b_idx.unsqueeze(1), idx[mask].reshape(B, S-1)], dim=1)
140
+
141
+ # Advanced indexing to reorder: batch indices (B,1) paired with reorder indices (B,S)
142
+ return x[torch.arange(B).unsqueeze(1), reorder_idx]
argus/layers/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
argus/layers/attention.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+ XFORMERS_AVAILABLE = False
19
+
20
+
21
+ class Attention(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int = 8,
26
+ qkv_bias: bool = True,
27
+ proj_bias: bool = True,
28
+ attn_drop: float = 0.0,
29
+ proj_drop: float = 0.0,
30
+ norm_layer: nn.Module = nn.LayerNorm,
31
+ qk_norm: bool = False,
32
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
33
+ rope=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
37
+ self.num_heads = num_heads
38
+ self.head_dim = dim // num_heads
39
+ self.scale = self.head_dim**-0.5
40
+ self.fused_attn = fused_attn
41
+
42
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
43
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+ self.rope = rope
49
+
50
+ def forward(self, x: Tensor, pos=None) -> Tensor:
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+ q, k = self.q_norm(q), self.k_norm(k)
55
+
56
+ if self.rope is not None:
57
+ q = self.rope(q, pos)
58
+ k = self.rope(k, pos)
59
+
60
+ if self.fused_attn:
61
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
62
+ else:
63
+ q = q * self.scale
64
+ attn = q @ k.transpose(-2, -1)
65
+ attn = attn.softmax(dim=-1)
66
+ attn = self.attn_drop(attn)
67
+ x = attn @ v
68
+
69
+ x = x.transpose(1, 2).reshape(B, N, C)
70
+ x = self.proj(x)
71
+ x = self.proj_drop(x)
72
+ return x
73
+
74
+
75
+ class MemEffAttention(Attention):
76
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
77
+ assert pos is None
78
+ if not XFORMERS_AVAILABLE:
79
+ if attn_bias is not None:
80
+ raise AssertionError("xFormers is required for using nested tensors")
81
+ return super().forward(x)
82
+
83
+ B, N, C = x.shape
84
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
85
+
86
+ q, k, v = unbind(qkv, 2)
87
+
88
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
89
+ x = x.reshape([B, N, C])
90
+
91
+ x = self.proj(x)
92
+ x = self.proj_drop(x)
93
+ return x
argus/layers/block.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ XFORMERS_AVAILABLE = False
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ qkv_bias: bool = True,
34
+ proj_bias: bool = True,
35
+ ffn_bias: bool = True,
36
+ drop: float = 0.0,
37
+ attn_drop: float = 0.0,
38
+ init_values=None,
39
+ drop_path: float = 0.0,
40
+ act_layer: Callable[..., nn.Module] = nn.GELU,
41
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
42
+ attn_class: Callable[..., nn.Module] = Attention,
43
+ ffn_layer: Callable[..., nn.Module] = Mlp,
44
+ qk_norm: bool = False,
45
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
46
+ rope=None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.norm1 = norm_layer(dim)
51
+
52
+ self.attn = attn_class(
53
+ dim,
54
+ num_heads=num_heads,
55
+ qkv_bias=qkv_bias,
56
+ proj_bias=proj_bias,
57
+ attn_drop=attn_drop,
58
+ proj_drop=drop,
59
+ qk_norm=qk_norm,
60
+ fused_attn=fused_attn,
61
+ rope=rope,
62
+ )
63
+
64
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
66
+
67
+ self.norm2 = norm_layer(dim)
68
+ mlp_hidden_dim = int(dim * mlp_ratio)
69
+ self.mlp = ffn_layer(
70
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
71
+ )
72
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.sample_drop_ratio = drop_path
76
+
77
+ def forward(self, x: Tensor, pos=None) -> Tensor:
78
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
79
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
80
+
81
+ def ffn_residual_func(x: Tensor) -> Tensor:
82
+ return self.ls2(self.mlp(self.norm2(x)))
83
+
84
+ if self.training and self.sample_drop_ratio > 0.1:
85
+ # the overhead is compensated only for a drop path rate larger than 0.1
86
+ x = drop_add_residual_stochastic_depth(
87
+ x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
88
+ )
89
+ x = drop_add_residual_stochastic_depth(
90
+ x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
91
+ )
92
+ elif self.training and self.sample_drop_ratio > 0.0:
93
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
94
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
95
+ else:
96
+ x = x + attn_residual_func(x, pos=pos)
97
+ x = x + ffn_residual_func(x)
98
+ return x
99
+
100
+
101
+ def drop_add_residual_stochastic_depth(
102
+ x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
103
+ ) -> Tensor:
104
+ # 1) extract subset using permutation
105
+ b, n, d = x.shape
106
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
107
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
108
+ x_subset = x[brange]
109
+
110
+ # 2) apply residual_func to get residual
111
+ if pos is not None:
112
+ # if necessary, apply rope to the subset
113
+ pos = pos[brange]
114
+ residual = residual_func(x_subset, pos=pos)
115
+ else:
116
+ residual = residual_func(x_subset)
117
+
118
+ x_flat = x.flatten(1)
119
+ residual = residual.flatten(1)
120
+
121
+ residual_scale_factor = b / sample_subset_size
122
+
123
+ # 3) add the residual
124
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
125
+ return x_plus_residual.view_as(x)
126
+
127
+
128
+ def get_branges_scales(x, sample_drop_ratio=0.0):
129
+ b, n, d = x.shape
130
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
131
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
132
+ residual_scale_factor = b / sample_subset_size
133
+ return brange, residual_scale_factor
134
+
135
+
136
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
137
+ if scaling_vector is None:
138
+ x_flat = x.flatten(1)
139
+ residual = residual.flatten(1)
140
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
141
+ else:
142
+ x_plus_residual = scaled_index_add(
143
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
144
+ )
145
+ return x_plus_residual
146
+
147
+
148
+ attn_bias_cache: Dict[Tuple, Any] = {}
149
+
150
+
151
+ def get_attn_bias_and_cat(x_list, branges=None):
152
+ """
153
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
154
+ """
155
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
156
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
157
+ if all_shapes not in attn_bias_cache.keys():
158
+ seqlens = []
159
+ for b, x in zip(batch_sizes, x_list):
160
+ for _ in range(b):
161
+ seqlens.append(x.shape[1])
162
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
163
+ attn_bias._batch_sizes = batch_sizes
164
+ attn_bias_cache[all_shapes] = attn_bias
165
+
166
+ if branges is not None:
167
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
168
+ else:
169
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
170
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
171
+
172
+ return attn_bias_cache[all_shapes], cat_tensors
173
+
174
+
175
+ def drop_add_residual_stochastic_depth_list(
176
+ x_list: List[Tensor],
177
+ residual_func: Callable[[Tensor, Any], Tensor],
178
+ sample_drop_ratio: float = 0.0,
179
+ scaling_vector=None,
180
+ ) -> Tensor:
181
+ # 1) generate random set of indices for dropping samples in the batch
182
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
183
+ branges = [s[0] for s in branges_scales]
184
+ residual_scale_factors = [s[1] for s in branges_scales]
185
+
186
+ # 2) get attention bias and index+concat the tensors
187
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
188
+
189
+ # 3) apply residual_func to get residual, and split the result
190
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
191
+
192
+ outputs = []
193
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
194
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
195
+ return outputs
196
+
197
+
198
+ class NestedTensorBlock(Block):
199
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
200
+ """
201
+ x_list contains a list of tensors to nest together and run
202
+ """
203
+ assert isinstance(self.attn, MemEffAttention)
204
+
205
+ if self.training and self.sample_drop_ratio > 0.0:
206
+
207
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
208
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
209
+
210
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
211
+ return self.mlp(self.norm2(x))
212
+
213
+ x_list = drop_add_residual_stochastic_depth_list(
214
+ x_list,
215
+ residual_func=attn_residual_func,
216
+ sample_drop_ratio=self.sample_drop_ratio,
217
+ scaling_vector=(self.ls1.gamma if isinstance(self.ls1, LayerScale) else None),
218
+ )
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=ffn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=(self.ls2.gamma if isinstance(self.ls1, LayerScale) else None),
224
+ )
225
+ return x_list
226
+ else:
227
+
228
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
229
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
230
+
231
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
232
+ return self.ls2(self.mlp(self.norm2(x)))
233
+
234
+ attn_bias, x = get_attn_bias_and_cat(x_list)
235
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
236
+ x = x + ffn_residual_func(x)
237
+ return attn_bias.split(x)
238
+
239
+ def forward(self, x_or_x_list):
240
+ if isinstance(x_or_x_list, Tensor):
241
+ return super().forward(x_or_x_list)
242
+ elif isinstance(x_or_x_list, list):
243
+ if not XFORMERS_AVAILABLE:
244
+ raise AssertionError("xFormers is required for using nested tensors")
245
+ return self.forward_nested(x_or_x_list)
246
+ else:
247
+ raise AssertionError
argus/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
argus/layers/layer_scale.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
17
+ super().__init__()
18
+ self.inplace = inplace
19
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
argus/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
argus/layers/patch_embed.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
51
+
52
+ self.img_size = image_HW
53
+ self.patch_size = patch_HW
54
+ self.patches_resolution = patch_grid_size
55
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
56
+
57
+ self.in_chans = in_chans
58
+ self.embed_dim = embed_dim
59
+
60
+ self.flatten_embedding = flatten_embedding
61
+
62
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
63
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
64
+
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ _, _, H, W = x.shape
67
+ patch_H, patch_W = self.patch_size
68
+
69
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
70
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
71
+
72
+ x = self.proj(x) # B C H W
73
+ H, W = x.size(2), x.size(3)
74
+ x = x.flatten(2).transpose(1, 2) # B HW C
75
+ x = self.norm(x)
76
+ if not self.flatten_embedding:
77
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
78
+ return x
79
+
80
+ def flops(self) -> float:
81
+ Ho, Wo = self.patches_resolution
82
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
83
+ if self.norm is not None:
84
+ flops += Ho * Wo * self.embed_dim
85
+ return flops
argus/layers/rope.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Implementation of 2D Rotary Position Embeddings (RoPE).
8
+
9
+ # This module provides a clean implementation of 2D Rotary Position Embeddings,
10
+ # which extends the original RoPE concept to handle 2D spatial positions.
11
+
12
+ # Inspired by:
13
+ # https://github.com/meta-llama/codellama/blob/main/llama/model.py
14
+ # https://github.com/naver-ai/rope-vit
15
+
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from typing import Dict, Tuple
22
+
23
+
24
+ class PositionGetter:
25
+ """Generates and caches 2D spatial positions for patches in a grid.
26
+
27
+ This class efficiently manages the generation of spatial coordinates for patches
28
+ in a 2D grid, caching results to avoid redundant computations.
29
+
30
+ Attributes:
31
+ position_cache: Dictionary storing precomputed position tensors for different
32
+ grid dimensions.
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initializes the position generator with an empty cache."""
37
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
38
+
39
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
40
+ """Generates spatial positions for a batch of patches.
41
+
42
+ Args:
43
+ batch_size: Number of samples in the batch.
44
+ height: Height of the grid in patches.
45
+ width: Width of the grid in patches.
46
+ device: Target device for the position tensor.
47
+
48
+ Returns:
49
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
50
+ for each position in the grid, repeated for each batch item.
51
+ """
52
+ if (height, width) not in self.position_cache:
53
+ y_coords = torch.arange(height, device=device)
54
+ x_coords = torch.arange(width, device=device)
55
+ positions = torch.cartesian_prod(y_coords, x_coords)
56
+ self.position_cache[height, width] = positions
57
+
58
+ cached_positions = self.position_cache[height, width]
59
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
60
+
61
+
62
+ class RotaryPositionEmbedding2D(nn.Module):
63
+ """2D Rotary Position Embedding implementation.
64
+
65
+ This module applies rotary position embeddings to input tokens based on their
66
+ 2D spatial positions. It handles the position-dependent rotation of features
67
+ separately for vertical and horizontal dimensions.
68
+
69
+ Args:
70
+ frequency: Base frequency for the position embeddings. Default: 100.0
71
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
72
+
73
+ Attributes:
74
+ base_frequency: Base frequency for computing position embeddings.
75
+ scaling_factor: Factor to scale the computed frequencies.
76
+ frequency_cache: Cache for storing precomputed frequency components.
77
+ """
78
+
79
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
80
+ """Initializes the 2D RoPE module."""
81
+ super().__init__()
82
+ self.base_frequency = frequency
83
+ self.scaling_factor = scaling_factor
84
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
85
+
86
+ def _compute_frequency_components(
87
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """Computes frequency components for rotary embeddings.
90
+
91
+ Args:
92
+ dim: Feature dimension (must be even).
93
+ seq_len: Maximum sequence length.
94
+ device: Target device for computations.
95
+ dtype: Data type for the computed tensors.
96
+
97
+ Returns:
98
+ Tuple of (cosine, sine) tensors for frequency components.
99
+ """
100
+ cache_key = (dim, seq_len, device, dtype)
101
+ if cache_key not in self.frequency_cache:
102
+ # Compute frequency bands
103
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
104
+ inv_freq = 1.0 / (self.base_frequency**exponents)
105
+
106
+ # Generate position-dependent frequencies
107
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
108
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
109
+
110
+ # Compute and cache frequency components
111
+ angles = angles.to(dtype)
112
+ angles = torch.cat((angles, angles), dim=-1)
113
+ cos_components = angles.cos().to(dtype)
114
+ sin_components = angles.sin().to(dtype)
115
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
116
+
117
+ return self.frequency_cache[cache_key]
118
+
119
+ @staticmethod
120
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
121
+ """Performs feature rotation by splitting and recombining feature dimensions.
122
+
123
+ Args:
124
+ x: Input tensor to rotate.
125
+
126
+ Returns:
127
+ Rotated feature tensor.
128
+ """
129
+ feature_dim = x.shape[-1]
130
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
131
+ return torch.cat((-x2, x1), dim=-1)
132
+
133
+ def _apply_1d_rope(
134
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
135
+ ) -> torch.Tensor:
136
+ """Applies 1D rotary position embeddings along one dimension.
137
+
138
+ Args:
139
+ tokens: Input token features.
140
+ positions: Position indices.
141
+ cos_comp: Cosine components for rotation.
142
+ sin_comp: Sine components for rotation.
143
+
144
+ Returns:
145
+ Tokens with applied rotary position embeddings.
146
+ """
147
+ # Embed positions with frequency components
148
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
149
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
150
+
151
+ # Apply rotation
152
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
153
+
154
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
155
+ """Applies 2D rotary position embeddings to input tokens.
156
+
157
+ Args:
158
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
159
+ The feature dimension (dim) must be divisible by 4.
160
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
161
+ the y and x coordinates for each token.
162
+
163
+ Returns:
164
+ Tensor of same shape as input with applied 2D rotary position embeddings.
165
+
166
+ Raises:
167
+ AssertionError: If input dimensions are invalid or positions are malformed.
168
+ """
169
+ # Validate inputs
170
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
171
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
172
+
173
+ # Compute feature dimension for each spatial direction
174
+ feature_dim = tokens.size(-1) // 2
175
+
176
+ # Get frequency components
177
+ max_position = int(positions.max()) + 1
178
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
179
+
180
+ # Split features for vertical and horizontal processing
181
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
182
+
183
+ # Apply RoPE separately for each dimension
184
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
185
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
186
+
187
+ # Combine processed features
188
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
argus/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ # try:
39
+ # if XFORMERS_ENABLED:
40
+ # from xformers.ops import SwiGLU
41
+
42
+ # XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ # else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ # raise ImportError
47
+ # except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
argus/layers/vision_transformer.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from .mlp import Mlp
20
+ from .patch_embed import PatchEmbed
21
+ from .swiglu_ffn import SwiGLUFFNFused
22
+ from .attention import MemEffAttention
23
+ from .block import NestedTensorBlock as Block
24
+
25
+ logger = logging.getLogger("dinov2")
26
+
27
+
28
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
29
+ if not depth_first and include_root:
30
+ fn(module=module, name=name)
31
+ for child_name, child_module in module.named_children():
32
+ child_name = ".".join((name, child_name)) if name else child_name
33
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
34
+ if depth_first and include_root:
35
+ fn(module=module, name=name)
36
+ return module
37
+
38
+
39
+ class BlockChunk(nn.ModuleList):
40
+ def forward(self, x):
41
+ for b in self:
42
+ x = b(x)
43
+ return x
44
+
45
+
46
+ class DinoVisionTransformer(nn.Module):
47
+ def __init__(
48
+ self,
49
+ img_size=224,
50
+ patch_size=16,
51
+ in_chans=3,
52
+ embed_dim=768,
53
+ depth=12,
54
+ num_heads=12,
55
+ mlp_ratio=4.0,
56
+ qkv_bias=True,
57
+ ffn_bias=True,
58
+ proj_bias=True,
59
+ drop_path_rate=0.0,
60
+ drop_path_uniform=False,
61
+ init_values=None, # for layerscale: None or 0 => no layerscale
62
+ embed_layer=PatchEmbed,
63
+ act_layer=nn.GELU,
64
+ block_fn=Block,
65
+ ffn_layer="mlp",
66
+ block_chunks=1,
67
+ num_register_tokens=0,
68
+ interpolate_antialias=False,
69
+ interpolate_offset=0.1,
70
+ qk_norm=False,
71
+ ):
72
+ """
73
+ Args:
74
+ img_size (int, tuple): input image size
75
+ patch_size (int, tuple): patch size
76
+ in_chans (int): number of input channels
77
+ embed_dim (int): embedding dimension
78
+ depth (int): depth of transformer
79
+ num_heads (int): number of attention heads
80
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
81
+ qkv_bias (bool): enable bias for qkv if True
82
+ proj_bias (bool): enable bias for proj in attn if True
83
+ ffn_bias (bool): enable bias for ffn if True
84
+ drop_path_rate (float): stochastic depth rate
85
+ drop_path_uniform (bool): apply uniform drop rate across blocks
86
+ weight_init (str): weight init scheme
87
+ init_values (float): layer-scale init values
88
+ embed_layer (nn.Module): patch embedding layer
89
+ act_layer (nn.Module): MLP activation layer
90
+ block_fn (nn.Module): transformer block class
91
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
92
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
93
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
94
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
95
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
96
+ """
97
+ super().__init__()
98
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
99
+
100
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
101
+ self.num_tokens = 1
102
+ self.n_blocks = depth
103
+ self.num_heads = num_heads
104
+ self.patch_size = patch_size
105
+ self.num_register_tokens = num_register_tokens
106
+ self.interpolate_antialias = interpolate_antialias
107
+ self.interpolate_offset = interpolate_offset
108
+ self.use_reentrant = False # hardcoded to False
109
+
110
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
111
+ num_patches = self.patch_embed.num_patches
112
+
113
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
114
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
115
+ assert num_register_tokens >= 0
116
+ self.register_tokens = (
117
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
118
+ )
119
+
120
+ if drop_path_uniform is True:
121
+ dpr = [drop_path_rate] * depth
122
+ else:
123
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
124
+
125
+ if ffn_layer == "mlp":
126
+ logger.info("using MLP layer as FFN")
127
+ ffn_layer = Mlp
128
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
129
+ logger.info("using SwiGLU layer as FFN")
130
+ ffn_layer = SwiGLUFFNFused
131
+ elif ffn_layer == "identity":
132
+ logger.info("using Identity layer as FFN")
133
+
134
+ def f(*args, **kwargs):
135
+ return nn.Identity()
136
+
137
+ ffn_layer = f
138
+ else:
139
+ raise NotImplementedError
140
+
141
+ blocks_list = [
142
+ block_fn(
143
+ dim=embed_dim,
144
+ num_heads=num_heads,
145
+ mlp_ratio=mlp_ratio,
146
+ qkv_bias=qkv_bias,
147
+ proj_bias=proj_bias,
148
+ ffn_bias=ffn_bias,
149
+ drop_path=dpr[i],
150
+ norm_layer=norm_layer,
151
+ act_layer=act_layer,
152
+ ffn_layer=ffn_layer,
153
+ init_values=init_values,
154
+ qk_norm=qk_norm,
155
+ )
156
+ for i in range(depth)
157
+ ]
158
+ if block_chunks > 0:
159
+ self.chunked_blocks = True
160
+ chunked_blocks = []
161
+ chunksize = depth // block_chunks
162
+ for i in range(0, depth, chunksize):
163
+ # this is to keep the block index consistent if we chunk the block list
164
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
165
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
166
+ else:
167
+ self.chunked_blocks = False
168
+ self.blocks = nn.ModuleList(blocks_list)
169
+
170
+ self.norm = norm_layer(embed_dim)
171
+ self.head = nn.Identity()
172
+
173
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
174
+
175
+ self.init_weights()
176
+
177
+ def init_weights(self):
178
+ trunc_normal_(self.pos_embed, std=0.02)
179
+ nn.init.normal_(self.cls_token, std=1e-6)
180
+ if self.register_tokens is not None:
181
+ nn.init.normal_(self.register_tokens, std=1e-6)
182
+ named_apply(init_weights_vit_timm, self)
183
+
184
+ def interpolate_pos_encoding(self, x, w, h):
185
+ previous_dtype = x.dtype
186
+ npatch = x.shape[1] - 1
187
+ N = self.pos_embed.shape[1] - 1
188
+ if npatch == N and w == h:
189
+ return self.pos_embed
190
+ pos_embed = self.pos_embed.float()
191
+ class_pos_embed = pos_embed[:, 0]
192
+ patch_pos_embed = pos_embed[:, 1:]
193
+ dim = x.shape[-1]
194
+ w0 = w // self.patch_size
195
+ h0 = h // self.patch_size
196
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
197
+ assert N == M * M
198
+ kwargs = {}
199
+ if self.interpolate_offset:
200
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
201
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
202
+ sx = float(w0 + self.interpolate_offset) / M
203
+ sy = float(h0 + self.interpolate_offset) / M
204
+ kwargs["scale_factor"] = (sx, sy)
205
+ else:
206
+ # Simply specify an output size instead of a scale factor
207
+ kwargs["size"] = (w0, h0)
208
+ patch_pos_embed = nn.functional.interpolate(
209
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
210
+ mode="bicubic",
211
+ antialias=self.interpolate_antialias,
212
+ **kwargs,
213
+ )
214
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
215
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
216
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
217
+
218
+ def prepare_tokens_with_masks(self, x, masks=None):
219
+ B, nc, w, h = x.shape
220
+ x = self.patch_embed(x)
221
+ if masks is not None:
222
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
223
+
224
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
225
+ x = x + self.interpolate_pos_encoding(x, w, h)
226
+
227
+ if self.register_tokens is not None:
228
+ x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
229
+
230
+ return x
231
+
232
+ def forward_features_list(self, x_list, masks_list):
233
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
234
+
235
+ for blk in self.blocks:
236
+ if self.training:
237
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
238
+ else:
239
+ x = blk(x)
240
+
241
+ all_x = x
242
+ output = []
243
+ for x, masks in zip(all_x, masks_list):
244
+ x_norm = self.norm(x)
245
+ output.append(
246
+ {
247
+ "x_norm_clstoken": x_norm[:, 0],
248
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
249
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
250
+ "x_prenorm": x,
251
+ "masks": masks,
252
+ }
253
+ )
254
+ return output
255
+
256
+ def forward_features(self, x, masks=None):
257
+ if isinstance(x, list):
258
+ return self.forward_features_list(x, masks)
259
+
260
+ x = self.prepare_tokens_with_masks(x, masks)
261
+
262
+ for blk in self.blocks:
263
+ if self.training:
264
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
265
+ else:
266
+ x = blk(x)
267
+
268
+ x_norm = self.norm(x)
269
+ return {
270
+ "x_norm_clstoken": x_norm[:, 0],
271
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
272
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
273
+ "x_prenorm": x,
274
+ "masks": masks,
275
+ }
276
+
277
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
278
+ x = self.prepare_tokens_with_masks(x)
279
+ # If n is an int, take the n last blocks. If it's a list, take them
280
+ output, total_block_len = [], len(self.blocks)
281
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
282
+ for i, blk in enumerate(self.blocks):
283
+ x = blk(x)
284
+ if i in blocks_to_take:
285
+ output.append(x)
286
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
287
+ return output
288
+
289
+ def _get_intermediate_layers_chunked(self, x, n=1):
290
+ x = self.prepare_tokens_with_masks(x)
291
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
292
+ # If n is an int, take the n last blocks. If it's a list, take them
293
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
294
+ for block_chunk in self.blocks:
295
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
296
+ x = blk(x)
297
+ if i in blocks_to_take:
298
+ output.append(x)
299
+ i += 1
300
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
301
+ return output
302
+
303
+ def get_intermediate_layers(
304
+ self,
305
+ x: torch.Tensor,
306
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
307
+ reshape: bool = False,
308
+ return_class_token: bool = False,
309
+ norm=True,
310
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
311
+ if self.chunked_blocks:
312
+ outputs = self._get_intermediate_layers_chunked(x, n)
313
+ else:
314
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
315
+ if norm:
316
+ outputs = [self.norm(out) for out in outputs]
317
+ class_tokens = [out[:, 0] for out in outputs]
318
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
319
+ if reshape:
320
+ B, _, w, h = x.shape
321
+ outputs = [
322
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
323
+ for out in outputs
324
+ ]
325
+ if return_class_token:
326
+ return tuple(zip(outputs, class_tokens))
327
+ return tuple(outputs)
328
+
329
+ def forward(self, *args, is_training=True, **kwargs):
330
+ ret = self.forward_features(*args, **kwargs)
331
+ if is_training:
332
+ return ret
333
+ else:
334
+ return self.head(ret["x_norm_clstoken"])
335
+
336
+
337
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
338
+ """ViT weight initialization, original timm impl (for reproducibility)"""
339
+ if isinstance(module, nn.Linear):
340
+ trunc_normal_(module.weight, std=0.02)
341
+ if module.bias is not None:
342
+ nn.init.zeros_(module.bias)
343
+
344
+
345
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
346
+ model = DinoVisionTransformer(
347
+ patch_size=patch_size,
348
+ embed_dim=384,
349
+ depth=12,
350
+ num_heads=6,
351
+ mlp_ratio=4,
352
+ block_fn=partial(Block, attn_class=MemEffAttention),
353
+ num_register_tokens=num_register_tokens,
354
+ **kwargs,
355
+ )
356
+ return model
357
+
358
+
359
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
360
+ model = DinoVisionTransformer(
361
+ patch_size=patch_size,
362
+ embed_dim=768,
363
+ depth=12,
364
+ num_heads=12,
365
+ mlp_ratio=4,
366
+ block_fn=partial(Block, attn_class=MemEffAttention),
367
+ num_register_tokens=num_register_tokens,
368
+ **kwargs,
369
+ )
370
+ return model
371
+
372
+
373
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
374
+ model = DinoVisionTransformer(
375
+ patch_size=patch_size,
376
+ embed_dim=1024,
377
+ depth=24,
378
+ num_heads=16,
379
+ mlp_ratio=4,
380
+ block_fn=partial(Block, attn_class=MemEffAttention),
381
+ num_register_tokens=num_register_tokens,
382
+ **kwargs,
383
+ )
384
+ return model
385
+
386
+
387
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
388
+ """
389
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
390
+ """
391
+ model = DinoVisionTransformer(
392
+ patch_size=patch_size,
393
+ embed_dim=1536,
394
+ depth=40,
395
+ num_heads=24,
396
+ mlp_ratio=4,
397
+ block_fn=partial(Block, attn_class=MemEffAttention),
398
+ num_register_tokens=num_register_tokens,
399
+ **kwargs,
400
+ )
401
+ return model
argus/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2026 Realsee. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0.
argus/models/aggregator.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.utils.checkpoint import checkpoint
6
+ from typing import Optional, Tuple, Union, List, Dict, Any
7
+ from argus.layers import Mlp
8
+ from argus.layers import PatchEmbed
9
+ from argus.layers.block import Block
10
+ from argus.layers.rope import RotaryPositionEmbedding2D, PositionGetter
11
+ from argus.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
12
+ from argus.heads.utils import reorder_by_reference
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
17
+ _RESNET_STD = [0.229, 0.224, 0.225]
18
+
19
+
20
+ class Aggregator(nn.Module):
21
+ """
22
+ Args:
23
+ img_size (int): Image size in pixels.
24
+ patch_size (int): Size of each patch for PatchEmbed.
25
+ embed_dim (int): Dimension of the token embeddings.
26
+ depth (int): Number of blocks.
27
+ num_heads (int): Number of attention heads.
28
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
29
+ num_register_tokens (int): Number of register tokens.
30
+ block_fn (nn.Module): The block type used for attention (Block by default).
31
+ qkv_bias (bool): Whether to include bias in QKV projections.
32
+ proj_bias (bool): Whether to include bias in the output projection.
33
+ ffn_bias (bool): Whether to include bias in MLP layers.
34
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
35
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
36
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
37
+ qk_norm (bool): Whether to apply QK normalization.
38
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
39
+ init_values (float): Init scale for layer scale.
40
+ reorder_by_learning_ref (bool): Whether to reorder features by learning reference view index.
41
+ ref_aa_block_num (int): Number of aa blocks for reference view learning.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ img_size=518,
47
+ patch_size=14,
48
+ embed_dim=1024,
49
+ depth=24,
50
+ num_heads=16,
51
+ mlp_ratio=4.0,
52
+ num_register_tokens=4,
53
+ block_fn=Block,
54
+ qkv_bias=True,
55
+ proj_bias=True,
56
+ ffn_bias=True,
57
+ patch_embed="dinov2_vitl14_reg",
58
+ aa_order=["frame", "global"],
59
+ aa_block_size=1,
60
+ qk_norm=True,
61
+ rope_freq=100,
62
+ init_values=0.01,
63
+ reorder_by_learning_ref=True,
64
+ ref_aa_block_num=2,
65
+ save_inference_memory=True,
66
+ ):
67
+ super().__init__()
68
+
69
+ self.reorder_by_learning_ref = reorder_by_learning_ref
70
+ self.save_inference_memory = save_inference_memory
71
+
72
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
73
+
74
+ # Initialize rotary position embedding if frequency > 0
75
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
76
+ self.position_getter = PositionGetter() if self.rope is not None else None
77
+
78
+ self.frame_blocks = nn.ModuleList(
79
+ [
80
+ block_fn(
81
+ dim=embed_dim,
82
+ num_heads=num_heads,
83
+ mlp_ratio=mlp_ratio,
84
+ qkv_bias=qkv_bias,
85
+ proj_bias=proj_bias,
86
+ ffn_bias=ffn_bias,
87
+ init_values=init_values,
88
+ qk_norm=qk_norm,
89
+ rope=self.rope,
90
+ )
91
+ for _ in range(depth)
92
+ ]
93
+ )
94
+
95
+ self.global_blocks = nn.ModuleList(
96
+ [
97
+ block_fn(
98
+ dim=embed_dim,
99
+ num_heads=num_heads,
100
+ mlp_ratio=mlp_ratio,
101
+ qkv_bias=qkv_bias,
102
+ proj_bias=proj_bias,
103
+ ffn_bias=ffn_bias,
104
+ init_values=init_values,
105
+ qk_norm=qk_norm,
106
+ rope=self.rope,
107
+ )
108
+ for _ in range(depth)
109
+ ]
110
+ )
111
+
112
+ self.depth = depth
113
+ self.aa_order = aa_order
114
+ self.patch_size = patch_size
115
+ self.aa_block_size = aa_block_size
116
+
117
+ # Validate that depth is divisible by aa_block_size
118
+ if self.depth % self.aa_block_size != 0:
119
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
120
+
121
+ self.aa_block_num = self.depth // self.aa_block_size
122
+
123
+
124
+ # Reference Learning Network
125
+ if self.reorder_by_learning_ref:
126
+ self.ref_aa_block_num = ref_aa_block_num
127
+ self.ref_frame_blocks = nn.ModuleList(
128
+ [
129
+ block_fn(
130
+ dim=embed_dim,
131
+ num_heads=num_heads,
132
+ mlp_ratio=mlp_ratio,
133
+ qkv_bias=qkv_bias,
134
+ proj_bias=proj_bias,
135
+ ffn_bias=ffn_bias,
136
+ init_values=init_values,
137
+ qk_norm=qk_norm,
138
+ rope=self.rope,
139
+ )
140
+ for _ in range(self.ref_aa_block_num)
141
+ ]
142
+ )
143
+
144
+ self.ref_global_blocks = nn.ModuleList(
145
+ [
146
+ block_fn(
147
+ dim=embed_dim,
148
+ num_heads=num_heads,
149
+ mlp_ratio=mlp_ratio,
150
+ qkv_bias=qkv_bias,
151
+ proj_bias=proj_bias,
152
+ ffn_bias=ffn_bias,
153
+ init_values=init_values,
154
+ qk_norm=qk_norm,
155
+ rope=self.rope,
156
+ )
157
+ for _ in range(self.ref_aa_block_num)
158
+ ]
159
+ )
160
+
161
+ # Note: We have two camera tokens, one for the first frame and one for the rest
162
+ # The same applies for register tokens
163
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
164
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
165
+
166
+ if self.reorder_by_learning_ref:
167
+ # describe the covisibility of the current frame with other frames
168
+ self.covisibility_token = nn.Parameter(torch.randn(1, 1, 1, embed_dim))
169
+
170
+ # The patch tokens start after the camera and register tokens
171
+ self.patch_start_idx = 1 + num_register_tokens
172
+
173
+ # Initialize parameters with small values
174
+ nn.init.normal_(self.camera_token, std=1e-6)
175
+ nn.init.normal_(self.register_token, std=1e-6)
176
+ if self.reorder_by_learning_ref:
177
+ nn.init.normal_(self.covisibility_token, std=1e-6)
178
+
179
+
180
+ # Register normalization constants as buffers
181
+ for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
182
+ self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
183
+
184
+ self.use_reentrant = False # hardcoded to False
185
+
186
+ def __build_patch_embed__(
187
+ self,
188
+ patch_embed,
189
+ img_size,
190
+ patch_size,
191
+ num_register_tokens,
192
+ interpolate_antialias=True,
193
+ interpolate_offset=0.0,
194
+ block_chunks=0,
195
+ init_values=1.0,
196
+ embed_dim=1024,
197
+ ):
198
+ """
199
+ Build the patch embed layer. If 'conv', we use a
200
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
201
+ """
202
+
203
+ if "conv" in patch_embed:
204
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
205
+ else:
206
+ vit_models = {
207
+ "dinov2_vitl14_reg": vit_large,
208
+ "dinov2_vitb14_reg": vit_base,
209
+ "dinov2_vits14_reg": vit_small,
210
+ "dinov2_vitg2_reg": vit_giant2,
211
+ }
212
+
213
+ self.patch_embed = vit_models[patch_embed](
214
+ img_size=img_size,
215
+ patch_size=patch_size,
216
+ num_register_tokens=num_register_tokens,
217
+ interpolate_antialias=interpolate_antialias,
218
+ interpolate_offset=interpolate_offset,
219
+ block_chunks=block_chunks,
220
+ init_values=init_values,
221
+ )
222
+
223
+ # Disable gradient updates for mask token
224
+ if hasattr(self.patch_embed, "mask_token"):
225
+ # self.patch_embed.mask_token.requires_grad_(False)
226
+ del self.patch_embed.mask_token
227
+
228
+
229
+
230
+ # covisibility head
231
+ if self.reorder_by_learning_ref:
232
+ self.token_norm = nn.LayerNorm(embed_dim * 2)
233
+ self.covisibility_head = Mlp(in_features=embed_dim * 2, hidden_features=embed_dim * 2 // 2, out_features=1, drop=0)
234
+
235
+ def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]:
236
+ """
237
+ Args:
238
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
239
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
240
+
241
+ Returns:
242
+ (list[torch.Tensor], int):
243
+ The list of outputs from the attention blocks,
244
+ and the patch_start_idx indicating where patch tokens begin.
245
+ """
246
+ B, S, C_in, H, W = images.shape
247
+
248
+ if C_in != 3:
249
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
250
+
251
+ # Normalize images and reshape for patch embed
252
+ images = (images - self._resnet_mean) / self._resnet_std
253
+
254
+ # Reshape to [B*S, C, H, W] for patch embedding
255
+ images = images.view(B * S, C_in, H, W)
256
+ patch_tokens = self.patch_embed(images)
257
+
258
+ if isinstance(patch_tokens, dict):
259
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
260
+
261
+ _, P, C = patch_tokens.shape
262
+
263
+ ################# ref learning
264
+ covisibility_scores = None
265
+ ref_idx = None
266
+ if self.reorder_by_learning_ref:
267
+ # expand covisibility token to match batch size and sequence length
268
+ covisibility_token = self.covisibility_token.expand(B, S, 1, C).view(B * S, 1, C).contiguous()
269
+ # Concatenate covisibility token with patch tokens
270
+ covisibility_patch_tokens = torch.cat([covisibility_token, patch_tokens], dim=1) # [BS,1+HW,C]
271
+
272
+ covisibility_pos = None
273
+ if self.rope is not None:
274
+ covisibility_pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
275
+
276
+ # do not use position embedding for special covisibility_token
277
+ # so set pos to 0 for the special tokens
278
+ covisibility_pos = covisibility_pos + 1
279
+ covisibility_pos_special = torch.zeros(B * S, 1, 2).to(images.device).to(covisibility_pos.dtype)
280
+ covisibility_pos = torch.cat([covisibility_pos_special, covisibility_pos], dim=1) # [BS, 1+HW, 2]
281
+
282
+
283
+ # update P because we added special tokens
284
+ _, P_covis, C_covis = covisibility_patch_tokens.shape
285
+
286
+ frame_idx = 0
287
+ global_idx = 0
288
+ output_list = []
289
+
290
+ for ref_block_i in range(self.ref_aa_block_num):
291
+ for attn_type in self.aa_order:
292
+ if attn_type == "frame":
293
+ covisibility_patch_tokens, frame_idx, frame_intermediates = self._ref_process_frame_attention(
294
+ covisibility_patch_tokens, B, S, P_covis, C_covis, frame_idx, pos=covisibility_pos
295
+ )
296
+ elif attn_type == "global":
297
+ covisibility_patch_tokens, global_idx, global_intermediates = self._ref_process_global_attention(
298
+ covisibility_patch_tokens, B, S, P_covis, C_covis, global_idx, pos=covisibility_pos
299
+ )
300
+ else:
301
+ raise ValueError(f"Unknown attention type: {attn_type}")
302
+
303
+ for i in range(len(frame_intermediates)):
304
+ # concat frame and global intermediates, [B x S x P x 2C]
305
+ concat_inter = torch.cat([frame_intermediates[-1], global_intermediates[-1]], dim=-1)
306
+ output_list.append(concat_inter)
307
+
308
+ last_covisibility_patch_tokens = output_list[-1][:,:,0,:] # [B, S, C]
309
+ # normalize
310
+ last_covisibility_patch_tokens = self.token_norm(last_covisibility_patch_tokens)
311
+
312
+ covisibility_scores = self.covisibility_head(last_covisibility_patch_tokens).squeeze(-1) # [B, S]
313
+ # # cos
314
+ # feat_norm = F.normalize(covisibility_features, p=2, dim=-1, eps=1e-8) # [B, S, D]
315
+ # covisibility_scores = feat_norm @ feat_norm.transpose(-1, -2)
316
+
317
+
318
+ ref_idx = covisibility_scores.argmax(-1) # [B, S] -> [B]
319
+ patch_tokens = patch_tokens.view(B,S,P,C)
320
+ patch_tokens = reorder_by_reference(patch_tokens, ref_idx)
321
+ patch_tokens = patch_tokens.view(B*S,P,C).contiguous()
322
+
323
+ ####################
324
+ # Expand camera and register tokens to match batch size and sequence length
325
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
326
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
327
+ # Concatenate special tokens with patch tokens
328
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) # [BS,1+4+HW,C]
329
+
330
+ pos = None
331
+ if self.rope is not None:
332
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
333
+
334
+ if self.patch_start_idx > 0:
335
+ # do not use position embedding for special tokens (camera and register tokens)
336
+ # so set pos to 0 for the special tokens
337
+ pos = pos + 1
338
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
339
+ pos = torch.cat([pos_special, pos], dim=1) # [BS, 1+4+HW, 2]
340
+
341
+
342
+
343
+ # update P because we added special tokens
344
+ _, P, C = tokens.shape
345
+
346
+ frame_idx = 0
347
+ global_idx = 0
348
+ output_list = []
349
+
350
+ for block_i in range(self.aa_block_num):
351
+ for attn_type in self.aa_order:
352
+ if attn_type == "frame":
353
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
354
+ tokens, B, S, P, C, frame_idx, pos=pos
355
+ )
356
+ elif attn_type == "global":
357
+ tokens, global_idx, global_intermediates = self._process_global_attention(
358
+ tokens, B, S, P, C, global_idx, pos=pos
359
+ )
360
+ else:
361
+ raise ValueError(f"Unknown attention type: {attn_type}")
362
+
363
+ for i in range(len(frame_intermediates)):
364
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
365
+ if (not self.training ) and (self.save_inference_memory) and (block_i not in [4,11,17,23]):
366
+ # only save the useful indices of intermediates
367
+ output_list.append(torch.tensor(0))
368
+ else:
369
+ # concat frame and global intermediates, [B x S x P x 2C]
370
+ output_list.append(concat_inter)
371
+
372
+ del concat_inter
373
+ del frame_intermediates
374
+ del global_intermediates
375
+ return output_list, self.patch_start_idx, covisibility_scores, ref_idx
376
+
377
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
378
+ """
379
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
380
+ """
381
+ # If needed, reshape tokens or positions:
382
+ if tokens.shape != (B * S, P, C):
383
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
384
+
385
+ if pos is not None and pos.shape != (B * S, P, 2):
386
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
387
+
388
+ intermediates = []
389
+
390
+ # by default, self.aa_block_size=1, which processes one block at a time
391
+ for _ in range(self.aa_block_size):
392
+ if self.training:
393
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
394
+ else:
395
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
396
+ frame_idx += 1
397
+ intermediates.append(tokens.view(B, S, P, C))
398
+
399
+ return tokens, frame_idx, intermediates
400
+
401
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
402
+ """
403
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
404
+ """
405
+ if tokens.shape != (B, S * P, C):
406
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
407
+
408
+ if pos is not None and pos.shape != (B, S * P, 2):
409
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
410
+
411
+ intermediates = []
412
+
413
+
414
+ # by default, self.aa_block_size=1, which processes one block at a time
415
+ for _ in range(self.aa_block_size):
416
+ if self.training:
417
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
418
+ else:
419
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
420
+ global_idx += 1
421
+
422
+ intermediates.append(tokens.view(B, S, P, C))
423
+
424
+
425
+
426
+ return tokens, global_idx, intermediates
427
+
428
+ def _ref_process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
429
+ """
430
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
431
+ """
432
+ # If needed, reshape tokens or positions:
433
+ if tokens.shape != (B * S, P, C):
434
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
435
+
436
+ if pos is not None and pos.shape != (B * S, P, 2):
437
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
438
+
439
+ intermediates = []
440
+
441
+ # by default, self.aa_block_size=1, which processes one block at a time
442
+ for _ in range(self.aa_block_size):
443
+ if self.training:
444
+ tokens = checkpoint(self.ref_frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
445
+ else:
446
+ tokens = self.ref_frame_blocks[frame_idx](tokens, pos=pos)
447
+ frame_idx += 1
448
+ intermediates.append(tokens.view(B, S, P, C))
449
+
450
+ return tokens, frame_idx, intermediates
451
+
452
+ def _ref_process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
453
+ """
454
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
455
+ """
456
+ if tokens.shape != (B, S * P, C):
457
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
458
+
459
+ if pos is not None and pos.shape != (B, S * P, 2):
460
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
461
+
462
+ intermediates = []
463
+
464
+
465
+ # by default, self.aa_block_size=1, which processes one block at a time
466
+ for _ in range(self.aa_block_size):
467
+ if self.training:
468
+ tokens = checkpoint(self.ref_global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
469
+ else:
470
+ tokens = self.ref_global_blocks[global_idx](tokens, pos=pos)
471
+ global_idx += 1
472
+
473
+ intermediates.append(tokens.view(B, S, P, C))
474
+
475
+
476
+
477
+ return tokens, global_idx, intermediates
478
+
479
+ def slice_expand_and_flatten(token_tensor, B, S):
480
+ """
481
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
482
+ 1) Uses the first position (index=0) for the first frame only
483
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
484
+ 3) Expands both to match batch size B
485
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
486
+ followed by (S-1) second-position tokens
487
+ 5) Flattens to (B*S, X, C) for processing
488
+
489
+ Returns:
490
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
491
+ """
492
+
493
+ # Slice out the "query" tokens => shape (1, 1, ...)
494
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
495
+ # Slice out the "other" tokens => shape (1, S-1, ...)
496
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
497
+ # Concatenate => shape (B, S, ...)
498
+ combined = torch.cat([query, others], dim=1)
499
+
500
+ # Finally flatten => shape (B*S, ...)
501
+ combined = combined.view(B * S, *combined.shape[2:])
502
+ return combined
argus/models/argus.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Dict
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+ # Import model components
7
+ from argus.models.aggregator import Aggregator
8
+ from argus.heads.camera_head import CameraHead
9
+ from argus.heads.dpt_head import DPTHead
10
+ from argus.heads.utils import reorder_by_reference
11
+
12
+
13
+ class Argus(nn.Module, PyTorchModelHubMixin):
14
+ """
15
+ Argus multi-task vision model for camera pose estimation, depth prediction, and 3D points.
16
+
17
+ Integrates an aggregator backbone with task-specific heads for:
18
+ - Camera pose encoding
19
+ - Depth map prediction
20
+ - 3D camera/rotated/world point prediction
21
+
22
+ Args:
23
+ img_size: Input image size (height/width, assumes square) (default: 518)
24
+ patch_size: Patch size for vision transformer backbone (default: 14)
25
+ embed_dim: Embedding dimension for transformer features (default: 1024)
26
+ enable_camera: Enable camera pose estimation head (default: True)
27
+ enable_depth: Enable depth prediction head (default: True)
28
+ enable_cam_point: Enable camera coordinate 3D point prediction head (default: False)
29
+ enable_rotated_point: Enable rotated 3D point prediction head (default: False)
30
+ enable_point: Enable world coordinate 3D point prediction head (default: False, Please do not set it to True during training)
31
+
32
+ Note:
33
+ All heads share the same aggregated transformer features from the Aggregator backbone.
34
+ Each DPT-based head outputs both predictions and confidence scores.
35
+ """
36
+ def __init__(
37
+ self,
38
+ img_size: int = 518,
39
+ patch_size: int = 14,
40
+ embed_dim: int = 1024,
41
+ enable_camera: bool = True,
42
+ enable_depth: bool = True,
43
+ enable_cam_point: bool = False,
44
+ enable_rotated_point: bool = False,
45
+ enable_point: bool = False,
46
+ reorder_by_learning_ref: bool = True,
47
+ restore_metric_scale: bool = False
48
+ ) -> None:
49
+ super().__init__()
50
+ # For inference
51
+ self.restore_metric_scale = restore_metric_scale
52
+ self.reorder_by_learning_ref = reorder_by_learning_ref
53
+
54
+ # Backbone and geometry transformer
55
+ self.aggregator = Aggregator(
56
+ img_size=img_size,
57
+ patch_size=patch_size,
58
+ embed_dim=embed_dim,
59
+ reorder_by_learning_ref=reorder_by_learning_ref,
60
+ )
61
+
62
+ # Task-specific prediction heads (lazy initialization based on flags)
63
+ self.camera_head: Optional[CameraHead] = CameraHead(dim_in=2 * embed_dim) if enable_camera else None
64
+ self.depth_head: Optional[DPTHead] = DPTHead(
65
+ dim_in=2 * embed_dim,
66
+ output_dim=2,
67
+ activation="exp",
68
+ conf_activation="expp1"
69
+ ) if enable_depth else None
70
+
71
+ # 3D point prediction heads (shared architecture, different output semantics)
72
+ self.cam_point_head: Optional[DPTHead] = DPTHead(
73
+ dim_in=2 * embed_dim,
74
+ output_dim=4,
75
+ activation="inv_log",
76
+ conf_activation="expp1"
77
+ ) if enable_cam_point else None
78
+
79
+ self.rotated_point_head: Optional[DPTHead] = DPTHead(
80
+ dim_in=2 * embed_dim,
81
+ output_dim=4,
82
+ activation="inv_log",
83
+ conf_activation="expp1"
84
+ ) if enable_rotated_point else None
85
+
86
+ self.point_head: Optional[DPTHead] = DPTHead(
87
+ dim_in=2 * embed_dim,
88
+ output_dim=4,
89
+ activation="inv_log",
90
+ conf_activation="expp1"
91
+ ) if enable_point else None
92
+
93
+ def forward(
94
+ self,
95
+ images: torch.Tensor,
96
+ ) -> Dict[str, torch.Tensor]:
97
+ """
98
+ Forward pass of the Argus model.
99
+
100
+ Automatically adds batch dimension if missing and processes multi-task predictions.
101
+
102
+ Args:
103
+ images: Input RGB images with shape:
104
+ - [S, 3, H, W] (sequence without batch) or
105
+ - [B, S, 3, H, W] (batch of sequences)
106
+ Values in range [0, 1], where:
107
+ - B: batch size
108
+ - S: sequence length (number of frames)
109
+ - 3: RGB channels
110
+ - H/W: image height/width (matches img_size)
111
+
112
+ Returns:
113
+ Dictionary of model predictions with task-specific outputs:
114
+ Common outputs:
115
+ - covisibility_scores: Covisibility scores from aggregator (shape varies)
116
+ - ref_idx: Reference frame indices (shape varies)
117
+
118
+ Camera head outputs (if enabled):
119
+ - pose_enc: Final camera pose encoding [B, S, 9]
120
+ - pose_enc_list: List of pose encodings from all iterations [List[torch.Tensor]]
121
+
122
+ Depth head outputs (if enabled):
123
+ - depth: Predicted depth maps [B, S, H, W, 1]
124
+ - depth_conf: Depth prediction confidence [B, S, H, W]
125
+
126
+ Camera point head outputs (if enabled):
127
+ - cam_points: 3D camera coordinates per pixel [B, S, H, W, 3]
128
+ - cam_points_conf: Camera point confidence [B, S, H, W]
129
+
130
+ Rotated point head outputs (if enabled):
131
+ - rotated_points: Rotated 3D coordinates per pixel [B, S, H, W, 3]
132
+ - rotated_points_conf: Rotated point confidence [B, S, H, W]
133
+
134
+ World point head outputs (if enabled):
135
+ - world_points: 3D world coordinates per pixel [B, S, H, W, 3]
136
+ - world_points_conf: World point confidence [B, S, H, W]
137
+
138
+ Inference-only outputs (not training):
139
+ - images: Original input images (for visualization) [B, S, 3, H, W]
140
+ """
141
+ # Add batch dimension if missing (handle [S,3,H,W] -> [1,S,3,H,W])
142
+ if len(images.shape) == 4:
143
+ images = images.unsqueeze(0)
144
+
145
+ # Extract aggregated features from backbone
146
+ (
147
+ aggregated_tokens_list, # List of aggregated transformer tokens across iterations
148
+ patch_start_idx, # Patch start indices for feature reconstruction
149
+ covisibility_scores, # Covisibility scores between frames
150
+ ref_idx # Reference frame indices
151
+ ) = self.aggregator(images)
152
+
153
+ # Initialize prediction dictionary
154
+ predictions: Dict[str, torch.Tensor] = {}
155
+
156
+ # Disable mixed precision for precise prediction calculations
157
+ with torch.amp.autocast("cuda", enabled=False):
158
+ # Add aggregator outputs to predictions
159
+ if covisibility_scores is not None:
160
+ predictions["covisibility_scores"] = covisibility_scores
161
+ if ref_idx is not None:
162
+ predictions["ref_idx"] = ref_idx
163
+
164
+ # Camera pose prediction (if enabled)
165
+ if self.camera_head is not None:
166
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
167
+ predictions["pose_enc"] = pose_enc_list[-1] # Use final iteration encoding
168
+ predictions["pose_enc_list"] = pose_enc_list # Mutil-layer supervision
169
+
170
+ # Depth prediction (if enabled)
171
+ if self.depth_head is not None:
172
+ depth, depth_conf = self.depth_head(
173
+ aggregated_tokens_list,
174
+ images=images,
175
+ patch_start_idx=patch_start_idx
176
+ )
177
+ predictions["depth"] = depth
178
+ predictions["depth_conf"] = depth_conf
179
+
180
+ # Camera 3D point prediction (if enabled)
181
+ if self.cam_point_head is not None:
182
+ cam_pts3d, cam_pts3d_conf = self.cam_point_head(
183
+ aggregated_tokens_list,
184
+ images=images,
185
+ patch_start_idx=patch_start_idx
186
+ )
187
+ predictions["cam_points"] = cam_pts3d
188
+ predictions["cam_points_conf"] = cam_pts3d_conf
189
+
190
+ # Rotated 3D point prediction (if enabled)
191
+ if self.rotated_point_head is not None:
192
+ rotated_pts3d, rotated_pts3d_conf = self.rotated_point_head(
193
+ aggregated_tokens_list,
194
+ images=images,
195
+ patch_start_idx=patch_start_idx
196
+ )
197
+ predictions["rotated_points"] = rotated_pts3d
198
+ predictions["rotated_points_conf"] = rotated_pts3d_conf
199
+
200
+ # World 3D point prediction (if enabled)
201
+ if self.point_head is not None:
202
+ world_pts3d, world_pts3d_conf = self.point_head(
203
+ aggregated_tokens_list,
204
+ images=images,
205
+ patch_start_idx=patch_start_idx
206
+ )
207
+ predictions["world_points"] = world_pts3d
208
+ predictions["world_points_conf"] = world_pts3d_conf
209
+
210
+
211
+ # Store input images for visualization during inference (skip in training)
212
+ if not self.training:
213
+ predictions["images"] = images
214
+ if "ref_idx" in predictions:
215
+ ref_idx = predictions["ref_idx"].detach()
216
+ # Reorder all spatial/temporal data (exclude adjacency matrix and IDs)
217
+ predictions["images"] = reorder_by_reference(predictions["images"], ref_idx)
218
+
219
+ if self.restore_metric_scale:
220
+ # Restore metric scale
221
+ abs_scale = 10.0
222
+ if self.camera_head is not None:
223
+ predictions["pose_enc"][...,:3] *= abs_scale
224
+ if self.depth_head is not None:
225
+ predictions["depth"] *= abs_scale
226
+ if self.cam_point_head is not None:
227
+ predictions["cam_points"] *= abs_scale
228
+ if self.rotated_point_head is not None:
229
+ predictions["rotated_points"] *= abs_scale
230
+ if self.point_head is not None:
231
+ predictions["world_points"] *= abs_scale
232
+
233
+
234
+ return predictions
argus/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2026 Realsee. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0.
argus/utils/data_io.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Realsee. All rights reserved.
2
+ # Licensed under the Apache License, Version 2.0.
3
+
4
+ """
5
+ Shared I/O and preprocessing utilities for panoramic image data.
6
+
7
+ These functions are used by both evaluation and training pipelines.
8
+ """
9
+
10
+ import os
11
+
12
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
13
+ import cv2
14
+ import numpy as np
15
+
16
+
17
+ def read_image_cv2_360(path: str, rgb: bool = True, shape=(560, 280)) -> np.ndarray:
18
+ """Read and resize a 360 panorama image.
19
+
20
+ Args:
21
+ path: Path to the image file.
22
+ rgb: If True, convert BGR to RGB (default: True).
23
+ shape: Target (width, height) tuple.
24
+
25
+ Returns:
26
+ Image as numpy array with shape (H, W, 3).
27
+ """
28
+ img = cv2.imread(path)
29
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
30
+ if img.shape[1] != shape[0]:
31
+ img = cv2.resize(img, shape, interpolation=cv2.INTER_AREA)
32
+ return img
33
+
34
+
35
+ def read_depth_360(path: str, depth_scale=5000.0, shape=(560, 280)) -> np.ndarray:
36
+ """Read and normalize a 360 depth map.
37
+
38
+ Args:
39
+ path: Path to the depth image file.
40
+ depth_scale: Scale factor to convert raw depth to meters.
41
+ shape: Target (width, height) tuple.
42
+
43
+ Returns:
44
+ Depth map as float32 numpy array with shape (H, W).
45
+ """
46
+ d = cv2.imread(path, cv2.IMREAD_UNCHANGED)
47
+ if d.shape[1] != shape[0]:
48
+ d = cv2.resize(d, shape, interpolation=cv2.INTER_NEAREST)
49
+ d = d.astype(np.float32) / depth_scale
50
+ return d
51
+
52
+
53
+ def random_rotate_theta(W=560, max_shift_percent=0.5):
54
+ """Generate a random rotation angle for panorama augmentation.
55
+
56
+ Args:
57
+ W: Panorama width in pixels.
58
+ max_shift_percent: Maximum horizontal shift as fraction of width.
59
+
60
+ Returns:
61
+ Rotation angle in radians.
62
+ """
63
+ max_shift = int(W * max_shift_percent)
64
+ shift_pixels = np.random.randint(-max_shift, max_shift + 1)
65
+ theta = (shift_pixels * 2 * np.pi) / W
66
+ return theta
67
+
68
+
69
+ def rotate_y(theta):
70
+ """Create a 3x3 rotation matrix around the Y-axis.
71
+
72
+ Args:
73
+ theta: Rotation angle in radians.
74
+
75
+ Returns:
76
+ 3x3 rotation matrix as float64 numpy array.
77
+ """
78
+ cos_theta = np.cos(theta)
79
+ sin_theta = np.sin(theta)
80
+ return np.array(
81
+ [[cos_theta, 0, -sin_theta], [0, 1, 0], [sin_theta, 0, cos_theta]],
82
+ dtype=np.float64,
83
+ )
84
+
85
+
86
+ def pano_depth_to_points(depth_map, pano_shape=(560, 280), crop=True, crop_ratio=0.15):
87
+ """Convert a panorama depth map to 3D point cloud.
88
+
89
+ Args:
90
+ depth_map: 2D depth map (H, W) or flattened array.
91
+ pano_shape: Original panorama (width, height) tuple.
92
+ crop: Whether the depth map has been vertically cropped.
93
+ crop_ratio: Crop ratio applied to top and bottom.
94
+
95
+ Returns:
96
+ Point cloud as numpy array with shape (N, 3).
97
+ """
98
+ w, h = pano_shape
99
+
100
+ if not crop:
101
+ px = np.tile(np.arange(w), int(h))
102
+ py = np.arange(0, int(h)).repeat(w)
103
+ else:
104
+ px = np.tile(np.arange(w), int(h * (1 - 2 * crop_ratio)))
105
+ py = np.arange(int(crop_ratio * h), int((1 - crop_ratio) * h)).repeat(w)
106
+
107
+ dist = depth_map.reshape(-1)
108
+
109
+ lat = (py / h - 0.5) * np.pi
110
+ long = (px / w - 0.5) * np.pi * 2.0
111
+
112
+ y = dist * np.sin(lat)
113
+ tmp = dist * np.cos(lat)
114
+ x = tmp * np.sin(long)
115
+ z = tmp * np.cos(long)
116
+
117
+ point_map = np.concatenate([i.reshape(-1, 1) for i in (x, y, z)], axis=-1)
118
+
119
+ return point_map # (h*w, 3)
120
+
121
+
122
+ def crop_panorama(pano, crop_ratio=0.15):
123
+ """Crop the top and bottom of a panorama by a given ratio.
124
+
125
+ Args:
126
+ pano: Input panorama array with shape (H, W, ...).
127
+ crop_ratio: Fraction to crop from top and bottom.
128
+
129
+ Returns:
130
+ Cropped panorama.
131
+ """
132
+ H, W = pano.shape[:2]
133
+ crop_H_top = int(crop_ratio * H)
134
+ crop_H_bottom = H - int(crop_ratio * H)
135
+ crop_pano = pano[crop_H_top:crop_H_bottom, ...]
136
+ return crop_pano
137
+
138
+
139
+ def rotate_panorama(panorama, theta):
140
+ """Horizontally rotate a panorama by shifting pixels.
141
+
142
+ Args:
143
+ panorama: Input panorama array with shape (H, W, ...).
144
+ theta: Rotation angle in radians.
145
+
146
+ Returns:
147
+ Shifted panorama.
148
+ """
149
+ H, W = panorama.shape[:2]
150
+ shift_pixels = int((theta * W) / (2 * np.pi))
151
+ shifted = np.roll(panorama, shift_pixels, axis=1)
152
+ return shifted
argus/utils/geometry.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def closed_form_inverse_se3(se3, R=None, T=None):
6
+ """
7
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
8
+
9
+ If `R` and `T` are provided, they must correspond to the rotation and translation
10
+ components of `se3`. Otherwise, they will be extracted from `se3`.
11
+
12
+ Args:
13
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
14
+ R (optional): Nx3x3 array or tensor of rotation matrices.
15
+ T (optional): Nx3x1 array or tensor of translation vectors.
16
+
17
+ Returns:
18
+ Inverted SE3 matrices with the same type and device as `se3`.
19
+
20
+ Shapes:
21
+ se3: (N, 4, 4)
22
+ R: (N, 3, 3)
23
+ T: (N, 3, 1)
24
+ """
25
+ # Check if se3 is a numpy array or a torch tensor
26
+ is_numpy = isinstance(se3, np.ndarray)
27
+
28
+ # Validate shapes
29
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
30
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
31
+
32
+ # Extract R and T if not provided
33
+ if R is None:
34
+ R = se3[:, :3, :3] # (N,3,3)
35
+ if T is None:
36
+ T = se3[:, :3, 3:] # (N,3,1)
37
+
38
+ # Transpose R
39
+ if is_numpy:
40
+ # Compute the transpose of the rotation for NumPy
41
+ R_transposed = np.transpose(R, (0, 2, 1))
42
+ # -R^T t for NumPy
43
+ top_right = -np.matmul(R_transposed, T)
44
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
45
+ else:
46
+ R_transposed = R.transpose(1, 2) # (N,3,3)
47
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
48
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
49
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
50
+
51
+ inverted_matrix[:, :3, :3] = R_transposed
52
+ inverted_matrix[:, :3, 3:] = top_right
53
+
54
+ return inverted_matrix
55
+
56
+
57
+ def pano_depth_to_points(depth_map, original_pano_shape=(560, 280), crop_ratio=0.15):
58
+ """
59
+ Convert batched cropped panoramic depth maps to 3D point clouds (PyTorch implementation).
60
+ Assumption: Input depth maps are already cropped by crop_ratio on top and bottom.
61
+
62
+ Args:
63
+ depth_map (torch.Tensor): Input cropped depth map, shape [B, S, H_crop, W, 1]
64
+ original_pano_shape (tuple): Original uncropped panorama size (W_ori, H_ori), default (560, 280)
65
+ crop_ratio (float): Crop ratio of original panorama (top and bottom respectively), default 0.15
66
+
67
+ Returns:
68
+ torch.Tensor: 3D point cloud with shape [B, S, H_crop, W, 3]
69
+ """
70
+ # Validate input shape
71
+ assert depth_map.dim() == 5 and depth_map.shape[-1] == 1, \
72
+ f"Input must be [B, S, H_crop, W, 1], got {depth_map.shape}"
73
+
74
+ B, S, H_crop, W, _ = depth_map.shape
75
+ W_ori, H_ori = original_pano_shape
76
+ device = depth_map.device # Align tensor device automatically
77
+
78
+ # Generate pixel grid coordinates (H_crop, W)
79
+ px_grid, py_grid = torch.meshgrid(
80
+ torch.arange(W, device=device),
81
+ torch.arange(H_crop, device=device),
82
+ indexing='xy' # Consistent with numpy's meshgrid
83
+ )
84
+
85
+ # Restore to original panorama y-coordinates (compensate for cropping)
86
+ crop_top = int(crop_ratio * H_ori)
87
+ py_ori = py_grid + crop_top
88
+
89
+ # Compute spherical coordinates (lat: latitude, long: longitude)
90
+ lat = (py_ori / H_ori - 0.5) * torch.pi
91
+ long = (px_grid / W_ori - 0.5) * 2 * torch.pi
92
+
93
+ # Remove channel dim and compute 3D Cartesian coordinates
94
+ dist = depth_map.squeeze(-1) # [B, S, H_crop, W]
95
+ y = dist * torch.sin(lat)
96
+ tmp = dist * torch.cos(lat)
97
+ x = tmp * torch.sin(long)
98
+ z = tmp * torch.cos(long)
99
+
100
+ # Concatenate to form 3D point cloud
101
+ point_cloud = torch.stack([x, y, z], dim=-1)
102
+
103
+ return point_cloud
104
+
105
+
106
+ def points_to_pano_depth(points):
107
+ """
108
+ Convert 3D point cloud back to ray panoramic depth map.
109
+ Ignore the error in direction.
110
+
111
+ Args:
112
+ points (torch.Tensor): Input 3D point cloud, shape [B, S, H, W, 3]
113
+
114
+ Returns:
115
+ torch.Tensor: panoramic depth map, shape [B, S, H, W, 1]
116
+ """
117
+ # Validate input shape and fill mode
118
+ assert points.dim() == 5 and points.shape[-1] == 3, \
119
+ f"Input point cloud must be [B, S, H, W, 3], got {points.shape}"
120
+
121
+ # Compute radial depth (dist = sqrt(x² + y² + z²))
122
+ dist = torch.norm(points, dim=-1, keepdim=True) # [B, S, H, W, 1]
123
+
124
+ return dist
125
+
126
+
127
+ def camera_points_to_rotated_points(cam_points, R):
128
+ """
129
+ Rotate batched panoramic camera point clouds with corresponding rotation matrices.
130
+
131
+ Args:
132
+ cam_points (torch.Tensor): Input camera 3D point cloud, shape [B, S, H, W, 3]
133
+ R (torch.Tensor): Corresponding rotation matrices, shape [B, S, 3, 3]
134
+
135
+ Returns:
136
+ torch.Tensor: Rotated 3D point cloud, shape [B, S, H, W, 3] (same as input cam_points)
137
+ """
138
+ # Validate input shapes and dimensions matching
139
+ assert cam_points.dim() == 5 and cam_points.shape[-1] == 3, \
140
+ f"Camera points must be [B, S, H, W, 3], got {cam_points.shape}"
141
+ assert R.dim() == 4 and R.shape[2:] == (3, 3), \
142
+ f"Rotation matrices R must be [B, S, 3, 3], got {R.shape}"
143
+ assert cam_points.shape[:2] == R.shape[:2], \
144
+ f"Batch/Sequence dim mismatch: cam_points {cam_points.shape[:2]} vs R {R.shape[:2]}"
145
+
146
+ # Expand dimensions for broadcasting (align spatial dimensions H, W)
147
+ cam_points_expanded = cam_points.unsqueeze(-1) # [B, S, H, W, 3, 1]
148
+ R_expanded = R.unsqueeze(2).unsqueeze(2) # [B, S, 1, 1, 3, 3]
149
+
150
+ # Batch matrix multiplication: R @ p (rotation operation)
151
+ rotated_points_expanded = torch.matmul(R_expanded, cam_points_expanded)
152
+
153
+ # Squeeze redundant dimension to recover original shape
154
+ rotated_points = rotated_points_expanded.squeeze(-1)
155
+
156
+ return rotated_points
157
+
158
+
159
+ def rotated_points_to_world_points(rotated_points, t):
160
+ """
161
+ Transform rotated camera points to world coordinates by adding translation vector.
162
+
163
+ Args:
164
+ rotated_points (torch.Tensor): Rotated 3D point cloud, shape [B, S, H, W, 3]
165
+ t (torch.Tensor): Translation vector, shape [B, S, 3] (per batch-sequence translation)
166
+
167
+ Returns:
168
+ torch.Tensor: World-coordinate 3D point cloud, shape [B, S, H, W, 3] (same as input)
169
+ """
170
+ # Validate input shapes and dimension matching
171
+ assert rotated_points.dim() == 5 and rotated_points.shape[-1] == 3, \
172
+ f"Rotated points must be [B, S, H, W, 3], got {rotated_points.shape}"
173
+ assert t.dim() == 3 and t.shape[-1] == 3, \
174
+ f"Translation t must be [B, S, 3], got {t.shape}"
175
+ assert rotated_points.shape[:2] == t.shape[:2], \
176
+ f"Batch/Sequence dim mismatch: rotated_points {rotated_points.shape[:2]} vs t {t.shape[:2]}"
177
+
178
+ # Expand translation dimensions for broadcasting with spatial dimensions (H, W)
179
+ # t: [B, S, 3] -> [B, S, 1, 1, 3] (broadcast to H and W)
180
+ t_expanded = t.unsqueeze(2).unsqueeze(2)
181
+
182
+ # Add translation (broadcasting automatically applies t to all H×W points per B-S pair)
183
+ world_points = rotated_points + t_expanded
184
+
185
+ return world_points
186
+
187
+
188
+
189
+ def unproject_depth_to_world_points(depth, extrinsic, size=560):
190
+ '''
191
+ Args:
192
+ depth: [S, H, W, 1]
193
+ extrinsic: [S, 4, 4]
194
+ Returns:
195
+ world_points: [S, H, W, 3]
196
+ '''
197
+ camera_points = pano_depth_to_points(depth, original_pano_shape=(size, size//2))
198
+ rotated_points = camera_points_to_rotated_points(camera_points, extrinsic[:, :, :3, :3])
199
+ world_points = rotated_points_to_world_points(rotated_points, extrinsic[:, :, :3, 3])
200
+
201
+ return world_points
argus/utils/normalization.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple
3
+ from argus.utils.geometry import closed_form_inverse_se3
4
+
5
+ def cal_scale_by_points(points: torch.Tensor, point_masks: torch.Tensor) -> torch.Tensor:
6
+ # Calculate average distance of valid 3D points (batch-wise)
7
+ dist = points.norm(dim=-1)
8
+ dist_sum = (dist * point_masks).sum(dim=[1, 2, 3]) # Shape: [B,]
9
+ valid_count = point_masks.sum(dim=[1, 2, 3])
10
+ avg_scale = (dist_sum / (valid_count + 1e-3)).clamp(min=1e-6, max=1e6)
11
+ return avg_scale
12
+
13
+ def normalize_camera_extrinsics_and_points_batch(
14
+ extrinsics: torch.Tensor,
15
+ cam_points: torch.Tensor,
16
+ depths: torch.Tensor,
17
+ point_masks: torch.Tensor,
18
+ scale_mode: str = "none",
19
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
20
+ # Basic input validation
21
+ assert extrinsics.ndim == 4 and extrinsics.shape[2:] == (4, 4), \
22
+ f"Extrinsics must be (B, S, 4, 4), got {extrinsics.shape}"
23
+ B, S = extrinsics.shape[:2]
24
+ device = extrinsics.device
25
+
26
+ # Step 1: Transform all extrinsics to reference frame (1st frame of each batch)
27
+ ref_extrinsics = extrinsics[:,0,:,:] # (B, 4, 4)
28
+ ref_extr_inv = closed_form_inverse_se3(ref_extrinsics)
29
+ new_extrinsics = torch.matmul(ref_extr_inv.unsqueeze(1), extrinsics) # (B, S, 4, 4) world coordinate
30
+
31
+ # Step 2: Clone tensors to avoid in-place modification
32
+ new_depths = depths.clone()
33
+ new_cam_points = cam_points.clone()
34
+
35
+ # Step 3: Compute rotated/world points from new extrinsics
36
+ R_new = new_extrinsics[:, :, :3, :3] # (B, S, 3, 3)
37
+ t_new = new_extrinsics[:, :, :3, 3] # (B, S, 3)
38
+ new_rotated_points = torch.matmul(R_new.unsqueeze(2).unsqueeze(3), new_cam_points.unsqueeze(-1)).squeeze(-1) # (B,S,1,1,3,3) × (B,S,H,W,3,1) -> (B,S,H,W,3)
39
+ new_world_points = new_rotated_points + t_new.unsqueeze(2).unsqueeze(3)
40
+
41
+ # Step 4: Apply scene scaling
42
+ if scale_mode == "avg_dist":
43
+ avg_scale = cal_scale_by_points(new_world_points, point_masks) # (B,)
44
+ # Reshape scale for broadcasting with different tensor shapes
45
+ scale_3d = avg_scale.view(-1, 1, 1) # For extrinsics (B, S, 4, 4)
46
+ scale_4d = avg_scale.view(-1, 1, 1, 1) # For depths (B, S, H, W)
47
+ scale_5d = avg_scale.view(-1, 1, 1, 1, 1) # For 3D points (B, S, H, W, 3)
48
+ new_extrinsics[:, :, :3, 3] /= scale_3d
49
+ new_depths /= scale_4d
50
+ new_cam_points /= scale_5d
51
+ new_rotated_points /= scale_5d
52
+ new_world_points /= scale_5d
53
+ elif scale_mode == "abs":
54
+ metric_scale = 10.0
55
+ new_extrinsics[:, :, :3, 3] /= metric_scale
56
+ new_depths /= metric_scale
57
+ new_cam_points /= metric_scale
58
+ new_rotated_points /= metric_scale
59
+ new_world_points /= metric_scale
60
+ elif scale_mode == "none":
61
+ pass
62
+ else:
63
+ raise ValueError(f"Unknown scale_mode: {scale_mode}")
64
+
65
+ return new_extrinsics, new_cam_points, new_rotated_points, new_world_points, new_depths
argus/utils/pose_enc.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Union
3
+ from .rotation import quat_to_mat, mat_to_quat
4
+
5
+
6
+ def extri_to_pose_encoding360(
7
+ extrinsics: torch.Tensor,
8
+ pose_encoding_type: Union[str, "absT_quaR"] = "absT_quaR"
9
+ ) -> torch.Tensor:
10
+ """
11
+ Convert camera extrinsic parameters to a compact pose encoding (absolute translation + quaternion rotation).
12
+
13
+ Transforms OpenCV-style camera extrinsics (3x4 [R|t] matrix) into a flattened encoding format
14
+ suitable for machine learning tasks like pose prediction or representation learning.
15
+
16
+ Args:
17
+ extrinsics: Camera extrinsic matrices with shape [B, S, 3, 4] or [B, S, 4, 4]
18
+ - B: Batch size
19
+ - S: Sequence length (number of frames)
20
+ - 3x4/4x4: Extrinsic matrix in OpenCV coordinate system (x-right, y-down, z-forward)
21
+ representing the transformation from world to camera space ([R|t] where R=3x3 rotation, t=3x1 translation)
22
+ pose_encoding_type: Type of pose encoding format (only "absT_quaR" supported):
23
+ - "absT_quaR": Absolute translation (3D) + quaternion rotation (4D)
24
+
25
+ Returns:
26
+ Encoded pose tensor with shape [B, S, 7]
27
+ - [:3]: Absolute translation vector (T) in world coordinates
28
+ - [3:7]: Rotation represented as unit quaternion (quat)
29
+ """
30
+ # Extract rotation matrix (R) and translation vector (T) from extrinsics
31
+ # Handle both 3x4 and 4x4 extrinsic matrix inputs
32
+ R = extrinsics[:, :, :3, :3] # [B, S, 3, 3] - rotation matrix
33
+ T = extrinsics[:, :, :3, 3] # [B, S, 3] - translation vector
34
+
35
+ if pose_encoding_type == "absT_quaR":
36
+ # Convert rotation matrix to quaternion (4D)
37
+ quat = mat_to_quat(R)
38
+
39
+ # Concatenate translation and quaternion to form compact pose encoding
40
+ pose_encoding = torch.cat([T, quat], dim=-1).float()
41
+ else:
42
+ raise NotImplementedError(f"Pose encoding type '{pose_encoding_type}' not supported. Only 'absT_quaR' is implemented.")
43
+
44
+ return pose_encoding
45
+
46
+
47
+ def pose_encoding_to_extri360(
48
+ pose_encoding: torch.Tensor,
49
+ pose_encoding_type: Union[str, "absT_quaR"] = "absT_quaR"
50
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
51
+ """
52
+ Convert compact pose encoding back to full camera extrinsic parameters (inverse of extri_to_pose_encoding360).
53
+
54
+ Reconstructs the 4x4 homogeneous extrinsic matrix from the flattened pose encoding,
55
+ including extraction of confidence scores from the encoding's extra dimensions.
56
+
57
+ Args:
58
+ pose_encoding: Encoded pose tensor with shape [B, S, 9]
59
+ - B: Batch size
60
+ - S: Sequence length (number of frames)
61
+ - [:3]: Absolute translation vector (T)
62
+ - [3:7]: Rotation quaternion (quat)
63
+ - [-2:]: Confidence scores for translation and rotation
64
+ pose_encoding_type: Type of pose encoding format (only "absT_quaR" supported):
65
+ - "absT_quaR": Absolute translation (3D) + quaternion rotation (4D)
66
+
67
+ Returns:
68
+ Tuple containing:
69
+ 1. extrinsics: Reconstructed camera extrinsic matrices with shape [B, S, 4, 4]
70
+ (homogeneous matrix in OpenCV coordinate system: [R|t; 0 0 0 1])
71
+ 2. conf: Confidence scores with shape [B, S, 2]
72
+ - [:, :, 0]: Translation confidence
73
+ - [:, :, 1]: Rotation confidence
74
+
75
+ Raises:
76
+ NotImplementedError: If unsupported pose encoding type is provided
77
+ """
78
+ if pose_encoding_type == "absT_quaR":
79
+ # Extract translation (T) and rotation quaternion (quat) from pose encoding
80
+ T = pose_encoding[..., :3] # [B, S, 3] - translation vector
81
+ quat = pose_encoding[..., 3:7] # [B, S, 4] - rotation quaternion
82
+
83
+ # Convert quaternion back to rotation matrix (3x3)
84
+ R = quat_to_mat(quat) # [B, S, 3, 3]
85
+
86
+ # Reconstruct 3x4 [R|t] matrix (rotation + translation)
87
+ extri_3x4 = torch.cat([R, T[..., None]], dim=-1) # [B, S, 3, 4]
88
+
89
+ # Add homogeneous row [0, 0, 0, 1] to form 4x4 extrinsic matrix
90
+ batch_size, seq_len = extri_3x4.shape[:2]
91
+ homogenous_row = torch.tensor(
92
+ [0, 0, 0, 1],
93
+ device=extri_3x4.device,
94
+ dtype=extri_3x4.dtype
95
+ ).expand(batch_size, seq_len, 1, 4) # [B, S, 1, 4]
96
+
97
+ # Combine to form 4x4 homogeneous extrinsic matrix
98
+ extrinsics = torch.cat((extri_3x4, homogenous_row), dim=2) # [B, S, 4, 4]
99
+
100
+ # Extract confidence scores (last two dimensions of pose encoding)
101
+ conf = pose_encoding[..., -2:] # [B, S, 2]
102
+
103
+ return extrinsics, conf
104
+
105
+ raise NotImplementedError(f"Pose encoding type '{pose_encoding_type}' not supported. Only 'absT_quaR' is implemented.")
argus/utils/rotation.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
7
+ """
8
+ Quaternion Order: XYZW or say ijkr, scalar-last
9
+
10
+ Convert rotations given as quaternions to rotation matrices.
11
+ Args:
12
+ quaternions: quaternions with real part last,
13
+ as tensor of shape (..., 4).
14
+
15
+ Returns:
16
+ Rotation matrices as tensor of shape (..., 3, 3).
17
+ """
18
+ # Normalize quaternions to unit length
19
+ quaternions = F.normalize(quaternions, dim=-1)
20
+
21
+ i, j, k, r = torch.unbind(quaternions, -1)
22
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
23
+
24
+ o = torch.stack(
25
+ (
26
+ 1 - two_s * (j * j + k * k),
27
+ two_s * (i * j - k * r),
28
+ two_s * (i * k + j * r),
29
+ two_s * (i * j + k * r),
30
+ 1 - two_s * (i * i + k * k),
31
+ two_s * (j * k - i * r),
32
+ two_s * (i * k - j * r),
33
+ two_s * (j * k + i * r),
34
+ 1 - two_s * (i * i + j * j),
35
+ ),
36
+ -1,
37
+ )
38
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
39
+
40
+
41
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Convert rotations given as rotation matrices to quaternions.
44
+
45
+ Args:
46
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
47
+
48
+ Returns:
49
+ quaternions with real part last, as tensor of shape (..., 4).
50
+ Quaternion Order: XYZW or say ijkr, scalar-last
51
+ """
52
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
53
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
54
+
55
+ batch_dim = matrix.shape[:-2]
56
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
57
+
58
+ q_abs = _sqrt_positive_part(
59
+ torch.stack(
60
+ [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
61
+ )
62
+ )
63
+
64
+ # we produce the desired quaternion multiplied by each of r, i, j, k
65
+ quat_by_rijk = torch.stack(
66
+ [
67
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
68
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
69
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
70
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
71
+ ],
72
+ dim=-2,
73
+ )
74
+
75
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
76
+ # the candidate won't be picked.
77
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
78
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
79
+
80
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
81
+ # forall i; we pick the best-conditioned one (with the largest denominator)
82
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
83
+
84
+ # Convert from rijk to ijkr
85
+ out = out[..., [1, 2, 3, 0]]
86
+
87
+ out = standardize_quaternion(out)
88
+
89
+ return out
90
+
91
+
92
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ Returns torch.sqrt(torch.max(0, x))
95
+ but with a zero subgradient where x is 0.
96
+ """
97
+ ret = torch.zeros_like(x)
98
+ positive_mask = x > 0
99
+ if torch.is_grad_enabled():
100
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
101
+ else:
102
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
103
+ return ret
104
+
105
+
106
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
107
+ """
108
+ Convert a unit quaternion to a standard form: one in which the real
109
+ part is non negative.
110
+
111
+ Args:
112
+ quaternions: Quaternions with real part last,
113
+ as tensor of shape (..., 4).
114
+
115
+ Returns:
116
+ Standardized quaternions as tensor of shape (..., 4).
117
+ """
118
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
assets/argus_logo.png ADDED

Git LFS Details

  • SHA256: 84ce97bea0cd15bca4e4655fa32b2e0a37b82c904963e3a79afd1f521ab89173
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB
examples/far_4/0.jpg ADDED

Git LFS Details

  • SHA256: 20ef38597f7839149d86ed294f7c878db26e9c895a36f40c48777e7d6c6d367f
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
examples/far_4/1.jpg ADDED

Git LFS Details

  • SHA256: 2283054165720ace4163571dd8360d1f7df9dd94ca58e9b6f5c3d81f1c3e4220
  • Pointer size: 130 Bytes
  • Size of remote file: 90.7 kB
examples/far_4/2.jpg ADDED

Git LFS Details

  • SHA256: 8408263518733a6ee177771601977b8b99ebbb802155abf093bbfe3b033a2f69
  • Pointer size: 130 Bytes
  • Size of remote file: 89.9 kB
examples/far_4/3.jpg ADDED

Git LFS Details

  • SHA256: 84694d2d5d75f82f5e9ecdc523d932e77179989990fc49f926021d26aea2877e
  • Pointer size: 130 Bytes
  • Size of remote file: 87.2 kB
examples/scene_00008/1757748389.jpg ADDED

Git LFS Details

  • SHA256: 30ff019ea8b37e6d7337b9b524bc91194ab8ff4e0ac9cffb33d985674b160740
  • Pointer size: 130 Bytes
  • Size of remote file: 76.8 kB
examples/scene_00008/1757748429.jpg ADDED

Git LFS Details

  • SHA256: bd4f6ee92bd6f883078e7002b8d72e21100134ffa4bed9feb1b039cdd0002c19
  • Pointer size: 130 Bytes
  • Size of remote file: 74.6 kB
examples/scene_00008/1757748477.jpg ADDED

Git LFS Details

  • SHA256: f4912e213784d15c4415216b5a13288722e5ef509a349084f170fa3920dede8b
  • Pointer size: 130 Bytes
  • Size of remote file: 84.4 kB
examples/scene_00008/1757748528.jpg ADDED

Git LFS Details

  • SHA256: 47c0fecc84e80012faee6e9ec6d1ce8b7841934f8ddcdcbb8d1bed60b2d0fd7a
  • Pointer size: 130 Bytes
  • Size of remote file: 96.3 kB
examples/scene_00008/1757748562.jpg ADDED

Git LFS Details

  • SHA256: 1c04a1704951d53733755b4c3444694325cd8576c87501bbd7aa679a8a61d5fc
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
examples/scene_00008/1757748600.jpg ADDED

Git LFS Details

  • SHA256: 3db952b9929da9f0f608d4509de43c3aa10dd7e5882c8f64a091c4d553b8be5f
  • Pointer size: 130 Bytes
  • Size of remote file: 85.7 kB
examples/scene_00008/1757748638.jpg ADDED

Git LFS Details

  • SHA256: 3ae1ddc73cc7cc85534b8ef0156baab344ee254e9bc7f02955ffc5c314a5e09c
  • Pointer size: 130 Bytes
  • Size of remote file: 75.3 kB
examples/scene_00008/1757748685.jpg ADDED

Git LFS Details

  • SHA256: 2d879dd182f6b747cbff39bf1104dcd194880212c356f421db3fca2102621571
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
examples/scene_00008/1757748728.jpg ADDED

Git LFS Details

  • SHA256: e9ec8b77c76757f695aa7f6bfaf5a1dbecd328c4b231a1daa0fbad77dc596115
  • Pointer size: 130 Bytes
  • Size of remote file: 75.9 kB
examples/scene_00008/1757748770.jpg ADDED

Git LFS Details

  • SHA256: 83535100f57b9ea09ee6b0e47da2ed827957dd2027c531727940a136cbb34719
  • Pointer size: 130 Bytes
  • Size of remote file: 91.9 kB
examples/scene_00008/1757748817.jpg ADDED

Git LFS Details

  • SHA256: 6d5bd095592ee99cbe7f2646a60d86f65a1e57767b479f81981d853f816d0939
  • Pointer size: 130 Bytes
  • Size of remote file: 87.5 kB
examples/scene_00008/1757748866.jpg ADDED

Git LFS Details

  • SHA256: 238d31b343b59023773ffd19a4f64e3663ceab052ee800c2b78548bd3484b9c4
  • Pointer size: 130 Bytes
  • Size of remote file: 76.3 kB
examples/scene_00008/1757748907.jpg ADDED

Git LFS Details

  • SHA256: cd3e611f2473a49d071e51d75302c419ef63c48d15339adb2988f7ebd581f11d
  • Pointer size: 130 Bytes
  • Size of remote file: 80.4 kB
examples/scene_00008/1757748959.jpg ADDED

Git LFS Details

  • SHA256: b157fccda7e9b51e908ce990976b38374b2baa01de4d5437fecbc49f1751b067
  • Pointer size: 130 Bytes
  • Size of remote file: 95.2 kB
examples/scene_00008/1757749004.jpg ADDED

Git LFS Details

  • SHA256: 63f3042464745c37e7eb327308c93174714dfd99f40d644e716b04ee4dac5558
  • Pointer size: 130 Bytes
  • Size of remote file: 90.8 kB
examples/scene_00008/1757749043.jpg ADDED

Git LFS Details

  • SHA256: 3b2dbc11fc28999b640ecce24fa80e18fcc1758a02bf7240358eec2adcbb9096
  • Pointer size: 130 Bytes
  • Size of remote file: 63.8 kB
examples/scene_00008/1757749091.jpg ADDED

Git LFS Details

  • SHA256: 884e574ce5de6f976d17028e5b23798003d746312f504882555bc007ff818bf0
  • Pointer size: 130 Bytes
  • Size of remote file: 75.3 kB
examples/scene_00008/1757749140.jpg ADDED

Git LFS Details

  • SHA256: 198106f9757833a60a0fdb7d2894412c82749edfdcfa8c14a68d3817846d427c
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB