chrischoy commited on
Commit
a8e8155
·
verified ·
1 Parent(s): 484c04d

Merge SpaceFormer demo (viser + CLI + Gradio) under demo/

Browse files
demo/README.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SpaceFormer Open-Vocab 3D Instance Segmentation
3
+ emoji: 🧩
4
+ colorFrom: indigo
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ tags:
12
+ - 3d
13
+ - point-cloud
14
+ - instance-segmentation
15
+ - open-vocabulary
16
+ ---
17
+
18
+ # SpaceFormer — Open-Vocabulary 3D Instance Segmentation (demo)
19
+
20
+ Proposal-free **open-vocabulary 3D instance segmentation**. A Mask2Former-style query
21
+ decoder (learned queries + RoPE) on top of the WarpConvNet `SpaCeFormer` backbone: one
22
+ forward pass over an RGB point cloud produces query masks + per-query CLIP features,
23
+ which are labeled against text embeddings of **arbitrary** class names (SigLIP2, with
24
+ prompt ensembling) — the vocabulary is chosen at inference time.
25
+
26
+ Released checkpoint:
27
+
28
+ | Benchmark | mAP |
29
+ |---|---|
30
+ | ScanNet200 | 0.1265 |
31
+ | ScanNet++ | 0.2217 |
32
+ | Replica | 0.2644 |
33
+
34
+ This repo is the **demo / inference layer**. The model itself lives in WarpConvNet
35
+ (`warpconvnet.models.spaceformer`); this repo only adds the Gradio UI (`app.py`) and a
36
+ CLI inference entry point (`inference.py`).
37
+
38
+ ## Requirements
39
+
40
+ ```bash
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ > **WarpConvNet must be installed with its compiled extension** (a pre-built wheel, or
45
+ > build from source). It is intentionally not pinned in `requirements.txt` because it is
46
+ > environment-specific. `transformers` pulls the SigLIP2 text encoder
47
+ > (`google/siglip2-so400m-patch14-224`) on first use.
48
+
49
+ ## Live demo (Gradio / HuggingFace Space)
50
+
51
+ ```bash
52
+ HF_REPO_ID=chrischoy/SpaCeFormer python app.py
53
+ # or a local checkpoint:
54
+ SPACEFORMER_CKPT=/path/to/spaceformer_512_siglip2_ssccc.ckpt python app.py
55
+ ```
56
+
57
+ Upload a point cloud, type comma-separated class names, get an interactive 3D view
58
+ colored by predicted instance + a ranked table. As a **HuggingFace Space**: create a
59
+ **GPU** Gradio Space, install WarpConvNet + `requirements.txt` in the image, and set the
60
+ Space variables `HF_REPO_ID` (and optional `HF_FILENAME`, default
61
+ `spaceformer_512_siglip2_ssccc.ckpt`).
62
+
63
+ ## Local demo (viser)
64
+
65
+ An interactive, self-contained local demo that takes **text class names**, runs
66
+ segmentation, and visualizes the result in the browser with
67
+ [viser](https://viser.studio) — each predicted instance gets a distinct color,
68
+ unassigned points stay grey, and a GUI panel lists the top instances.
69
+
70
+ ```bash
71
+ # auto-download the checkpoint + use a bundled sample point cloud
72
+ python demo_viser.py --port 8080
73
+
74
+ # your own cloud + vocabulary, local checkpoint
75
+ python demo_viser.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
76
+ --ply my_scene.ply --class-names chair table monitor wall floor
77
+
78
+ # full ScanNet200 label set
79
+ python demo_viser.py --ply my_scene.ply --use-scannet200
80
+ ```
81
+
82
+ Then open the printed URL (default `http://localhost:8080`) in a browser.
83
+ With no `--ply`, the demo uses an open3d bundled sample cloud (or a synthesized
84
+ random RGB cloud) — a generic cloud won't segment meaningfully; it only
85
+ demonstrates that the pipeline + viewer run end to end. The demo colors the
86
+ model's **output** points (`out["backbone_pc"].coordinates`), which are what the
87
+ predicted masks index into after the model's internal voxelization — not the raw
88
+ `.ply` points, whose count may differ.
89
+
90
+ ## CLI inference
91
+
92
+ ```bash
93
+ # local checkpoint
94
+ python inference.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
95
+ --scene /path/to/scene_dir # dir with coord.npy + color.npy
96
+
97
+ # or auto-download from a HuggingFace model repo
98
+ HF_REPO_ID=chrischoy/SpaCeFormer python inference.py \
99
+ --scene my_scene.ply --class-names "office chair" "desk" "monitor" "other"
100
+
101
+ # full ScanNet200 label set
102
+ python inference.py --ckpt <ckpt> --scene <scene> --use-scannet200
103
+ ```
104
+
105
+ `--scene` accepts a directory with `coord.npy`(`[N,3]` float meters)+`color.npy`(`[N,3]`
106
+ 0–255), a `.npz` `{coord,color}`, an `[N,6]` `.npy` (xyz,rgb), or a `.ply`. Coordinates
107
+ stay in **meters** — the model voxelizes internally at 2 cm. Output: a ranked list of
108
+ `{label, score, #points}`; `score = objectness · mask_quality · class_prob`.
109
+
110
+ ## License
111
+
112
+ Apache-2.0, matching the WarpConvNet `space_former.py` SPDX header.
demo/app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """HuggingFace Space: SpaceFormer open-vocabulary 3D instance segmentation release.
4
+
5
+ Presentation/deployment layer. All model + inference logic is imported from the
6
+ installed ``warpconvnet`` library (``warpconvnet.models.spaceformer``); this file
7
+ only adds the Gradio UI, the 3D Plotly viewer, and checkpoint download.
8
+
9
+ Upload an RGB point cloud (.ply / [N,6] .npy / .npz), type comma-separated class
10
+ names, and get an interactive 3D view colored by predicted instance + a ranked
11
+ table of {label, score, #points}.
12
+
13
+ WarpConvNet (with its compiled extension) and transformers must be installed in
14
+ the Space image. Configure the checkpoint via Space variables:
15
+
16
+ HF_REPO_ID model repo holding the checkpoint (e.g. chrischoy/SpaCeFormer)
17
+ HF_FILENAME checkpoint filename (default: spaceformer_512_siglip2_ssccc.ckpt)
18
+ SPACEFORMER_CKPT explicit local checkpoint path (overrides the HF download)
19
+ """
20
+
21
+ import os
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from warpconvnet.models.spaceformer import (
27
+ build_spaceformer,
28
+ load_spaceformer_checkpoint,
29
+ )
30
+ from labels import DEFAULT_CLASS_NAMES, PROMPT_TEMPLATES
31
+ from pipeline import (
32
+ SIGLIP_MODEL_ID,
33
+ load_scene,
34
+ make_batch,
35
+ predict_instances,
36
+ )
37
+
38
+ HF_REPO_ID = os.environ.get("HF_REPO_ID", "")
39
+ HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt")
40
+
41
+ _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ _STATE = {"net": None, "clip": None} # lazy singletons, kept resident across requests
43
+
44
+
45
+ def _resolve_ckpt() -> str:
46
+ explicit = os.environ.get("SPACEFORMER_CKPT")
47
+ if explicit:
48
+ return explicit
49
+ if not HF_REPO_ID:
50
+ raise RuntimeError(
51
+ "Set SPACEFORMER_CKPT to a local checkpoint, or HF_REPO_ID to a "
52
+ "HuggingFace model repo to auto-download from."
53
+ )
54
+ from huggingface_hub import hf_hub_download
55
+ return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
56
+
57
+
58
+ def _get_model():
59
+ if _STATE["net"] is None:
60
+ net = build_spaceformer(device=_DEVICE)
61
+ load_spaceformer_checkpoint(net, _resolve_ckpt())
62
+ _STATE["net"] = net
63
+ return _STATE["net"]
64
+
65
+
66
+ def _get_clip_encoder():
67
+ if _STATE["clip"] is None:
68
+ from text_encoder import get_text_encoder
69
+ _STATE["clip"] = get_text_encoder(
70
+ model_type="siglip2", model_id=SIGLIP_MODEL_ID, device=str(_DEVICE)
71
+ )
72
+ return _STATE["clip"]
73
+
74
+
75
+ def _text_eval(class_names):
76
+ from clip_eval import CLIPAlignmentEval
77
+ evaluator = CLIPAlignmentEval(normalize_input=False)
78
+ evaluator.prepare_target_embedding(
79
+ class_names=list(class_names),
80
+ clip_encoder=_get_clip_encoder(),
81
+ device=_DEVICE,
82
+ prompt_templates=list(PROMPT_TEMPLATES),
83
+ )
84
+ return evaluator
85
+
86
+
87
+ def _palette(n):
88
+ rng = np.random.default_rng(0)
89
+ return [tuple(int(x) for x in rng.integers(40, 230, size=3)) for _ in range(max(n, 1))]
90
+
91
+
92
+ def _plot(coord_np, results, top_k, score_thresh):
93
+ """Plotly 3D scatter colored by instance (top_k by score)."""
94
+ import plotly.graph_objects as go
95
+
96
+ kept = [r for r in results if r["score"] >= score_thresh][:top_k]
97
+ rgb = np.full((coord_np.shape[0], 3), 160, dtype=np.uint8) # grey background
98
+ palette = _palette(len(kept))
99
+ for i, r in enumerate(kept):
100
+ rgb[r["mask"]] = palette[i]
101
+ colors = [f"rgb({c[0]},{c[1]},{c[2]})" for c in rgb]
102
+
103
+ # Subsample for browser responsiveness.
104
+ n = coord_np.shape[0]
105
+ if n > 120_000:
106
+ idx = np.random.default_rng(0).choice(n, 120_000, replace=False)
107
+ else:
108
+ idx = np.arange(n)
109
+
110
+ fig = go.Figure(
111
+ data=[go.Scatter3d(
112
+ x=coord_np[idx, 0], y=coord_np[idx, 1], z=coord_np[idx, 2],
113
+ mode="markers",
114
+ marker=dict(size=1.5, color=[colors[i] for i in idx]),
115
+ )]
116
+ )
117
+ fig.update_layout(
118
+ scene=dict(aspectmode="data"),
119
+ margin=dict(l=0, r=0, t=0, b=0),
120
+ showlegend=False,
121
+ )
122
+ return fig
123
+
124
+
125
+ def segment(scene_file, class_text, top_k, score_thresh):
126
+ """Gradio callback: file + class names -> (3D figure, results table)."""
127
+ if scene_file is None:
128
+ return None, [["(upload a point cloud first)", "", ""]]
129
+
130
+ class_names = [c.strip() for c in class_text.split(",") if c.strip()] \
131
+ or list(DEFAULT_CLASS_NAMES)
132
+
133
+ path = scene_file.name if hasattr(scene_file, "name") else scene_file
134
+ coord_np, color_np = load_scene(path)
135
+ batch = make_batch(coord_np, color_np, _DEVICE)
136
+
137
+ net = _get_model()
138
+ results = predict_instances(net, batch, _text_eval(class_names), class_names)
139
+
140
+ fig = _plot(coord_np, results, int(top_k), float(score_thresh))
141
+ table = [
142
+ [r["label"], f"{r['score']:.3f}", int(r["mask"].sum())]
143
+ for r in results[: int(top_k)]
144
+ ]
145
+ if not table:
146
+ table = [["(no instances above threshold)", "", ""]]
147
+ return fig, table
148
+
149
+
150
+ def build_interface():
151
+ import gradio as gr
152
+
153
+ with gr.Blocks(title="SpaceFormer — Open-Vocab 3D Instance Segmentation") as demo:
154
+ gr.Markdown(
155
+ "# SpaceFormer\n"
156
+ "Proposal-free **open-vocabulary 3D instance segmentation**. Upload an "
157
+ "RGB point cloud, type any class names, and get instance masks labeled "
158
+ "against your vocabulary (SigLIP2 text + prompt ensembling).\n\n"
159
+ "Released checkpoint: ScanNet200 **0.1265** / ScanNet++ 0.2217 / Replica 0.2644."
160
+ )
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ scene_file = gr.File(label="Point cloud (.ply / .npy[N,6] / .npz)")
164
+ class_text = gr.Textbox(
165
+ label="Class names (comma-separated)",
166
+ value=", ".join(DEFAULT_CLASS_NAMES),
167
+ )
168
+ top_k = gr.Slider(1, 100, value=30, step=1, label="Max instances shown")
169
+ score_thresh = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Score threshold")
170
+ run = gr.Button("Segment", variant="primary")
171
+ with gr.Column(scale=2):
172
+ plot = gr.Plot(label="Predicted instances (colored)")
173
+ table = gr.Dataframe(
174
+ headers=["label", "score", "#points"],
175
+ label="Instances",
176
+ wrap=True,
177
+ )
178
+ run.click(segment, [scene_file, class_text, top_k, score_thresh], [plot, table])
179
+ return demo
180
+
181
+
182
+ if __name__ == "__main__":
183
+ build_interface().launch(server_name="0.0.0.0")
demo/clip_eval.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Open-vocabulary CLIP alignment for inference.
4
+
5
+ Self-contained extract of the training repo's ``CLIPAlignmentEval``: encode class
6
+ names into a text-embedding matrix (optionally with prompt ensembling) and score
7
+ per-query CLIP features against it via cosine similarity.
8
+ """
9
+
10
+ import logging
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ class CLIPAlignmentEval(nn.Module):
21
+ """Cosine-similarity classifier between per-query CLIP features and text embeddings.
22
+
23
+ Args:
24
+ normalize_input: L2-normalize the query features before the cosine product.
25
+ For SpaceFormer set this to ``False`` — the clip head output is already
26
+ compared directly (matches the official eval recipe).
27
+ """
28
+
29
+ def __init__(self, normalize_input: bool = False):
30
+ super().__init__()
31
+ self.normalize_input = normalize_input
32
+ self.emb_target: Optional[torch.Tensor] = None # [C, D] L2-normalized
33
+
34
+ def set_target_embedding(self, text_embeddings: torch.Tensor) -> None:
35
+ self.emb_target = text_embeddings.float()
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ if self.normalize_input:
39
+ return F.normalize(x, p=2, dim=1)
40
+ return x
41
+
42
+ def predict(self, x: torch.Tensor, return_logit: bool = False) -> torch.Tensor:
43
+ """Score features ``x`` [Q, D] against the text embeddings -> [Q, C]."""
44
+ assert self.emb_target is not None, "call prepare_target_embedding() first"
45
+ pred = self.forward(x)
46
+ logit = torch.matmul(pred, self.emb_target.t().to(pred.dtype))
47
+ if return_logit:
48
+ return logit
49
+ return logit.argmax(dim=1)
50
+
51
+ @torch.inference_mode()
52
+ def prepare_target_embedding(
53
+ self,
54
+ class_names: List[str],
55
+ clip_encoder: nn.Module,
56
+ device: torch.device,
57
+ use_prompt: bool = False,
58
+ prompt_template: Optional[str] = None,
59
+ prompt_templates: Optional[List[str]] = None,
60
+ ) -> None:
61
+ """Encode ``class_names`` into the [C, D] target matrix.
62
+
63
+ Three mutually exclusive prompting modes (first non-empty wins):
64
+ - ``prompt_templates``: prompt ensembling — render each class under every
65
+ ``"... {} ..."`` template, per-row L2-normalize, mean, re-normalize.
66
+ This is the recommended eval-time free win.
67
+ - ``prompt_template``: a single ``"... {c} ..."`` format string.
68
+ - ``use_prompt``: the OpenScene default ``"a {} in a scene"``.
69
+ The token ``"other"`` is always encoded bare (no template) so the
70
+ background/void class stays neutral.
71
+ """
72
+ log.info("Preparing CLIP target embedding for %d classes", len(class_names))
73
+ if prompt_templates:
74
+ log.info("Prompt ensembling over %d templates", len(prompt_templates))
75
+ ensembled = None
76
+ for template in prompt_templates:
77
+ rendered = [
78
+ template.format(c) if "other" not in c else "other" for c in class_names
79
+ ]
80
+ emb = F.normalize(clip_encoder(rendered, normalize=True).float(), p=2, dim=-1)
81
+ ensembled = emb if ensembled is None else ensembled + emb
82
+ text_embedding = F.normalize(ensembled / float(len(prompt_templates)), p=2, dim=-1)
83
+ elif prompt_template is not None:
84
+ rendered = [prompt_template.format(c=c) for c in class_names]
85
+ text_embedding = clip_encoder(rendered, normalize=True)
86
+ elif use_prompt:
87
+ rendered = [f"a {c} in a scene" if "other" not in c else "other" for c in class_names]
88
+ text_embedding = clip_encoder(rendered, normalize=True)
89
+ else:
90
+ text_embedding = clip_encoder(class_names, normalize=True)
91
+ self.set_target_embedding(text_embedding.to(device))
demo/demo_viser.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Local viser demo: SpaceFormer open-vocabulary 3D instance segmentation.
4
+
5
+ Takes TEXT class names, runs one forward pass of the released SpaceFormer over a
6
+ point cloud, and visualizes the predicted instances in a browser with
7
+ `viser <https://viser.studio>`_ (each kept instance a distinct color; unassigned
8
+ points grey). Model + pipeline come from the installed ``warpconvnet`` library
9
+ and the sibling ``pipeline.py`` / ``postprocessing.py`` in this repo — this file
10
+ only adds the forward + viser visualization glue.
11
+
12
+ # local checkpoint, custom vocabulary
13
+ python demo_viser.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
14
+ --ply my_scene.ply --class-names chair table monitor wall floor
15
+
16
+ # auto-download the checkpoint from HuggingFace and use a bundled sample cloud
17
+ python demo_viser.py --port 8080 # falls back to chrischoy/SpaCeFormer
18
+
19
+ Then open the printed URL (default http://localhost:8080) in a browser.
20
+
21
+ CRITICAL coord/mask alignment (see space_former_seg.py):
22
+ The model voxelizes internally (PointToSparseWrapper), so its output masks are
23
+ over the model's OUTPUT points, whose count may NOT equal the raw .ply point
24
+ count. In eval mode the forward returns ``out["backbone_pc"]`` (a warpconvnet
25
+ ``Points``) whose ``.coordinates`` correspond 1:1 with the per-point mask rows
26
+ in ``out["mask"][0]``. We therefore run the forward pass HERE (mirroring
27
+ pipeline.predict_instances) so we can pull the backbone_pc coordinates and the
28
+ post-processed masks together, and color THOSE coordinates — never the raw
29
+ input coords, which may differ in length/order.
30
+
31
+ NOTE: WarpConvNet (with its compiled ``_C`` extension) + transformers + a
32
+ checkpoint are required to actually run. ``--help`` and import of this file work
33
+ without viser/WarpConvNet (both are imported lazily inside ``main``).
34
+ """
35
+
36
+ import argparse
37
+ import os
38
+ import tempfile
39
+ import time
40
+
41
+ import numpy as np
42
+ import torch
43
+
44
+ # Reused, task-specific glue from this repo (scene I/O, eval transforms, text
45
+ # embeddings) — we deliberately do NOT reuse predict_instances() wholesale
46
+ # because we also need the aligned backbone_pc coordinates, so we inline the
47
+ # forward below and call apply_post_processing() directly.
48
+ from labels import CLASS_LABELS_200, DEFAULT_CLASS_NAMES
49
+ from pipeline import (
50
+ POST_PROCESSING_CFG,
51
+ build_text_embeddings,
52
+ load_scene,
53
+ make_batch,
54
+ )
55
+ from postprocessing import apply_post_processing
56
+
57
+ HF_REPO_ID = os.environ.get("HF_REPO_ID", "chrischoy/SpaCeFormer")
58
+ HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt")
59
+
60
+
61
+ # --------------------------------------------------------------------------- #
62
+ # Checkpoint + sample-scene resolution
63
+ # --------------------------------------------------------------------------- #
64
+ def resolve_ckpt(ckpt_arg):
65
+ """Return a local checkpoint path: --ckpt, $SPACEFORMER_CKPT, or HF download."""
66
+ if ckpt_arg:
67
+ return ckpt_arg
68
+ explicit = os.environ.get("SPACEFORMER_CKPT")
69
+ if explicit:
70
+ return explicit
71
+ from huggingface_hub import hf_hub_download
72
+
73
+ print(f"[ckpt] downloading {HF_FILENAME} from HuggingFace repo {HF_REPO_ID} ...")
74
+ return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
75
+
76
+
77
+ def resolve_sample_ply(ply_arg):
78
+ """Return a path to a .ply. If --ply is unset, use a zero-config sample.
79
+
80
+ Preference order (no fragile external URLs):
81
+ 1) open3d's bundled ``PLYPointCloud`` sample (a small RGB point cloud), or
82
+ 2) a synthesized random RGB point cloud written to a temp .ply.
83
+
84
+ A random/sample cloud will NOT segment into meaningful instances — it only
85
+ demonstrates that the pipeline + visualization run end to end.
86
+ """
87
+ if ply_arg:
88
+ return ply_arg
89
+
90
+ # 1) open3d bundled sample (downloaded/cached by open3d itself, offline after).
91
+ try:
92
+ import open3d as o3d
93
+
94
+ sample_path = o3d.data.PLYPointCloud().path
95
+ if os.path.isfile(sample_path):
96
+ print(f"[sample] using open3d bundled PLYPointCloud sample: {sample_path}")
97
+ print("[sample] NOTE: a generic sample cloud won't segment meaningfully; "
98
+ "it's only to demo the pipeline + viz.")
99
+ return sample_path
100
+ except Exception as exc: # noqa: BLE001 - any open3d issue -> fall back to synthetic
101
+ print(f"[sample] open3d sample unavailable ({exc}); synthesizing a random cloud.")
102
+
103
+ # 2) Synthesize a small random RGB point cloud and write a temp .ply.
104
+ rng = np.random.default_rng(0)
105
+ n = 20_000
106
+ coord = rng.uniform(-2.0, 2.0, size=(n, 3)).astype(np.float32) # meters
107
+ color = rng.integers(0, 256, size=(n, 3)).astype(np.uint8)
108
+ tmp = tempfile.NamedTemporaryFile(suffix=".ply", delete=False)
109
+ tmp.close()
110
+ _write_ply(tmp.name, coord, color)
111
+ print(f"[sample] wrote synthetic random RGB cloud ({n} pts) to {tmp.name}")
112
+ print("[sample] NOTE: a random cloud won't segment meaningfully; it's only to "
113
+ "demo the pipeline + viz.")
114
+ return tmp.name
115
+
116
+
117
+ def _write_ply(path, coord, color):
118
+ """Write an ASCII RGB .ply (plyfile if available, else open3d, else raw)."""
119
+ try:
120
+ from plyfile import PlyData, PlyElement
121
+
122
+ verts = np.empty(
123
+ coord.shape[0],
124
+ dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"),
125
+ ("red", "u1"), ("green", "u1"), ("blue", "u1")],
126
+ )
127
+ verts["x"], verts["y"], verts["z"] = coord[:, 0], coord[:, 1], coord[:, 2]
128
+ verts["red"], verts["green"], verts["blue"] = color[:, 0], color[:, 1], color[:, 2]
129
+ PlyData([PlyElement.describe(verts, "vertex")], text=True).write(path)
130
+ return
131
+ except ImportError:
132
+ pass
133
+ import open3d as o3d
134
+
135
+ pcd = o3d.geometry.PointCloud()
136
+ pcd.points = o3d.utility.Vector3dVector(coord.astype(np.float64))
137
+ pcd.colors = o3d.utility.Vector3dVector(color.astype(np.float64) / 255.0)
138
+ o3d.io.write_point_cloud(path, pcd)
139
+
140
+
141
+ # --------------------------------------------------------------------------- #
142
+ # Forward + post-processing (aligned to backbone_pc coordinates)
143
+ # --------------------------------------------------------------------------- #
144
+ @torch.inference_mode()
145
+ def segment_aligned(net, batch, text_eval, class_names):
146
+ """One forward pass -> post-processed instances aligned to backbone_pc coords.
147
+
148
+ Mirrors ``pipeline.predict_instances`` but ALSO returns the model's output
149
+ coordinates (``out["backbone_pc"].coordinates``), which are the coordinates
150
+ the returned masks index into (see module docstring / space_former_seg.py).
151
+
152
+ Returns ``(coords[M,3] float32, results)`` where each result is
153
+ ``{mask: bool[M], label, label_id, score}`` and ``M`` is the backbone point
154
+ count (== number of columns in ``out["mask"][0]``), NOT the raw .ply count.
155
+ """
156
+ out = net(batch)
157
+
158
+ # backbone_pc is present only in eval; its coordinates align 1:1 with the
159
+ # mask rows. batch size is 1 here, so every point belongs to sample 0.
160
+ backbone_pc = out["backbone_pc"]
161
+ coords = backbone_pc.coordinates.detach().cpu().numpy().astype(np.float32) # [M, 3]
162
+
163
+ binary_logits = out["logit"][0] # [Q, 2] objectness over {fg, bg}
164
+ mask_logits = out["mask"][0].T # [M, Q] -> [Q, M]
165
+ clip_feats = out["clip_feat"][0] # [Q, D]
166
+ pred_iou = out["pred_iou"][0] if "pred_iou" in out else None
167
+
168
+ # Sanity: mask columns must match the backbone point count we will color.
169
+ assert mask_logits.shape[1] == coords.shape[0], (
170
+ f"mask columns {mask_logits.shape[1]} != backbone points {coords.shape[0]}; "
171
+ "coord/mask alignment broken"
172
+ )
173
+
174
+ class_logits = text_eval.predict(clip_feats, return_logit=True) # [Q, C]
175
+
176
+ masks, scores, _classes, indices = apply_post_processing(
177
+ mask_logits,
178
+ binary_logits,
179
+ mask_threshold=0.0,
180
+ point_coords=None,
181
+ pp_cfg=POST_PROCESSING_CFG,
182
+ pred_iou=pred_iou,
183
+ )
184
+
185
+ results = []
186
+ if len(indices) > 0:
187
+ probs = torch.softmax(class_logits[indices], dim=-1) # [K, C]
188
+ class_probs, class_ids = probs.max(dim=1)
189
+ final_scores = scores * class_probs
190
+ for k in range(len(indices)):
191
+ results.append(
192
+ {
193
+ "mask": masks[k].cpu().numpy().astype(bool), # bool[M], aligned to coords
194
+ "label": class_names[int(class_ids[k])],
195
+ "label_id": int(class_ids[k]),
196
+ "score": float(final_scores[k]),
197
+ }
198
+ )
199
+ results.sort(key=lambda r: r["score"], reverse=True)
200
+ return coords, results
201
+
202
+
203
+ # --------------------------------------------------------------------------- #
204
+ # Coloring + viser visualization
205
+ # --------------------------------------------------------------------------- #
206
+ def _instance_colors(coords, results, top_k, score_thresh):
207
+ """Grey base cloud + a distinct random color per kept (top-k, thresholded) instance.
208
+
209
+ Returns ``(rgb[M,3] uint8, kept)`` where ``kept`` is the list of instances
210
+ actually colored (already sorted by score, high first).
211
+ """
212
+ rgb = np.full((coords.shape[0], 3), 160, dtype=np.uint8) # grey background
213
+ kept = [r for r in results if r["score"] >= score_thresh][:top_k]
214
+ rng = np.random.default_rng(0)
215
+ for r in kept:
216
+ color = rng.integers(40, 230, size=3).astype(np.uint8)
217
+ r["color"] = tuple(int(c) for c in color) # stash for GUI legend
218
+ rgb[r["mask"]] = color
219
+ return rgb, kept
220
+
221
+
222
+ def visualize(coords, results, port, top_k, score_thresh, point_size):
223
+ """Start a viser server, add the colored point cloud + a GUI legend, block."""
224
+ import viser # imported lazily so --help works without viser installed
225
+
226
+ rgb, kept = _instance_colors(coords, results, top_k, score_thresh)
227
+
228
+ server = viser.ViserServer(port=port)
229
+
230
+ # Point cloud: positions from backbone_pc coords, colors per instance.
231
+ server.scene.add_point_cloud(
232
+ name="/scene",
233
+ points=coords.astype(np.float32),
234
+ colors=rgb, # uint8 [M, 3]
235
+ point_size=point_size,
236
+ )
237
+
238
+ # A small GUI panel listing the kept instances (label + score + #points).
239
+ with server.gui.add_folder(f"Top {len(kept)} instances"):
240
+ if not kept:
241
+ server.gui.add_markdown("_(no instances above the score threshold)_")
242
+ for i, r in enumerate(kept):
243
+ c = r["color"]
244
+ swatch = f'<span style="color:rgb({c[0]},{c[1]},{c[2]})">&#9632;</span>'
245
+ server.gui.add_markdown(
246
+ f"{swatch} **{r['label']}** — score {r['score']:.3f}, "
247
+ f"{int(r['mask'].sum())} pts"
248
+ )
249
+
250
+ url = f"http://localhost:{port}"
251
+ print(f"\n[viser] serving at {url} (open this URL in your browser)")
252
+ print("[viser] press Ctrl-C to stop.")
253
+ try:
254
+ while True:
255
+ time.sleep(2.0)
256
+ except KeyboardInterrupt:
257
+ print("\n[viser] shutting down.")
258
+
259
+
260
+ def print_summary(results, top_k):
261
+ """Ranked text summary of instances (label, score, #points) to stdout."""
262
+ print(f"\n=== {len(results)} predicted instances ===")
263
+ header = f"{'#':>3} {'score':>6} {'points':>8} label"
264
+ print(header)
265
+ print("-" * len(header))
266
+ for i, r in enumerate(results[:top_k]):
267
+ print(f"{i:>3} {r['score']:>6.3f} {int(r['mask'].sum()):>8} {r['label']}")
268
+ if len(results) > top_k:
269
+ print(f"... ({len(results) - top_k} more)")
270
+
271
+
272
+ # --------------------------------------------------------------------------- #
273
+ # Entry point (linear top-to-bottom)
274
+ # --------------------------------------------------------------------------- #
275
+ def main():
276
+ ap = argparse.ArgumentParser(
277
+ description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
278
+ )
279
+ ap.add_argument("--ply", default=None,
280
+ help="input point cloud .ply (default: an open3d bundled sample "
281
+ "or a synthesized random cloud)")
282
+ ap.add_argument("--class-names", nargs="+", default=None,
283
+ help="open-vocab class names (default: a short built-in list)")
284
+ ap.add_argument("--use-scannet200", action="store_true",
285
+ help="use the full ScanNet200 label set as class names")
286
+ ap.add_argument("--ckpt", default=None,
287
+ help="checkpoint path (else $SPACEFORMER_CKPT, else HF download)")
288
+ ap.add_argument("--iou-head", action="store_true",
289
+ help="build the learned IoU head (only for a checkpoint that has one)")
290
+ ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
291
+ ap.add_argument("--port", type=int, default=8080, help="viser server port")
292
+ ap.add_argument("--score-thresh", type=float, default=0.0,
293
+ help="only color instances with score >= this")
294
+ ap.add_argument("--top-k", type=int, default=30,
295
+ help="max number of instances to color / list")
296
+ ap.add_argument("--point-size", type=float, default=0.02,
297
+ help="viser point size (world units / meters)")
298
+ args = ap.parse_args()
299
+
300
+ device = torch.device(args.device)
301
+
302
+ # Vocabulary (text class names — the open-vocabulary input).
303
+ if args.use_scannet200:
304
+ class_names = list(CLASS_LABELS_200)
305
+ else:
306
+ class_names = args.class_names or list(DEFAULT_CLASS_NAMES)
307
+ print(f"[vocab] {len(class_names)} classes: {', '.join(class_names[:8])}"
308
+ f"{' ...' if len(class_names) > 8 else ''}")
309
+
310
+ # 1) Resolve/load the scene (.ply -> coord[N,3] meters, color[N,3] 0-255).
311
+ ply_path = resolve_sample_ply(args.ply)
312
+ coord_np, color_np = load_scene(ply_path)
313
+
314
+ # 2) Build the eval batch (CenterShift + NormalizeColor + offsets).
315
+ batch = make_batch(coord_np, color_np, device)
316
+
317
+ # 3) Build model + load released weights (WarpConvNet required here).
318
+ from warpconvnet.models.spaceformer import (
319
+ build_spaceformer,
320
+ load_spaceformer_checkpoint,
321
+ )
322
+
323
+ net = build_spaceformer(use_iou_head=args.iou_head, device=device)
324
+ missing, unexpected = load_spaceformer_checkpoint(net, resolve_ckpt(args.ckpt))
325
+ print(f"[weights] {len(missing)} missing, {len(unexpected)} unexpected")
326
+
327
+ # 4) Encode the text class names (SigLIP2 + prompt ensembling).
328
+ text_eval = build_text_embeddings(class_names, device)
329
+
330
+ # 5) Forward + post-process -> masks + labels + scores aligned to backbone coords.
331
+ coords, results = segment_aligned(net, batch, text_eval, class_names)
332
+ print(f"[model] backbone output points: {coords.shape[0]} "
333
+ f"(raw input was {coord_np.shape[0]})")
334
+
335
+ # 6) Text summary + interactive viser visualization.
336
+ print_summary(results, args.top_k)
337
+ visualize(coords, results, args.port, args.top_k, args.score_thresh, args.point_size)
338
+
339
+
340
+ if __name__ == "__main__":
341
+ main()
demo/inference.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """SpaceFormer single-scene open-vocabulary 3D instance segmentation (CLI inference).
4
+
5
+ Self-contained command-line entry point for the released
6
+ SpaceFormer. The model + inference pipeline come from the installed ``warpconvnet``
7
+ library (``warpconvnet.models.spaceformer``); this script only wires up checkpoint
8
+ resolution (local path or HuggingFace auto-download), runs one scene, and
9
+ prints/saves the predicted instances.
10
+
11
+ # weights from a local checkpoint
12
+ python inference.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
13
+ --scene /path/to/scene_dir
14
+
15
+ # or auto-download from a HuggingFace model repo
16
+ HF_REPO_ID=chrischoy/SpaCeFormer python inference.py \
17
+ --scene my_scene.ply --class-names "office chair" "desk" "monitor" "other"
18
+
19
+ Requires WarpConvNet (with its compiled extension) + transformers in the environment.
20
+ """
21
+
22
+ import argparse
23
+ import os
24
+
25
+ import torch
26
+
27
+ from warpconvnet.models.spaceformer import (
28
+ build_spaceformer,
29
+ load_spaceformer_checkpoint,
30
+ )
31
+ from labels import CLASS_LABELS_200, DEFAULT_CLASS_NAMES
32
+ from pipeline import (
33
+ build_text_embeddings,
34
+ load_scene,
35
+ make_batch,
36
+ predict_instances,
37
+ print_summary,
38
+ save_results,
39
+ )
40
+
41
+ HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt")
42
+
43
+
44
+ def resolve_ckpt(ckpt_arg: str | None) -> str:
45
+ """Return a local checkpoint path: explicit --ckpt, $SPACEFORMER_CKPT, or HF download."""
46
+ if ckpt_arg:
47
+ return ckpt_arg
48
+ explicit = os.environ.get("SPACEFORMER_CKPT")
49
+ if explicit:
50
+ return explicit
51
+ repo_id = os.environ.get("HF_REPO_ID")
52
+ if not repo_id:
53
+ raise SystemExit(
54
+ "No checkpoint: pass --ckpt, or set SPACEFORMER_CKPT (local path) or "
55
+ "HF_REPO_ID (HuggingFace model repo to download from)."
56
+ )
57
+ from huggingface_hub import hf_hub_download
58
+ return hf_hub_download(repo_id=repo_id, filename=HF_FILENAME)
59
+
60
+
61
+ def main() -> None:
62
+ ap = argparse.ArgumentParser(description=__doc__,
63
+ formatter_class=argparse.RawDescriptionHelpFormatter)
64
+ ap.add_argument("--ckpt", default=None,
65
+ help="checkpoint path (else $SPACEFORMER_CKPT, else HF download via $HF_REPO_ID)")
66
+ ap.add_argument("--scene", required=True,
67
+ help="scene dir (coord.npy+color.npy) or .npy/.npz/.ply file")
68
+ ap.add_argument("--class-names", nargs="+", default=None,
69
+ help="open-vocab class names (default: a short built-in list)")
70
+ ap.add_argument("--use-scannet200", action="store_true",
71
+ help="use the full ScanNet200 label set as class names")
72
+ ap.add_argument("--iou-head", action="store_true",
73
+ help="build the learned IoU head (only for a checkpoint that has one)")
74
+ ap.add_argument("--save", default=None, help="optional output .npz path")
75
+ ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
76
+ args = ap.parse_args()
77
+
78
+ device = torch.device(args.device)
79
+ if args.use_scannet200:
80
+ class_names = list(CLASS_LABELS_200)
81
+ else:
82
+ class_names = args.class_names or list(DEFAULT_CLASS_NAMES)
83
+ print(f"[vocab] {len(class_names)} classes")
84
+
85
+ net = build_spaceformer(use_iou_head=args.iou_head, device=device)
86
+ missing, unexpected = load_spaceformer_checkpoint(net, resolve_ckpt(args.ckpt))
87
+ print(f"[weights] {len(missing)} missing, {len(unexpected)} unexpected")
88
+
89
+ coord_np, color_np = load_scene(args.scene)
90
+ batch = make_batch(coord_np, color_np, device)
91
+ text_eval = build_text_embeddings(class_names, device)
92
+ results = predict_instances(net, batch, text_eval, class_names)
93
+
94
+ print_summary(results)
95
+ if args.save:
96
+ save_results(results, coord_np, args.save)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()
demo/labels.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Label sets and prompt templates for open-vocabulary SpaceFormer evaluation."""
4
+
5
+ # Prompt-ensembling templates (LEVER 1). Each must contain a single positional
6
+ # ``{}`` placeholder. Averaging class embeddings across these is a confirmed
7
+ # eval-time accuracy win for the released checkpoint.
8
+ PROMPT_TEMPLATES = (
9
+ "a {} in a scene",
10
+ "a photo of a {} in a scene",
11
+ "a {}",
12
+ "a photo of a {}",
13
+ "there is a {} in the scene",
14
+ "a 3d model of a {}",
15
+ )
16
+
17
+ # A short, readable default vocabulary for the demo. Pass your own class names to
18
+ # override, or use CLASS_LABELS_200 for the full ScanNet200 benchmark label set.
19
+ DEFAULT_CLASS_NAMES = (
20
+ "wall",
21
+ "floor",
22
+ "ceiling",
23
+ "chair",
24
+ "table",
25
+ "desk",
26
+ "couch",
27
+ "bed",
28
+ "cabinet",
29
+ "shelf",
30
+ "door",
31
+ "window",
32
+ "monitor",
33
+ "keyboard",
34
+ "lamp",
35
+ "picture",
36
+ "whiteboard",
37
+ "trash can",
38
+ "backpack",
39
+ "plant",
40
+ "other",
41
+ )
42
+
43
+ # The 200 ScanNet200 instance/semantic class names (benchmark order).
44
+ CLASS_LABELS_200 = (
45
+ "wall",
46
+ "chair",
47
+ "floor",
48
+ "table",
49
+ "door",
50
+ "couch",
51
+ "cabinet",
52
+ "shelf",
53
+ "desk",
54
+ "office chair",
55
+ "bed",
56
+ "pillow",
57
+ "sink",
58
+ "picture",
59
+ "window",
60
+ "toilet",
61
+ "bookshelf",
62
+ "monitor",
63
+ "curtain",
64
+ "book",
65
+ "armchair",
66
+ "coffee table",
67
+ "box",
68
+ "refrigerator",
69
+ "lamp",
70
+ "kitchen cabinet",
71
+ "towel",
72
+ "clothes",
73
+ "tv",
74
+ "nightstand",
75
+ "counter",
76
+ "dresser",
77
+ "stool",
78
+ "cushion",
79
+ "plant",
80
+ "ceiling",
81
+ "bathtub",
82
+ "end table",
83
+ "dining table",
84
+ "keyboard",
85
+ "bag",
86
+ "backpack",
87
+ "toilet paper",
88
+ "printer",
89
+ "tv stand",
90
+ "whiteboard",
91
+ "blanket",
92
+ "shower curtain",
93
+ "trash can",
94
+ "closet",
95
+ "stairs",
96
+ "microwave",
97
+ "stove",
98
+ "shoe",
99
+ "computer tower",
100
+ "bottle",
101
+ "bin",
102
+ "ottoman",
103
+ "bench",
104
+ "board",
105
+ "washing machine",
106
+ "mirror",
107
+ "copier",
108
+ "basket",
109
+ "sofa chair",
110
+ "file cabinet",
111
+ "fan",
112
+ "laptop",
113
+ "shower",
114
+ "paper",
115
+ "person",
116
+ "paper towel dispenser",
117
+ "oven",
118
+ "blinds",
119
+ "rack",
120
+ "plate",
121
+ "blackboard",
122
+ "piano",
123
+ "suitcase",
124
+ "rail",
125
+ "radiator",
126
+ "recycling bin",
127
+ "container",
128
+ "wardrobe",
129
+ "soap dispenser",
130
+ "telephone",
131
+ "bucket",
132
+ "clock",
133
+ "stand",
134
+ "light",
135
+ "laundry basket",
136
+ "pipe",
137
+ "clothes dryer",
138
+ "guitar",
139
+ "toilet paper holder",
140
+ "seat",
141
+ "speaker",
142
+ "column",
143
+ "bicycle",
144
+ "ladder",
145
+ "bathroom stall",
146
+ "shower wall",
147
+ "cup",
148
+ "jacket",
149
+ "storage bin",
150
+ "coffee maker",
151
+ "dishwasher",
152
+ "paper towel roll",
153
+ "machine",
154
+ "mat",
155
+ "windowsill",
156
+ "bar",
157
+ "toaster",
158
+ "bulletin board",
159
+ "ironing board",
160
+ "fireplace",
161
+ "soap dish",
162
+ "kitchen counter",
163
+ "doorframe",
164
+ "toilet paper dispenser",
165
+ "mini fridge",
166
+ "fire extinguisher",
167
+ "ball",
168
+ "hat",
169
+ "shower curtain rod",
170
+ "water cooler",
171
+ "paper cutter",
172
+ "tray",
173
+ "shower door",
174
+ "pillar",
175
+ "ledge",
176
+ "toaster oven",
177
+ "mouse",
178
+ "toilet seat cover dispenser",
179
+ "furniture",
180
+ "cart",
181
+ "storage container",
182
+ "scale",
183
+ "tissue box",
184
+ "light switch",
185
+ "crate",
186
+ "power outlet",
187
+ "decoration",
188
+ "sign",
189
+ "projector",
190
+ "closet door",
191
+ "vacuum cleaner",
192
+ "candle",
193
+ "plunger",
194
+ "stuffed animal",
195
+ "headphones",
196
+ "dish rack",
197
+ "broom",
198
+ "guitar case",
199
+ "range hood",
200
+ "dustpan",
201
+ "hair dryer",
202
+ "water bottle",
203
+ "handicap bar",
204
+ "purse",
205
+ "vent",
206
+ "shower floor",
207
+ "water pitcher",
208
+ "mailbox",
209
+ "bowl",
210
+ "paper bag",
211
+ "alarm clock",
212
+ "music stand",
213
+ "projector screen",
214
+ "divider",
215
+ "laundry detergent",
216
+ "bathroom counter",
217
+ "object",
218
+ "bathroom vanity",
219
+ "closet wall",
220
+ "laundry hamper",
221
+ "bathroom stall door",
222
+ "ceiling light",
223
+ "trash bin",
224
+ "dumbbell",
225
+ "stair rail",
226
+ "tube",
227
+ "bathroom cabinet",
228
+ "cd case",
229
+ "closet rod",
230
+ "coffee kettle",
231
+ "structure",
232
+ "shower head",
233
+ "keyboard piano",
234
+ "case of water bottles",
235
+ "coat rack",
236
+ "storage organizer",
237
+ "folded chair",
238
+ "fire alarm",
239
+ "power strip",
240
+ "calendar",
241
+ "poster",
242
+ "potted plant",
243
+ "luggage",
244
+ "mattress",
245
+ )
demo/pipeline.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Open-vocabulary inference pipeline for the SpaceFormer release.
4
+
5
+ Task-specific glue that turns the raw model outputs (logit / mask / clip_feat,
6
+ from ``warpconvnet.models.spaceformer.SpaCeFormerInstSeg``) into labeled instances:
7
+ scene I/O + minimal eval transforms, SigLIP2 text embedding (prompt-ensembled),
8
+ and SAM2-style mask post-processing. Kept in the release repo (not WarpConvNet),
9
+ since labeling/post-processing is downstream of the model.
10
+ """
11
+
12
+ import os
13
+
14
+ import numpy as np
15
+ import torch
16
+
17
+ from postprocessing import apply_post_processing
18
+ from labels import PROMPT_TEMPLATES
19
+
20
+ # SigLIP2 text encoder used at training time. text_dim 1152 == model clip-head dim.
21
+ SIGLIP_MODEL_ID = "google/siglip2-so400m-patch14-224"
22
+
23
+ # Voxel size the model was trained/exported at (the model voxelizes internally;
24
+ # this only labels the data_dict for parity).
25
+ GRID_SIZE = 0.02
26
+
27
+ # Release eval post-processing: NMS on, min 20 pts/mask, DBSCAN/stability off.
28
+ POST_PROCESSING_CFG = {
29
+ "use_dbscan": False,
30
+ "use_stability_score": False,
31
+ "use_nms": True,
32
+ "nms_thresh": 0.7,
33
+ "min_mask_points": 20,
34
+ "objectness_thresh": 0.0,
35
+ }
36
+
37
+
38
+ # --------------------------------------------------------------------------- #
39
+ # Scene loading + minimal eval transforms
40
+ # --------------------------------------------------------------------------- #
41
+ def load_scene(scene_path: str):
42
+ """Load one scene as (coord[N,3] float meters, color[N,3] float 0-255)."""
43
+ if os.path.isdir(scene_path):
44
+ coord = np.load(os.path.join(scene_path, "coord.npy"))
45
+ color = np.load(os.path.join(scene_path, "color.npy"))
46
+ elif scene_path.endswith(".npz"):
47
+ z = np.load(scene_path)
48
+ coord = z[_first_key(z, ("coord", "coords", "xyz", "points"))]
49
+ color = z[_first_key(z, ("color", "colors", "rgb"))]
50
+ elif scene_path.endswith(".npy"):
51
+ arr = np.load(scene_path)
52
+ assert arr.ndim == 2 and arr.shape[1] >= 6, "expected [N,6] xyz+rgb .npy"
53
+ coord, color = arr[:, :3], arr[:, 3:6]
54
+ elif scene_path.endswith(".ply"):
55
+ coord, color = _load_ply(scene_path)
56
+ else:
57
+ raise ValueError(f"unsupported scene format: {scene_path}")
58
+
59
+ coord = np.ascontiguousarray(coord, dtype=np.float32)
60
+ color = np.ascontiguousarray(color)
61
+ if color.dtype != np.uint8 and color.max() <= 1.0 + 1e-6:
62
+ color = (color * 255.0).round()
63
+ color = color.astype(np.float32)
64
+ print(f"[scene] {scene_path}: {coord.shape[0]} points")
65
+ return coord, color
66
+
67
+
68
+ def _first_key(z, candidates):
69
+ for c in candidates:
70
+ if c in z:
71
+ return c
72
+ raise KeyError(f"none of {candidates} in {list(z.keys())}")
73
+
74
+
75
+ def _load_ply(path: str):
76
+ try:
77
+ from plyfile import PlyData
78
+
79
+ v = PlyData.read(path)["vertex"]
80
+ coord = np.stack([v["x"], v["y"], v["z"]], axis=1)
81
+ if "red" in v.data.dtype.names:
82
+ color = np.stack([v["red"], v["green"], v["blue"]], axis=1)
83
+ else:
84
+ color = np.full_like(coord, 127.5)
85
+ return coord, color
86
+ except ImportError:
87
+ import open3d as o3d
88
+
89
+ pcd = o3d.io.read_point_cloud(path)
90
+ coord = np.asarray(pcd.points)
91
+ color = np.asarray(pcd.colors) * 255.0 if pcd.has_colors() else np.full_like(coord, 127.5)
92
+ return coord, color
93
+
94
+
95
+ def make_batch(coord_np, color_np, device):
96
+ """Apply the minimal eval transforms and build the single-sample data dict.
97
+
98
+ CenterShift(apply_z) + NormalizeColor(0-255 -> [-1,1]) + offset [0, N].
99
+ Coords stay in meters — the model voxelizes internally at 2 cm.
100
+ """
101
+ coord = torch.from_numpy(coord_np).float()
102
+ color = torch.from_numpy(color_np).float()
103
+
104
+ cmin = coord.min(dim=0).values
105
+ cmax = coord.max(dim=0).values
106
+ shift = torch.tensor([(cmin[0] + cmax[0]) / 2, (cmin[1] + cmax[1]) / 2, cmin[2]])
107
+ coord = coord - shift
108
+ feat = color / 127.5 - 1.0
109
+ n = coord.shape[0]
110
+ offset = torch.tensor([0, n], dtype=torch.int32)
111
+ return {
112
+ "coord": coord.to(device),
113
+ "feat": feat.to(device),
114
+ "offset": offset.to(device),
115
+ "grid_size": GRID_SIZE,
116
+ }
117
+
118
+
119
+ # --------------------------------------------------------------------------- #
120
+ # Text embeddings (prompt-ensembled) + prediction
121
+ # --------------------------------------------------------------------------- #
122
+ def build_text_embeddings(class_names, device):
123
+ """Encode class names with SigLIP2 under multiple templates and average."""
124
+ from text_encoder import get_text_encoder
125
+ from clip_eval import CLIPAlignmentEval
126
+
127
+ clip_encoder = get_text_encoder(model_type="siglip2", model_id=SIGLIP_MODEL_ID, device=str(device))
128
+ evaluator = CLIPAlignmentEval(normalize_input=False) # matches official eval
129
+ evaluator.prepare_target_embedding(
130
+ class_names=list(class_names),
131
+ clip_encoder=clip_encoder,
132
+ device=device,
133
+ prompt_templates=list(PROMPT_TEMPLATES), # prompt ensembling ON
134
+ )
135
+ return evaluator
136
+
137
+
138
+ @torch.inference_mode()
139
+ def predict_instances(net, batch, text_eval, class_names):
140
+ """Single forward pass -> post-processing -> open-vocab labels."""
141
+ out = net(batch)
142
+ binary_logits = out["logit"][0] # [Q, 2] objectness over {fg, bg}
143
+ mask_logits = out["mask"][0].T # [N, Q] -> [Q, N]
144
+ clip_feats = out["clip_feat"][0] # [Q, D]
145
+ pred_iou = out["pred_iou"][0] if "pred_iou" in out else None
146
+
147
+ class_logits = text_eval.predict(clip_feats, return_logit=True) # [Q, C]
148
+
149
+ masks, scores, _classes, indices = apply_post_processing(
150
+ mask_logits,
151
+ binary_logits,
152
+ mask_threshold=0.0,
153
+ point_coords=None,
154
+ pp_cfg=POST_PROCESSING_CFG,
155
+ pred_iou=pred_iou,
156
+ )
157
+
158
+ results = []
159
+ if len(indices) > 0:
160
+ probs = torch.softmax(class_logits[indices], dim=-1) # [K, C]
161
+ class_probs, class_ids = probs.max(dim=1)
162
+ final_scores = scores * class_probs
163
+ for k in range(len(indices)):
164
+ results.append(
165
+ {
166
+ "mask": masks[k].cpu().numpy().astype(bool),
167
+ "label": class_names[int(class_ids[k])],
168
+ "label_id": int(class_ids[k]),
169
+ "score": float(final_scores[k]),
170
+ }
171
+ )
172
+ results.sort(key=lambda r: r["score"], reverse=True)
173
+ return results
174
+
175
+
176
+ # --------------------------------------------------------------------------- #
177
+ # Output helpers
178
+ # --------------------------------------------------------------------------- #
179
+ def print_summary(results, top_k=20):
180
+ print(f"\n=== {len(results)} predicted instances ===")
181
+ header = f"{'#':>3} {'score':>6} {'points':>8} label"
182
+ print(header)
183
+ print("-" * len(header))
184
+ for i, r in enumerate(results[:top_k]):
185
+ print(f"{i:>3} {r['score']:>6.3f} {int(r['mask'].sum()):>8} {r['label']}")
186
+ if len(results) > top_k:
187
+ print(f"... ({len(results) - top_k} more)")
188
+
189
+
190
+ def save_results(results, coord_np, out_path):
191
+ if not results:
192
+ np.savez(out_path, masks=np.zeros((0, coord_np.shape[0]), dtype=bool),
193
+ labels=np.array([]), scores=np.array([]))
194
+ return
195
+ np.savez(
196
+ out_path,
197
+ coord=coord_np,
198
+ masks=np.stack([r["mask"] for r in results]),
199
+ labels=np.array([r["label"] for r in results]),
200
+ label_ids=np.array([r["label_id"] for r in results]),
201
+ scores=np.array([r["score"] for r in results]),
202
+ )
203
+ print(f"[save] wrote {len(results)} instances to {out_path}")
demo/postprocessing.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """
4
+ SAM2-style post-processing utilities for mask segmentation.
5
+
6
+ This module provides shared post-processing functions used by both the
7
+ MaskLanguageLitModule (validation/testing) and the demo script.
8
+ """
9
+
10
+ from typing import Tuple, Optional, Dict
11
+ import time
12
+ import numpy as np
13
+
14
+ import torch
15
+
16
+ try:
17
+ from cuml.cluster import DBSCAN
18
+ except ImportError:
19
+ DBSCAN = None
20
+
21
+
22
+ def calculate_stability_score(
23
+ masks: torch.Tensor,
24
+ mask_threshold: float = 0.0,
25
+ threshold_offset: float = 1.0,
26
+ ) -> torch.Tensor:
27
+ """
28
+ Computes the stability score for a set of masks.
29
+
30
+ The stability score is the IoU between the binary masks obtained by
31
+ thresholding at (mask_threshold + threshold_offset) and
32
+ (mask_threshold - threshold_offset).
33
+
34
+ High stability means sharp mask boundaries.
35
+
36
+ Args:
37
+ masks: [Q, N] mask logits
38
+ mask_threshold: Base threshold (usually 0.0 for logits)
39
+ threshold_offset: Offset to apply for high/low thresholds
40
+
41
+ Returns:
42
+ stability_score: [Q] stability score per mask
43
+ """
44
+ high_thresh_mask = masks > (mask_threshold + threshold_offset)
45
+ low_thresh_mask = masks > (mask_threshold - threshold_offset)
46
+
47
+ intersection = high_thresh_mask.float().sum(-1)
48
+ union = low_thresh_mask.float().sum(-1)
49
+
50
+ stability_score = intersection / (union + 1e-6)
51
+ return stability_score
52
+
53
+
54
+ def apply_nms(
55
+ masks_binary: torch.Tensor,
56
+ scores: torch.Tensor,
57
+ nms_thresh: float = 0.7,
58
+ ) -> torch.Tensor:
59
+ """
60
+ Applies greedy NMS on masks using pairwise IoU.
61
+
62
+ Args:
63
+ masks_binary: [Q, N] binary masks (booleans or 0/1 floats)
64
+ scores: [Q] mask scores for ranking
65
+ nms_thresh: IoU threshold for suppression
66
+
67
+ Returns:
68
+ keep_indices: Tensor of indices to keep after NMS
69
+ """
70
+ # Sort by score descending
71
+ order = torch.argsort(scores, descending=True)
72
+ masks_binary = masks_binary.bool()
73
+
74
+ keep = []
75
+ indices = order
76
+
77
+ while indices.numel() > 0:
78
+ current = indices[0]
79
+ keep.append(current.item())
80
+
81
+ if indices.numel() == 1:
82
+ break
83
+
84
+ # Compare current mask with rest
85
+ current_mask = masks_binary[current].unsqueeze(0) # [1, N]
86
+ rest_indices = indices[1:]
87
+ rest_masks = masks_binary[rest_indices] # [K, N]
88
+
89
+ intersection = (current_mask & rest_masks).float().sum(dim=1)
90
+ union = (current_mask | rest_masks).float().sum(dim=1)
91
+ iou = intersection / (union + 1e-6)
92
+
93
+ # Keep masks with IoU < thresh
94
+ mask_keep = iou < nms_thresh
95
+ indices = rest_indices[mask_keep]
96
+
97
+ return torch.tensor(keep, device=masks_binary.device, dtype=torch.long)
98
+
99
+
100
+ def apply_dbscan_clustering(
101
+ current_masks: torch.Tensor,
102
+ point_coords: torch.Tensor,
103
+ current_scores: torch.Tensor,
104
+ current_classes: torch.Tensor,
105
+ eps: float = 0.95,
106
+ min_samples: int = 1,
107
+ backend: str = "auto",
108
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109
+ """
110
+ Applies DBSCAN to each mask to split spatially disconnected components.
111
+
112
+ Args:
113
+ current_masks: [Q, N] boolean masks
114
+ point_coords: [N, 3] point coordinates
115
+ current_scores: [Q] scores
116
+ current_classes: [Q] classes
117
+ eps: DBSCAN eps parameter
118
+ min_samples: DBSCAN min_samples parameter
119
+ backend: "auto", "cuml", or "cpu"
120
+
121
+ Returns:
122
+ new_masks: [Q', N] expanded boolean masks
123
+ new_scores: [Q'] expanded scores
124
+ new_classes: [Q'] expanded classes
125
+ new_indices: [Q'] indices mapping to original queries
126
+ """
127
+ # 0. Size check (Performance optimization) - REMOVED GLOBAL CHECK
128
+ # if point_coords.shape[0] > 100000:
129
+ # print(f"DBSCAN: Skipping due to large point cloud ({point_coords.shape[0]} points > 100k)")
130
+ # return current_masks, current_scores, current_classes
131
+
132
+ # 1. Determine Backend
133
+ use_cuml = False
134
+ if backend == "auto":
135
+ use_cuml = DBSCAN is not None
136
+ elif backend == "cuml":
137
+ if DBSCAN is None:
138
+ print("Warning: backend='cuml' requested but cuML not found. Falling back to CPU.")
139
+ use_cuml = False
140
+ else:
141
+ use_cuml = True
142
+ elif backend == "cpu":
143
+ use_cuml = False
144
+
145
+ device = current_masks.device
146
+ num_queries = current_masks.shape[0]
147
+
148
+ # Initialize lists to hold the new split masks
149
+ new_masks_list = []
150
+ # We'll store indices pointing to original scores/classes to avoid duplicating them early
151
+ new_indices_list = []
152
+
153
+ # 2. Execution Path
154
+ if use_cuml:
155
+ # --- cuML (GPU) Path ---
156
+ # print(f"DBSCAN (cuML): Processing {point_coords.shape[0]} points")
157
+
158
+ # Ensure data is on GPU and valid types
159
+ # cuML DBSCAN expects input of shape (n_samples, n_features)
160
+ # We process each mask independently.
161
+
162
+ # Optimization: To avoid loop overhead, we could try to batch, but DBSCAN isn't batched.
163
+ # We iterate over queries.
164
+
165
+ for i in range(num_queries):
166
+ mask = current_masks[i]
167
+
168
+ # Skip empty masks
169
+ if not mask.any():
170
+ continue
171
+
172
+ # Filter points for this mask
173
+ # mask is [N], point_coords is [N, 3]
174
+ # Slicing creates a new tensor on GPU
175
+ points = point_coords[mask]
176
+
177
+ # Check per-mask size limit
178
+ if points.shape[0] > 100000:
179
+ # Skip DBSCAN for this mask, keep original
180
+ print(
181
+ f"DBSCAN (cuML): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)"
182
+ )
183
+ new_masks_list.append(mask)
184
+ new_indices_list.append(i)
185
+ continue
186
+
187
+ if points.shape[0] < min_samples:
188
+ # Keep original
189
+ print(
190
+ f"DBSCAN (cuML): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})"
191
+ )
192
+ new_masks_list.append(mask)
193
+ new_indices_list.append(i)
194
+ continue
195
+
196
+ try:
197
+ # Run cuML DBSCAN
198
+ # dbscan = DBSCAN(eps=eps, min_samples=min_samples)
199
+ # labels = dbscan.fit_predict(points)
200
+ # fit_predict returns a cudf Series or cupy array depending on input?
201
+ # If input is torch tensor, cuML >= 23.04 supports __cuda_array_interface__
202
+ # It usually returns a cupy array or similar.
203
+
204
+ # Check if we need to convert to cupy explicitly if torch support is iffy in installed version
205
+ # But modern cuML supports torch tensors.
206
+
207
+ start_time = time.time()
208
+ clusterer = DBSCAN(eps=eps, min_samples=min_samples)
209
+ labels = clusterer.fit_predict(points)
210
+ db_time = time.time() - start_time
211
+
212
+ # Labels is likely a cupy array or similar on GPU
213
+ # Convert to torch for easier handling
214
+ if hasattr(labels, "to_dlpack"):
215
+ from torch.utils.dlpack import from_dlpack
216
+
217
+ labels = from_dlpack(labels.to_dlpack())
218
+ elif hasattr(labels, "__cuda_array_interface__"):
219
+ labels = torch.as_tensor(labels, device=device)
220
+
221
+ unique_labels = torch.unique(labels)
222
+
223
+ # Count valid clusters (excluding noise -1)
224
+ valid_clusters = unique_labels[unique_labels != -1]
225
+
226
+ if len(valid_clusters) == 0:
227
+ # All noise? Or just one noise cluster?
228
+ # If essentially no structure found, maybe keep original or drop?
229
+ # Standard behavior: if it was a mask, and now it's all noise...
230
+ # we probably shouldn't discard the *entire* mask content if it was a valid object.
231
+ # But DBSCAN says it's noise.
232
+ # Let's keep original if nothing valid found, similar to CPU path logic.
233
+ pass
234
+
235
+ found_cluster = False
236
+
237
+ # Reconstruct masks
238
+ # We need global indices of the points
239
+ mask_indices = torch.nonzero(mask, as_tuple=True)[0]
240
+
241
+ for label in valid_clusters:
242
+ found_cluster = True
243
+ # Create new boolean mask
244
+ # 1. Start with zeros
245
+ new_mask = torch.zeros_like(mask)
246
+ # 2. Get local indices where label matches
247
+ local_indices = (labels == label).nonzero(as_tuple=True)[0]
248
+ # 3. Map to global indices
249
+ global_indices = mask_indices[local_indices]
250
+ # 4. Set True
251
+ new_mask[global_indices] = True
252
+
253
+ new_masks_list.append(new_mask)
254
+ new_indices_list.append(i)
255
+
256
+ if not found_cluster:
257
+ # Treat as noise/failure to cluster, keep original?
258
+ if len(new_masks_list) == 0 or new_indices_list[-1] != i:
259
+ # If we haven't added anything for this query `i`
260
+ # (Logic check: strictly speaking we might have added splits from previous masks
261
+ # so checking new_indices_list[-1] is valid only if list not empty)
262
+ pass
263
+
264
+ except Exception as e:
265
+ print(f"DBSCAN (cuML) Error Query {i}: {e}")
266
+ # Fallback: keep original
267
+ new_masks_list.append(mask)
268
+ new_indices_list.append(i)
269
+
270
+ else:
271
+ # --- CPU Path ---
272
+ # print(f"DBSCAN (CPU): Processing {point_coords.shape[0]} points")
273
+
274
+ # Move inputs to CPU
275
+ masks_cpu = current_masks.detach().cpu().numpy()
276
+ coords_cpu = point_coords.detach().cpu().numpy()
277
+
278
+ try:
279
+ from sklearn.cluster import DBSCAN as SklearnDBSCAN
280
+ except ImportError:
281
+ print("Scikit-learn not found. Returning original masks.")
282
+ print("Scikit-learn not found. Returning original masks.")
283
+ return (
284
+ current_masks,
285
+ current_scores,
286
+ current_classes,
287
+ torch.arange(num_queries, device=device),
288
+ )
289
+
290
+ for i in range(num_queries):
291
+ mask = masks_cpu[i]
292
+
293
+ if not mask.any():
294
+ continue
295
+
296
+ points = coords_cpu[mask]
297
+
298
+ # Check per-mask size limit
299
+ if points.shape[0] > 100000:
300
+ # Skip DBSCAN for this mask, keep original
301
+ print(
302
+ f"DBSCAN (CPU): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)"
303
+ )
304
+ new_masks_list.append(current_masks[i])
305
+ new_indices_list.append(i)
306
+ continue
307
+
308
+ if points.shape[0] < min_samples:
309
+ # Keep original
310
+ print(
311
+ f"DBSCAN (CPU): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})"
312
+ )
313
+ new_masks_list.append(current_masks[i])
314
+ new_indices_list.append(i)
315
+ continue
316
+
317
+ try:
318
+ # Ensure float32 for sklearn
319
+ start_time = time.time()
320
+ clusterer = SklearnDBSCAN(eps=eps, min_samples=min_samples)
321
+ labels = clusterer.fit_predict(points.astype(np.float32))
322
+ db_time = time.time() - start_time
323
+ unique_labels = np.unique(labels)
324
+ print(
325
+ f"DBSCAN (CPU): Processing {points.shape[0]} points took {db_time:.4f} seconds, found {len(unique_labels)} clusters"
326
+ )
327
+ found_cluster = False
328
+
329
+ # We need indices to reconstruct mask on GPU/CPU
330
+ # Since we are returning torch tensors on `device`, let's construct list of tensors
331
+ # It is faster to construct on CPU then move or construct on GPU?
332
+ # Constructing on GPU inside loop might be slow due to kernel launches.
333
+ # Let's construct on GPU to match the list type of cuML path
334
+
335
+ mask_indices_cpu = np.nonzero(mask)[0]
336
+
337
+ for label in unique_labels:
338
+ if label == -1:
339
+ continue
340
+ found_cluster = True
341
+
342
+ # Construct new mask
343
+ # It's easier to create on CPU then convert
344
+ new_mask_cpu = np.zeros_like(mask) # bool/uint8
345
+
346
+ local_mask = labels == label
347
+ active_indices = mask_indices_cpu[local_mask]
348
+ new_mask_cpu[active_indices] = 1 # True
349
+
350
+ # Convert to tensor on device
351
+ new_masks_list.append(
352
+ torch.from_numpy(new_mask_cpu).to(device, dtype=torch.bool)
353
+ )
354
+ new_indices_list.append(i)
355
+
356
+ if not found_cluster:
357
+ # Keep original? Currently explicitly dropped in previous code pass?
358
+ # "if not found_cluster: # Treated as noise, currently dropped."
359
+ # But we should probably keep it if it was a valid object that just didn't cluster well?
360
+ # The original code did `pass`.
361
+ pass
362
+
363
+ except Exception as e:
364
+ print(f"DBSCAN (CPU) Error Query {i}: {e}")
365
+ new_masks_list.append(current_masks[i])
366
+ new_indices_list.append(i)
367
+
368
+ # 3. Assemble Results
369
+ if len(new_masks_list) == 0:
370
+ return (
371
+ torch.zeros((0, current_masks.shape[1]), device=device, dtype=torch.bool),
372
+ torch.zeros((0,), device=device, dtype=current_scores.dtype),
373
+ torch.zeros((0,), device=device, dtype=current_classes.dtype),
374
+ torch.zeros((0,), device=device, dtype=torch.long),
375
+ )
376
+
377
+ final_masks = torch.stack(new_masks_list)
378
+
379
+ # Gather scores and classes using indices
380
+ indices_tensor = torch.tensor(new_indices_list, device=device, dtype=torch.long)
381
+ final_scores = current_scores[indices_tensor]
382
+ final_classes = current_classes[indices_tensor]
383
+
384
+ return final_masks, final_scores, final_classes, indices_tensor
385
+
386
+
387
+ def apply_post_processing(
388
+ pred_masks: torch.Tensor,
389
+ pred_logits: torch.Tensor,
390
+ mask_threshold: float = 0.0,
391
+ point_coords: Optional[torch.Tensor] = None,
392
+ pp_cfg: Optional[Dict] = None,
393
+ pred_iou: Optional[torch.Tensor] = None,
394
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
395
+ """
396
+ Applies configured post-processing filters.
397
+
398
+ Args:
399
+ pred_masks: [Q, N] mask logits
400
+ pred_logits: [Q, 2] class logits (objectness is class 0)
401
+ mask_threshold: Threshold for mask binarization (usually 0.0 for logits)
402
+ pred_iou: Optional [Q] learned IoU logits from SpaceFormer's IoU head.
403
+ When provided, `sigmoid(pred_iou)` replaces the hand-coded
404
+ `mask_quality = (sigmoid(masks) * binary).sum / binary.sum` proxy in
405
+ the score = obj * quality formula. DBSCAN expansion copies the same
406
+ scalar to every component of an expanded query.
407
+ pp_cfg: Post-processing configuration dict with keys:
408
+ - objectness_thresh: float (default 0.0, disabled)
409
+ - min_mask_points: int (default 0, disabled)
410
+ - use_stability_score: bool (default False)
411
+ - stability_score_thresh: float (default 0.9)
412
+ - stability_score_offset: float (default 1.0)
413
+ - stability_score_thresh: float (default 0.9)
414
+ - stability_score_offset: float (default 1.0)
415
+ - use_nms: bool (default False)
416
+ - nms_thresh: float (default 0.7)
417
+ - use_dbscan: bool (default False)
418
+ - dbscan_eps: float (default 0.95)
419
+ - dbscan_min_points: int (default 1)
420
+ - dbscan_backend: str (default "auto")
421
+
422
+ Returns:
423
+ final_masks: [Q', N] final binary masks
424
+ final_scores: [Q'] final scores
425
+ final_classes: [Q'] final classes
426
+ final_indices: [Q'] indices mapping to original queries
427
+ """
428
+ if pp_cfg is None:
429
+ pp_cfg = {}
430
+
431
+ # Basic preparation
432
+ masks_binary = pred_masks > mask_threshold
433
+
434
+ # 0. Min Point Count Filtering (FIRST STEP - early rejection)
435
+ # Filter out small masks before expensive operations like DBSCAN
436
+ keep = torch.arange(pred_masks.shape[0], device=pred_masks.device)
437
+
438
+ if pp_cfg.get("min_mask_points", 0) > 0:
439
+ counts = masks_binary.float().sum(1)
440
+ keep_size = counts >= pp_cfg["min_mask_points"]
441
+ keep = keep[keep_size]
442
+
443
+ if len(keep) == 0:
444
+ return (
445
+ torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
446
+ torch.zeros((0,), device=pred_masks.device, dtype=pred_masks.dtype),
447
+ torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
448
+ torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
449
+ )
450
+
451
+ # Filter all inputs
452
+ masks_binary = masks_binary[keep]
453
+ pred_masks = pred_masks[keep]
454
+ pred_logits = pred_logits[keep]
455
+ if pred_iou is not None:
456
+ pred_iou = pred_iou[keep]
457
+
458
+ # 1. DBSCAN Expansion
459
+ # If DBSCAN is used, we expand masks immediately.
460
+ # We maintain a mapping to original logits to allow stability calculation later.
461
+
462
+ current_masks = masks_binary
463
+ current_logits = pred_masks
464
+ current_pred_logits = pred_logits
465
+
466
+ # Track indices (now relative to filtered set if min_mask_points was applied)
467
+ current_indices = keep.clone()
468
+
469
+ # Objectness component
470
+ # Check what class 0 means?
471
+ obj_probs = pred_logits.softmax(dim=-1)[:, 0]
472
+
473
+ # Mask quality component (IoU proxy) — learned if pred_iou is provided
474
+ # (P3-SAM-style IoU head), otherwise the hand-coded sigmoid-mean proxy.
475
+ if pred_iou is not None:
476
+ mask_quality = pred_iou.sigmoid()
477
+ else:
478
+ masks_sigmoid = pred_masks.sigmoid()
479
+ mask_quality = (masks_sigmoid * masks_binary.float()).sum(1) / (
480
+ masks_binary.float().sum(1) + 1e-6
481
+ )
482
+ scores = obj_probs * mask_quality
483
+ classes = torch.zeros_like(scores, dtype=torch.long) # class 0
484
+
485
+ if pp_cfg.get("use_dbscan", False) and point_coords is not None:
486
+ current_masks, scores, classes, dbscan_indices = apply_dbscan_clustering(
487
+ current_masks,
488
+ point_coords,
489
+ scores,
490
+ classes,
491
+ eps=pp_cfg.get("dbscan_eps", 0.95),
492
+ min_samples=pp_cfg.get("dbscan_min_points", 1),
493
+ backend=pp_cfg.get("dbscan_backend", "auto"),
494
+ )
495
+
496
+ # We need to map them back to original query indices
497
+ current_indices = keep[dbscan_indices]
498
+
499
+ # Expand logits and other properties to match split masks
500
+ # Use dbscan_indices (relative to current filtered set) for indexing current tensors
501
+ current_logits = current_logits[dbscan_indices]
502
+ current_pred_logits = current_pred_logits[dbscan_indices]
503
+ obj_probs = obj_probs[dbscan_indices]
504
+
505
+ # MASK THE LOGITS (Stability Fix)
506
+ # Key step: constrain the logits to the new binary mask shape
507
+ # so stability score is calculated on the component, not the whole original mask.
508
+ # We use a large negative value for background.
509
+ current_logits = torch.where(current_masks, current_logits, -100.0)
510
+
511
+ # Recalculate mask quality for the NEW masks. With learned IoU we copy
512
+ # the parent query's scalar to every expanded component (no per-component
513
+ # IoU prediction is available); without it, recompute the sigmoid-mean
514
+ # proxy from the masked logits.
515
+ if pred_iou is not None:
516
+ mask_quality = pred_iou[dbscan_indices].sigmoid()
517
+ else:
518
+ masks_sigmoid = current_logits.sigmoid()
519
+ mask_quality = (masks_sigmoid * current_masks.float()).sum(1) / (
520
+ current_masks.float().sum(1) + 1e-6
521
+ )
522
+ # Recalculate scores (Obj * Quality)
523
+ scores = obj_probs * mask_quality
524
+
525
+ # Now we have `current_masks` (binary) and `current_logits` (masked logits).
526
+ # All subsequent steps operate on these.
527
+
528
+ # 2. Objectness Filtering
529
+ keep = torch.arange(current_masks.shape[0], device=current_masks.device)
530
+
531
+ if pp_cfg.get("objectness_thresh", 0.0) > 0:
532
+ # obj_probs is aligned with current set
533
+ keep_obj = obj_probs > pp_cfg["objectness_thresh"]
534
+ keep = keep[keep_obj[keep]]
535
+
536
+ if len(keep) == 0:
537
+ return (
538
+ torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
539
+ torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype),
540
+ torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype),
541
+ torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
542
+ )
543
+
544
+ # 3. Stability Score
545
+ if pp_cfg.get("use_stability_score", False):
546
+ active_logits = current_logits[keep]
547
+ stability = calculate_stability_score(
548
+ active_logits,
549
+ mask_threshold,
550
+ pp_cfg.get("stability_score_offset", 1.0),
551
+ )
552
+ keep_stable = stability >= pp_cfg.get("stability_score_thresh", 0.9)
553
+ keep = keep[keep_stable]
554
+
555
+ if len(keep) == 0:
556
+ return (
557
+ torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
558
+ torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype),
559
+ torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype),
560
+ torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
561
+ )
562
+
563
+ # 4. NMS
564
+ if pp_cfg.get("use_nms", False):
565
+ active_masks = current_masks[keep]
566
+ active_scores = scores[keep]
567
+
568
+ keep_nms = apply_nms(active_masks, active_scores, pp_cfg.get("nms_thresh", 0.7))
569
+ keep = keep[keep_nms]
570
+
571
+ # Final gather
572
+ final_masks = current_masks[keep]
573
+ final_scores = scores[keep]
574
+ final_classes = classes[keep]
575
+ final_indices = current_indices[keep]
576
+
577
+ return final_masks, final_scores, final_classes, final_indices
demo/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Demo/inference dependencies.
2
+ # WarpConvNet (with its compiled _C extension) must be installed separately — a
3
+ # pre-built wheel or built from source; it is environment-specific so not pinned here.
4
+ torch
5
+ einops
6
+ transformers
7
+ numpy
8
+ huggingface_hub
9
+ gradio
10
+ plotly
11
+ # optional point-cloud input formats
12
+ plyfile
13
+ # local viser demo (demo_viser.py): interactive 3D viewer + sample .ply I/O
14
+ viser
15
+ open3d
demo/text_encoder.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import List, Tuple
4
+
5
+ import hashlib
6
+ import os
7
+ import abc
8
+ import numpy as np
9
+
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModel
12
+
13
+ # Assume that models are already cached
14
+ os.environ["HF_HUB_OFFLINE"] = "1"
15
+
16
+
17
+ # Use a deterministic hash function for strings
18
+ def string_hash(s: str) -> int:
19
+ return int(hashlib.md5(s.encode()).hexdigest(), 16)
20
+
21
+
22
+ class CLIPTextEncoderInterace(abc.ABC):
23
+ model: torch.nn.Module
24
+ CHANNEL_DIM: int
25
+
26
+ def __post_init__(self):
27
+ self.freeze_encoder()
28
+
29
+ def freeze_encoder(self):
30
+ for params in self.model.parameters():
31
+ params.requires_grad = False
32
+
33
+ @abc.abstractmethod
34
+ def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
35
+ raise NotImplementedError
36
+
37
+ @torch.inference_mode()
38
+ def get_unique_text_embedding(
39
+ self,
40
+ list_of_texts: List[str] | List[List[str]],
41
+ normalize: bool = True,
42
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
43
+ """Get unique embeddings for a list of texts.
44
+
45
+ Args:
46
+ list_of_texts: List[str] | List[List[str]]
47
+ List of texts or list of list of texts to get unique embeddings for.
48
+ Total number of texts is N.
49
+
50
+ Returns:
51
+ embeddings: torch.Tensor, shape (M, D)
52
+ Unique embeddings for the list of texts.
53
+ from_unique_indices: torch.Tensor, shape (N,)
54
+ Indices of the texts in the original list.
55
+ to_unique_indices: torch.Tensor, shape (M,)
56
+ Indices of the unique texts in the flattened list.
57
+ """
58
+ # Flatten the list of texts
59
+ if isinstance(list_of_texts, list) and isinstance(list_of_texts[0], list):
60
+ # list of lists
61
+ list_of_texts = [item for sublist in list_of_texts for item in sublist]
62
+
63
+ # cchoy: Get unique texts using hash. Using string directly is not deterministic due to python string object not using the string values only for hashing.
64
+ flat_caption_hash = [string_hash(caption) for caption in list_of_texts]
65
+ _, to_unique_indices, from_unique_indices = np.unique(
66
+ flat_caption_hash, return_index=True, return_inverse=True
67
+ )
68
+
69
+ # Get unique texts
70
+ unique_texts = [list_of_texts[i] for i in to_unique_indices]
71
+
72
+ # Get embeddings
73
+ embeddings = self(unique_texts, normalize=normalize)
74
+
75
+ # Return embeddings and indices
76
+ return embeddings, torch.tensor(from_unique_indices), torch.tensor(to_unique_indices)
77
+
78
+
79
+ def get_text_encoder(
80
+ model_type: str,
81
+ device: str,
82
+ **kwargs,
83
+ ) -> CLIPTextEncoderInterace:
84
+ if model_type == "siglip2":
85
+ return Siglip2TextEncoder(device=device, **kwargs)
86
+ elif model_type == "openclip": # Recap CLIP is also openclip
87
+ return OpenCLIPTextEncoder(device=device, **kwargs)
88
+ else:
89
+ raise ValueError(f"Model type {model_type} not supported")
90
+
91
+
92
+ class OpenCLIPTextEncoder(CLIPTextEncoderInterace):
93
+ CHANNEL_DIM = None
94
+
95
+ def __init__(
96
+ self,
97
+ model_id: str,
98
+ device: str = "cuda",
99
+ torch_dtype: torch.dtype = torch.bfloat16,
100
+ context_length: int = None,
101
+ **kwargs,
102
+ ):
103
+ # This is a not a required dependency, so we need to import it here
104
+ try:
105
+ from open_clip import create_model_from_pretrained, get_tokenizer
106
+ except ImportError:
107
+ raise ImportError(
108
+ "open_clip is not installed. Please install it with `pip install open-clip`"
109
+ )
110
+ self.prepare_data(model_id)
111
+
112
+ self.tokenizer = get_tokenizer(model_id)
113
+ precision = {torch.float16: "fp16", torch.bfloat16: "bf16"}[torch_dtype]
114
+ self.model, _ = create_model_from_pretrained(
115
+ model_id,
116
+ device=device,
117
+ precision=precision,
118
+ )
119
+ self.device = device
120
+
121
+ # Set context_length: use provided value, or infer from model, or use default
122
+ if context_length is not None:
123
+ self.context_length = context_length
124
+ elif hasattr(self.model, "context_length"):
125
+ self.context_length = self.model.context_length
126
+ elif hasattr(self.model, "text") and hasattr(self.model.text, "context_length"):
127
+ self.context_length = self.model.text.context_length
128
+ else:
129
+ # Default to 77 for standard CLIP models
130
+ self.context_length = 77
131
+ print(
132
+ f"Warning: Could not infer context_length from model, using default: {self.context_length}"
133
+ )
134
+
135
+ def prepare_data(self, model_id: str):
136
+ from open_clip.factory import download_pretrained_from_hf
137
+
138
+ # Remove hf-hub: prefix if it exists
139
+ model_id = model_id[len("hf-hub:") :] if model_id.startswith("hf-hub:") else model_id
140
+ ckpt_path = download_pretrained_from_hf(
141
+ model_id, cache_dir=os.environ.get("HF_HUB_CACHE", os.path.expanduser("~/.cache/"))
142
+ )
143
+ return ckpt_path
144
+
145
+ @torch.inference_mode()
146
+ @torch.amp.autocast(enabled=True, device_type="cuda")
147
+ def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
148
+ text_tokens = self.tokenizer(list_of_texts, context_length=self.context_length).to(
149
+ self.device
150
+ )
151
+ embeddings = self.model.encode_text(text_tokens)
152
+ if normalize:
153
+ embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
154
+ return embeddings
155
+
156
+
157
+ class Siglip2TextEncoder(CLIPTextEncoderInterace):
158
+ CHANNEL_DIM = 1152
159
+
160
+ def __init__(
161
+ self,
162
+ model_id: str = "google/siglip2-so400m-patch16-384",
163
+ device: str = "cuda",
164
+ attn_implementation: str = "flash_attention_2",
165
+ torch_dtype: torch.dtype = torch.bfloat16,
166
+ **kwargs,
167
+ ):
168
+ # Disable tokenizer parallelism
169
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
170
+
171
+ # Try loading from local cache first to avoid 429 errors
172
+ try:
173
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
174
+ self.model = AutoModel.from_pretrained(
175
+ model_id,
176
+ attn_implementation=attn_implementation,
177
+ torch_dtype=torch_dtype,
178
+ device_map=device,
179
+ local_files_only=True,
180
+ )
181
+ print(f"Successfully loaded {model_id} from local cache.")
182
+ except OSError:
183
+ print(
184
+ f"Model {model_id} not found locally. Downloading/Updating from Hugging Face Hub..."
185
+ )
186
+ # Fallback to downloading if not found locally
187
+ # This might still hit 429 if many ranks try it, but it's the standard fallback.
188
+ # Ideally verify downloading on rank 0 only in a multi-node setup if this persists.
189
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
190
+ self.model = AutoModel.from_pretrained(
191
+ model_id,
192
+ attn_implementation=attn_implementation,
193
+ torch_dtype=torch_dtype,
194
+ device_map=device,
195
+ )
196
+
197
+ self.model.vision_model = None # Remove vision model
198
+ self.device = device
199
+
200
+ @torch.inference_mode()
201
+ @torch.amp.autocast(enabled=True, device_type="cuda")
202
+ def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
203
+ # Length is 64 https://huggingface.co/docs/transformers/main/model_doc/siglip2
204
+ text_inputs = self.tokenizer(
205
+ list_of_texts,
206
+ padding="max_length",
207
+ truncation=True,
208
+ max_length=64,
209
+ return_tensors="pt",
210
+ ).to(self.device)
211
+ outputs = self.model.get_text_features(**text_inputs)
212
+ # In newer transformers, get_text_features may return a
213
+ # BaseModelOutputWithPooling instead of a plain tensor.
214
+ if not isinstance(outputs, torch.Tensor):
215
+ outputs = outputs.pooler_output
216
+ if normalize:
217
+ outputs = torch.nn.functional.normalize(outputs, dim=-1)
218
+ return outputs