Merge SpaceFormer demo (viser + CLI + Gradio) under demo/
Browse files- demo/README.md +112 -0
- demo/app.py +183 -0
- demo/clip_eval.py +91 -0
- demo/demo_viser.py +341 -0
- demo/inference.py +100 -0
- demo/labels.py +245 -0
- demo/pipeline.py +203 -0
- demo/postprocessing.py +577 -0
- demo/requirements.txt +15 -0
- demo/text_encoder.py +218 -0
demo/README.md
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SpaceFormer Open-Vocab 3D Instance Segmentation
|
| 3 |
+
emoji: 🧩
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
tags:
|
| 12 |
+
- 3d
|
| 13 |
+
- point-cloud
|
| 14 |
+
- instance-segmentation
|
| 15 |
+
- open-vocabulary
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# SpaceFormer — Open-Vocabulary 3D Instance Segmentation (demo)
|
| 19 |
+
|
| 20 |
+
Proposal-free **open-vocabulary 3D instance segmentation**. A Mask2Former-style query
|
| 21 |
+
decoder (learned queries + RoPE) on top of the WarpConvNet `SpaCeFormer` backbone: one
|
| 22 |
+
forward pass over an RGB point cloud produces query masks + per-query CLIP features,
|
| 23 |
+
which are labeled against text embeddings of **arbitrary** class names (SigLIP2, with
|
| 24 |
+
prompt ensembling) — the vocabulary is chosen at inference time.
|
| 25 |
+
|
| 26 |
+
Released checkpoint:
|
| 27 |
+
|
| 28 |
+
| Benchmark | mAP |
|
| 29 |
+
|---|---|
|
| 30 |
+
| ScanNet200 | 0.1265 |
|
| 31 |
+
| ScanNet++ | 0.2217 |
|
| 32 |
+
| Replica | 0.2644 |
|
| 33 |
+
|
| 34 |
+
This repo is the **demo / inference layer**. The model itself lives in WarpConvNet
|
| 35 |
+
(`warpconvnet.models.spaceformer`); this repo only adds the Gradio UI (`app.py`) and a
|
| 36 |
+
CLI inference entry point (`inference.py`).
|
| 37 |
+
|
| 38 |
+
## Requirements
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
> **WarpConvNet must be installed with its compiled extension** (a pre-built wheel, or
|
| 45 |
+
> build from source). It is intentionally not pinned in `requirements.txt` because it is
|
| 46 |
+
> environment-specific. `transformers` pulls the SigLIP2 text encoder
|
| 47 |
+
> (`google/siglip2-so400m-patch14-224`) on first use.
|
| 48 |
+
|
| 49 |
+
## Live demo (Gradio / HuggingFace Space)
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
HF_REPO_ID=chrischoy/SpaCeFormer python app.py
|
| 53 |
+
# or a local checkpoint:
|
| 54 |
+
SPACEFORMER_CKPT=/path/to/spaceformer_512_siglip2_ssccc.ckpt python app.py
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Upload a point cloud, type comma-separated class names, get an interactive 3D view
|
| 58 |
+
colored by predicted instance + a ranked table. As a **HuggingFace Space**: create a
|
| 59 |
+
**GPU** Gradio Space, install WarpConvNet + `requirements.txt` in the image, and set the
|
| 60 |
+
Space variables `HF_REPO_ID` (and optional `HF_FILENAME`, default
|
| 61 |
+
`spaceformer_512_siglip2_ssccc.ckpt`).
|
| 62 |
+
|
| 63 |
+
## Local demo (viser)
|
| 64 |
+
|
| 65 |
+
An interactive, self-contained local demo that takes **text class names**, runs
|
| 66 |
+
segmentation, and visualizes the result in the browser with
|
| 67 |
+
[viser](https://viser.studio) — each predicted instance gets a distinct color,
|
| 68 |
+
unassigned points stay grey, and a GUI panel lists the top instances.
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
# auto-download the checkpoint + use a bundled sample point cloud
|
| 72 |
+
python demo_viser.py --port 8080
|
| 73 |
+
|
| 74 |
+
# your own cloud + vocabulary, local checkpoint
|
| 75 |
+
python demo_viser.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
|
| 76 |
+
--ply my_scene.ply --class-names chair table monitor wall floor
|
| 77 |
+
|
| 78 |
+
# full ScanNet200 label set
|
| 79 |
+
python demo_viser.py --ply my_scene.ply --use-scannet200
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Then open the printed URL (default `http://localhost:8080`) in a browser.
|
| 83 |
+
With no `--ply`, the demo uses an open3d bundled sample cloud (or a synthesized
|
| 84 |
+
random RGB cloud) — a generic cloud won't segment meaningfully; it only
|
| 85 |
+
demonstrates that the pipeline + viewer run end to end. The demo colors the
|
| 86 |
+
model's **output** points (`out["backbone_pc"].coordinates`), which are what the
|
| 87 |
+
predicted masks index into after the model's internal voxelization — not the raw
|
| 88 |
+
`.ply` points, whose count may differ.
|
| 89 |
+
|
| 90 |
+
## CLI inference
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
# local checkpoint
|
| 94 |
+
python inference.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
|
| 95 |
+
--scene /path/to/scene_dir # dir with coord.npy + color.npy
|
| 96 |
+
|
| 97 |
+
# or auto-download from a HuggingFace model repo
|
| 98 |
+
HF_REPO_ID=chrischoy/SpaCeFormer python inference.py \
|
| 99 |
+
--scene my_scene.ply --class-names "office chair" "desk" "monitor" "other"
|
| 100 |
+
|
| 101 |
+
# full ScanNet200 label set
|
| 102 |
+
python inference.py --ckpt <ckpt> --scene <scene> --use-scannet200
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
`--scene` accepts a directory with `coord.npy`(`[N,3]` float meters)+`color.npy`(`[N,3]`
|
| 106 |
+
0–255), a `.npz` `{coord,color}`, an `[N,6]` `.npy` (xyz,rgb), or a `.ply`. Coordinates
|
| 107 |
+
stay in **meters** — the model voxelizes internally at 2 cm. Output: a ranked list of
|
| 108 |
+
`{label, score, #points}`; `score = objectness · mask_quality · class_prob`.
|
| 109 |
+
|
| 110 |
+
## License
|
| 111 |
+
|
| 112 |
+
Apache-2.0, matching the WarpConvNet `space_former.py` SPDX header.
|
demo/app.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""HuggingFace Space: SpaceFormer open-vocabulary 3D instance segmentation release.
|
| 4 |
+
|
| 5 |
+
Presentation/deployment layer. All model + inference logic is imported from the
|
| 6 |
+
installed ``warpconvnet`` library (``warpconvnet.models.spaceformer``); this file
|
| 7 |
+
only adds the Gradio UI, the 3D Plotly viewer, and checkpoint download.
|
| 8 |
+
|
| 9 |
+
Upload an RGB point cloud (.ply / [N,6] .npy / .npz), type comma-separated class
|
| 10 |
+
names, and get an interactive 3D view colored by predicted instance + a ranked
|
| 11 |
+
table of {label, score, #points}.
|
| 12 |
+
|
| 13 |
+
WarpConvNet (with its compiled extension) and transformers must be installed in
|
| 14 |
+
the Space image. Configure the checkpoint via Space variables:
|
| 15 |
+
|
| 16 |
+
HF_REPO_ID model repo holding the checkpoint (e.g. chrischoy/SpaCeFormer)
|
| 17 |
+
HF_FILENAME checkpoint filename (default: spaceformer_512_siglip2_ssccc.ckpt)
|
| 18 |
+
SPACEFORMER_CKPT explicit local checkpoint path (overrides the HF download)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
from warpconvnet.models.spaceformer import (
|
| 27 |
+
build_spaceformer,
|
| 28 |
+
load_spaceformer_checkpoint,
|
| 29 |
+
)
|
| 30 |
+
from labels import DEFAULT_CLASS_NAMES, PROMPT_TEMPLATES
|
| 31 |
+
from pipeline import (
|
| 32 |
+
SIGLIP_MODEL_ID,
|
| 33 |
+
load_scene,
|
| 34 |
+
make_batch,
|
| 35 |
+
predict_instances,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
HF_REPO_ID = os.environ.get("HF_REPO_ID", "")
|
| 39 |
+
HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt")
|
| 40 |
+
|
| 41 |
+
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
+
_STATE = {"net": None, "clip": None} # lazy singletons, kept resident across requests
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _resolve_ckpt() -> str:
|
| 46 |
+
explicit = os.environ.get("SPACEFORMER_CKPT")
|
| 47 |
+
if explicit:
|
| 48 |
+
return explicit
|
| 49 |
+
if not HF_REPO_ID:
|
| 50 |
+
raise RuntimeError(
|
| 51 |
+
"Set SPACEFORMER_CKPT to a local checkpoint, or HF_REPO_ID to a "
|
| 52 |
+
"HuggingFace model repo to auto-download from."
|
| 53 |
+
)
|
| 54 |
+
from huggingface_hub import hf_hub_download
|
| 55 |
+
return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_model():
|
| 59 |
+
if _STATE["net"] is None:
|
| 60 |
+
net = build_spaceformer(device=_DEVICE)
|
| 61 |
+
load_spaceformer_checkpoint(net, _resolve_ckpt())
|
| 62 |
+
_STATE["net"] = net
|
| 63 |
+
return _STATE["net"]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _get_clip_encoder():
|
| 67 |
+
if _STATE["clip"] is None:
|
| 68 |
+
from text_encoder import get_text_encoder
|
| 69 |
+
_STATE["clip"] = get_text_encoder(
|
| 70 |
+
model_type="siglip2", model_id=SIGLIP_MODEL_ID, device=str(_DEVICE)
|
| 71 |
+
)
|
| 72 |
+
return _STATE["clip"]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _text_eval(class_names):
|
| 76 |
+
from clip_eval import CLIPAlignmentEval
|
| 77 |
+
evaluator = CLIPAlignmentEval(normalize_input=False)
|
| 78 |
+
evaluator.prepare_target_embedding(
|
| 79 |
+
class_names=list(class_names),
|
| 80 |
+
clip_encoder=_get_clip_encoder(),
|
| 81 |
+
device=_DEVICE,
|
| 82 |
+
prompt_templates=list(PROMPT_TEMPLATES),
|
| 83 |
+
)
|
| 84 |
+
return evaluator
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _palette(n):
|
| 88 |
+
rng = np.random.default_rng(0)
|
| 89 |
+
return [tuple(int(x) for x in rng.integers(40, 230, size=3)) for _ in range(max(n, 1))]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _plot(coord_np, results, top_k, score_thresh):
|
| 93 |
+
"""Plotly 3D scatter colored by instance (top_k by score)."""
|
| 94 |
+
import plotly.graph_objects as go
|
| 95 |
+
|
| 96 |
+
kept = [r for r in results if r["score"] >= score_thresh][:top_k]
|
| 97 |
+
rgb = np.full((coord_np.shape[0], 3), 160, dtype=np.uint8) # grey background
|
| 98 |
+
palette = _palette(len(kept))
|
| 99 |
+
for i, r in enumerate(kept):
|
| 100 |
+
rgb[r["mask"]] = palette[i]
|
| 101 |
+
colors = [f"rgb({c[0]},{c[1]},{c[2]})" for c in rgb]
|
| 102 |
+
|
| 103 |
+
# Subsample for browser responsiveness.
|
| 104 |
+
n = coord_np.shape[0]
|
| 105 |
+
if n > 120_000:
|
| 106 |
+
idx = np.random.default_rng(0).choice(n, 120_000, replace=False)
|
| 107 |
+
else:
|
| 108 |
+
idx = np.arange(n)
|
| 109 |
+
|
| 110 |
+
fig = go.Figure(
|
| 111 |
+
data=[go.Scatter3d(
|
| 112 |
+
x=coord_np[idx, 0], y=coord_np[idx, 1], z=coord_np[idx, 2],
|
| 113 |
+
mode="markers",
|
| 114 |
+
marker=dict(size=1.5, color=[colors[i] for i in idx]),
|
| 115 |
+
)]
|
| 116 |
+
)
|
| 117 |
+
fig.update_layout(
|
| 118 |
+
scene=dict(aspectmode="data"),
|
| 119 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
| 120 |
+
showlegend=False,
|
| 121 |
+
)
|
| 122 |
+
return fig
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def segment(scene_file, class_text, top_k, score_thresh):
|
| 126 |
+
"""Gradio callback: file + class names -> (3D figure, results table)."""
|
| 127 |
+
if scene_file is None:
|
| 128 |
+
return None, [["(upload a point cloud first)", "", ""]]
|
| 129 |
+
|
| 130 |
+
class_names = [c.strip() for c in class_text.split(",") if c.strip()] \
|
| 131 |
+
or list(DEFAULT_CLASS_NAMES)
|
| 132 |
+
|
| 133 |
+
path = scene_file.name if hasattr(scene_file, "name") else scene_file
|
| 134 |
+
coord_np, color_np = load_scene(path)
|
| 135 |
+
batch = make_batch(coord_np, color_np, _DEVICE)
|
| 136 |
+
|
| 137 |
+
net = _get_model()
|
| 138 |
+
results = predict_instances(net, batch, _text_eval(class_names), class_names)
|
| 139 |
+
|
| 140 |
+
fig = _plot(coord_np, results, int(top_k), float(score_thresh))
|
| 141 |
+
table = [
|
| 142 |
+
[r["label"], f"{r['score']:.3f}", int(r["mask"].sum())]
|
| 143 |
+
for r in results[: int(top_k)]
|
| 144 |
+
]
|
| 145 |
+
if not table:
|
| 146 |
+
table = [["(no instances above threshold)", "", ""]]
|
| 147 |
+
return fig, table
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def build_interface():
|
| 151 |
+
import gradio as gr
|
| 152 |
+
|
| 153 |
+
with gr.Blocks(title="SpaceFormer — Open-Vocab 3D Instance Segmentation") as demo:
|
| 154 |
+
gr.Markdown(
|
| 155 |
+
"# SpaceFormer\n"
|
| 156 |
+
"Proposal-free **open-vocabulary 3D instance segmentation**. Upload an "
|
| 157 |
+
"RGB point cloud, type any class names, and get instance masks labeled "
|
| 158 |
+
"against your vocabulary (SigLIP2 text + prompt ensembling).\n\n"
|
| 159 |
+
"Released checkpoint: ScanNet200 **0.1265** / ScanNet++ 0.2217 / Replica 0.2644."
|
| 160 |
+
)
|
| 161 |
+
with gr.Row():
|
| 162 |
+
with gr.Column(scale=1):
|
| 163 |
+
scene_file = gr.File(label="Point cloud (.ply / .npy[N,6] / .npz)")
|
| 164 |
+
class_text = gr.Textbox(
|
| 165 |
+
label="Class names (comma-separated)",
|
| 166 |
+
value=", ".join(DEFAULT_CLASS_NAMES),
|
| 167 |
+
)
|
| 168 |
+
top_k = gr.Slider(1, 100, value=30, step=1, label="Max instances shown")
|
| 169 |
+
score_thresh = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Score threshold")
|
| 170 |
+
run = gr.Button("Segment", variant="primary")
|
| 171 |
+
with gr.Column(scale=2):
|
| 172 |
+
plot = gr.Plot(label="Predicted instances (colored)")
|
| 173 |
+
table = gr.Dataframe(
|
| 174 |
+
headers=["label", "score", "#points"],
|
| 175 |
+
label="Instances",
|
| 176 |
+
wrap=True,
|
| 177 |
+
)
|
| 178 |
+
run.click(segment, [scene_file, class_text, top_k, score_thresh], [plot, table])
|
| 179 |
+
return demo
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
build_interface().launch(server_name="0.0.0.0")
|
demo/clip_eval.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Open-vocabulary CLIP alignment for inference.
|
| 4 |
+
|
| 5 |
+
Self-contained extract of the training repo's ``CLIPAlignmentEval``: encode class
|
| 6 |
+
names into a text-embedding matrix (optionally with prompt ensembling) and score
|
| 7 |
+
per-query CLIP features against it via cosine similarity.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
from typing import List, Optional
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
log = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CLIPAlignmentEval(nn.Module):
|
| 21 |
+
"""Cosine-similarity classifier between per-query CLIP features and text embeddings.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
normalize_input: L2-normalize the query features before the cosine product.
|
| 25 |
+
For SpaceFormer set this to ``False`` — the clip head output is already
|
| 26 |
+
compared directly (matches the official eval recipe).
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, normalize_input: bool = False):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.normalize_input = normalize_input
|
| 32 |
+
self.emb_target: Optional[torch.Tensor] = None # [C, D] L2-normalized
|
| 33 |
+
|
| 34 |
+
def set_target_embedding(self, text_embeddings: torch.Tensor) -> None:
|
| 35 |
+
self.emb_target = text_embeddings.float()
|
| 36 |
+
|
| 37 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
if self.normalize_input:
|
| 39 |
+
return F.normalize(x, p=2, dim=1)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
def predict(self, x: torch.Tensor, return_logit: bool = False) -> torch.Tensor:
|
| 43 |
+
"""Score features ``x`` [Q, D] against the text embeddings -> [Q, C]."""
|
| 44 |
+
assert self.emb_target is not None, "call prepare_target_embedding() first"
|
| 45 |
+
pred = self.forward(x)
|
| 46 |
+
logit = torch.matmul(pred, self.emb_target.t().to(pred.dtype))
|
| 47 |
+
if return_logit:
|
| 48 |
+
return logit
|
| 49 |
+
return logit.argmax(dim=1)
|
| 50 |
+
|
| 51 |
+
@torch.inference_mode()
|
| 52 |
+
def prepare_target_embedding(
|
| 53 |
+
self,
|
| 54 |
+
class_names: List[str],
|
| 55 |
+
clip_encoder: nn.Module,
|
| 56 |
+
device: torch.device,
|
| 57 |
+
use_prompt: bool = False,
|
| 58 |
+
prompt_template: Optional[str] = None,
|
| 59 |
+
prompt_templates: Optional[List[str]] = None,
|
| 60 |
+
) -> None:
|
| 61 |
+
"""Encode ``class_names`` into the [C, D] target matrix.
|
| 62 |
+
|
| 63 |
+
Three mutually exclusive prompting modes (first non-empty wins):
|
| 64 |
+
- ``prompt_templates``: prompt ensembling — render each class under every
|
| 65 |
+
``"... {} ..."`` template, per-row L2-normalize, mean, re-normalize.
|
| 66 |
+
This is the recommended eval-time free win.
|
| 67 |
+
- ``prompt_template``: a single ``"... {c} ..."`` format string.
|
| 68 |
+
- ``use_prompt``: the OpenScene default ``"a {} in a scene"``.
|
| 69 |
+
The token ``"other"`` is always encoded bare (no template) so the
|
| 70 |
+
background/void class stays neutral.
|
| 71 |
+
"""
|
| 72 |
+
log.info("Preparing CLIP target embedding for %d classes", len(class_names))
|
| 73 |
+
if prompt_templates:
|
| 74 |
+
log.info("Prompt ensembling over %d templates", len(prompt_templates))
|
| 75 |
+
ensembled = None
|
| 76 |
+
for template in prompt_templates:
|
| 77 |
+
rendered = [
|
| 78 |
+
template.format(c) if "other" not in c else "other" for c in class_names
|
| 79 |
+
]
|
| 80 |
+
emb = F.normalize(clip_encoder(rendered, normalize=True).float(), p=2, dim=-1)
|
| 81 |
+
ensembled = emb if ensembled is None else ensembled + emb
|
| 82 |
+
text_embedding = F.normalize(ensembled / float(len(prompt_templates)), p=2, dim=-1)
|
| 83 |
+
elif prompt_template is not None:
|
| 84 |
+
rendered = [prompt_template.format(c=c) for c in class_names]
|
| 85 |
+
text_embedding = clip_encoder(rendered, normalize=True)
|
| 86 |
+
elif use_prompt:
|
| 87 |
+
rendered = [f"a {c} in a scene" if "other" not in c else "other" for c in class_names]
|
| 88 |
+
text_embedding = clip_encoder(rendered, normalize=True)
|
| 89 |
+
else:
|
| 90 |
+
text_embedding = clip_encoder(class_names, normalize=True)
|
| 91 |
+
self.set_target_embedding(text_embedding.to(device))
|
demo/demo_viser.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Local viser demo: SpaceFormer open-vocabulary 3D instance segmentation.
|
| 4 |
+
|
| 5 |
+
Takes TEXT class names, runs one forward pass of the released SpaceFormer over a
|
| 6 |
+
point cloud, and visualizes the predicted instances in a browser with
|
| 7 |
+
`viser <https://viser.studio>`_ (each kept instance a distinct color; unassigned
|
| 8 |
+
points grey). Model + pipeline come from the installed ``warpconvnet`` library
|
| 9 |
+
and the sibling ``pipeline.py`` / ``postprocessing.py`` in this repo — this file
|
| 10 |
+
only adds the forward + viser visualization glue.
|
| 11 |
+
|
| 12 |
+
# local checkpoint, custom vocabulary
|
| 13 |
+
python demo_viser.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
|
| 14 |
+
--ply my_scene.ply --class-names chair table monitor wall floor
|
| 15 |
+
|
| 16 |
+
# auto-download the checkpoint from HuggingFace and use a bundled sample cloud
|
| 17 |
+
python demo_viser.py --port 8080 # falls back to chrischoy/SpaCeFormer
|
| 18 |
+
|
| 19 |
+
Then open the printed URL (default http://localhost:8080) in a browser.
|
| 20 |
+
|
| 21 |
+
CRITICAL coord/mask alignment (see space_former_seg.py):
|
| 22 |
+
The model voxelizes internally (PointToSparseWrapper), so its output masks are
|
| 23 |
+
over the model's OUTPUT points, whose count may NOT equal the raw .ply point
|
| 24 |
+
count. In eval mode the forward returns ``out["backbone_pc"]`` (a warpconvnet
|
| 25 |
+
``Points``) whose ``.coordinates`` correspond 1:1 with the per-point mask rows
|
| 26 |
+
in ``out["mask"][0]``. We therefore run the forward pass HERE (mirroring
|
| 27 |
+
pipeline.predict_instances) so we can pull the backbone_pc coordinates and the
|
| 28 |
+
post-processed masks together, and color THOSE coordinates — never the raw
|
| 29 |
+
input coords, which may differ in length/order.
|
| 30 |
+
|
| 31 |
+
NOTE: WarpConvNet (with its compiled ``_C`` extension) + transformers + a
|
| 32 |
+
checkpoint are required to actually run. ``--help`` and import of this file work
|
| 33 |
+
without viser/WarpConvNet (both are imported lazily inside ``main``).
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
import argparse
|
| 37 |
+
import os
|
| 38 |
+
import tempfile
|
| 39 |
+
import time
|
| 40 |
+
|
| 41 |
+
import numpy as np
|
| 42 |
+
import torch
|
| 43 |
+
|
| 44 |
+
# Reused, task-specific glue from this repo (scene I/O, eval transforms, text
|
| 45 |
+
# embeddings) — we deliberately do NOT reuse predict_instances() wholesale
|
| 46 |
+
# because we also need the aligned backbone_pc coordinates, so we inline the
|
| 47 |
+
# forward below and call apply_post_processing() directly.
|
| 48 |
+
from labels import CLASS_LABELS_200, DEFAULT_CLASS_NAMES
|
| 49 |
+
from pipeline import (
|
| 50 |
+
POST_PROCESSING_CFG,
|
| 51 |
+
build_text_embeddings,
|
| 52 |
+
load_scene,
|
| 53 |
+
make_batch,
|
| 54 |
+
)
|
| 55 |
+
from postprocessing import apply_post_processing
|
| 56 |
+
|
| 57 |
+
HF_REPO_ID = os.environ.get("HF_REPO_ID", "chrischoy/SpaCeFormer")
|
| 58 |
+
HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# --------------------------------------------------------------------------- #
|
| 62 |
+
# Checkpoint + sample-scene resolution
|
| 63 |
+
# --------------------------------------------------------------------------- #
|
| 64 |
+
def resolve_ckpt(ckpt_arg):
|
| 65 |
+
"""Return a local checkpoint path: --ckpt, $SPACEFORMER_CKPT, or HF download."""
|
| 66 |
+
if ckpt_arg:
|
| 67 |
+
return ckpt_arg
|
| 68 |
+
explicit = os.environ.get("SPACEFORMER_CKPT")
|
| 69 |
+
if explicit:
|
| 70 |
+
return explicit
|
| 71 |
+
from huggingface_hub import hf_hub_download
|
| 72 |
+
|
| 73 |
+
print(f"[ckpt] downloading {HF_FILENAME} from HuggingFace repo {HF_REPO_ID} ...")
|
| 74 |
+
return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def resolve_sample_ply(ply_arg):
|
| 78 |
+
"""Return a path to a .ply. If --ply is unset, use a zero-config sample.
|
| 79 |
+
|
| 80 |
+
Preference order (no fragile external URLs):
|
| 81 |
+
1) open3d's bundled ``PLYPointCloud`` sample (a small RGB point cloud), or
|
| 82 |
+
2) a synthesized random RGB point cloud written to a temp .ply.
|
| 83 |
+
|
| 84 |
+
A random/sample cloud will NOT segment into meaningful instances — it only
|
| 85 |
+
demonstrates that the pipeline + visualization run end to end.
|
| 86 |
+
"""
|
| 87 |
+
if ply_arg:
|
| 88 |
+
return ply_arg
|
| 89 |
+
|
| 90 |
+
# 1) open3d bundled sample (downloaded/cached by open3d itself, offline after).
|
| 91 |
+
try:
|
| 92 |
+
import open3d as o3d
|
| 93 |
+
|
| 94 |
+
sample_path = o3d.data.PLYPointCloud().path
|
| 95 |
+
if os.path.isfile(sample_path):
|
| 96 |
+
print(f"[sample] using open3d bundled PLYPointCloud sample: {sample_path}")
|
| 97 |
+
print("[sample] NOTE: a generic sample cloud won't segment meaningfully; "
|
| 98 |
+
"it's only to demo the pipeline + viz.")
|
| 99 |
+
return sample_path
|
| 100 |
+
except Exception as exc: # noqa: BLE001 - any open3d issue -> fall back to synthetic
|
| 101 |
+
print(f"[sample] open3d sample unavailable ({exc}); synthesizing a random cloud.")
|
| 102 |
+
|
| 103 |
+
# 2) Synthesize a small random RGB point cloud and write a temp .ply.
|
| 104 |
+
rng = np.random.default_rng(0)
|
| 105 |
+
n = 20_000
|
| 106 |
+
coord = rng.uniform(-2.0, 2.0, size=(n, 3)).astype(np.float32) # meters
|
| 107 |
+
color = rng.integers(0, 256, size=(n, 3)).astype(np.uint8)
|
| 108 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".ply", delete=False)
|
| 109 |
+
tmp.close()
|
| 110 |
+
_write_ply(tmp.name, coord, color)
|
| 111 |
+
print(f"[sample] wrote synthetic random RGB cloud ({n} pts) to {tmp.name}")
|
| 112 |
+
print("[sample] NOTE: a random cloud won't segment meaningfully; it's only to "
|
| 113 |
+
"demo the pipeline + viz.")
|
| 114 |
+
return tmp.name
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _write_ply(path, coord, color):
|
| 118 |
+
"""Write an ASCII RGB .ply (plyfile if available, else open3d, else raw)."""
|
| 119 |
+
try:
|
| 120 |
+
from plyfile import PlyData, PlyElement
|
| 121 |
+
|
| 122 |
+
verts = np.empty(
|
| 123 |
+
coord.shape[0],
|
| 124 |
+
dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"),
|
| 125 |
+
("red", "u1"), ("green", "u1"), ("blue", "u1")],
|
| 126 |
+
)
|
| 127 |
+
verts["x"], verts["y"], verts["z"] = coord[:, 0], coord[:, 1], coord[:, 2]
|
| 128 |
+
verts["red"], verts["green"], verts["blue"] = color[:, 0], color[:, 1], color[:, 2]
|
| 129 |
+
PlyData([PlyElement.describe(verts, "vertex")], text=True).write(path)
|
| 130 |
+
return
|
| 131 |
+
except ImportError:
|
| 132 |
+
pass
|
| 133 |
+
import open3d as o3d
|
| 134 |
+
|
| 135 |
+
pcd = o3d.geometry.PointCloud()
|
| 136 |
+
pcd.points = o3d.utility.Vector3dVector(coord.astype(np.float64))
|
| 137 |
+
pcd.colors = o3d.utility.Vector3dVector(color.astype(np.float64) / 255.0)
|
| 138 |
+
o3d.io.write_point_cloud(path, pcd)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# --------------------------------------------------------------------------- #
|
| 142 |
+
# Forward + post-processing (aligned to backbone_pc coordinates)
|
| 143 |
+
# --------------------------------------------------------------------------- #
|
| 144 |
+
@torch.inference_mode()
|
| 145 |
+
def segment_aligned(net, batch, text_eval, class_names):
|
| 146 |
+
"""One forward pass -> post-processed instances aligned to backbone_pc coords.
|
| 147 |
+
|
| 148 |
+
Mirrors ``pipeline.predict_instances`` but ALSO returns the model's output
|
| 149 |
+
coordinates (``out["backbone_pc"].coordinates``), which are the coordinates
|
| 150 |
+
the returned masks index into (see module docstring / space_former_seg.py).
|
| 151 |
+
|
| 152 |
+
Returns ``(coords[M,3] float32, results)`` where each result is
|
| 153 |
+
``{mask: bool[M], label, label_id, score}`` and ``M`` is the backbone point
|
| 154 |
+
count (== number of columns in ``out["mask"][0]``), NOT the raw .ply count.
|
| 155 |
+
"""
|
| 156 |
+
out = net(batch)
|
| 157 |
+
|
| 158 |
+
# backbone_pc is present only in eval; its coordinates align 1:1 with the
|
| 159 |
+
# mask rows. batch size is 1 here, so every point belongs to sample 0.
|
| 160 |
+
backbone_pc = out["backbone_pc"]
|
| 161 |
+
coords = backbone_pc.coordinates.detach().cpu().numpy().astype(np.float32) # [M, 3]
|
| 162 |
+
|
| 163 |
+
binary_logits = out["logit"][0] # [Q, 2] objectness over {fg, bg}
|
| 164 |
+
mask_logits = out["mask"][0].T # [M, Q] -> [Q, M]
|
| 165 |
+
clip_feats = out["clip_feat"][0] # [Q, D]
|
| 166 |
+
pred_iou = out["pred_iou"][0] if "pred_iou" in out else None
|
| 167 |
+
|
| 168 |
+
# Sanity: mask columns must match the backbone point count we will color.
|
| 169 |
+
assert mask_logits.shape[1] == coords.shape[0], (
|
| 170 |
+
f"mask columns {mask_logits.shape[1]} != backbone points {coords.shape[0]}; "
|
| 171 |
+
"coord/mask alignment broken"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
class_logits = text_eval.predict(clip_feats, return_logit=True) # [Q, C]
|
| 175 |
+
|
| 176 |
+
masks, scores, _classes, indices = apply_post_processing(
|
| 177 |
+
mask_logits,
|
| 178 |
+
binary_logits,
|
| 179 |
+
mask_threshold=0.0,
|
| 180 |
+
point_coords=None,
|
| 181 |
+
pp_cfg=POST_PROCESSING_CFG,
|
| 182 |
+
pred_iou=pred_iou,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
results = []
|
| 186 |
+
if len(indices) > 0:
|
| 187 |
+
probs = torch.softmax(class_logits[indices], dim=-1) # [K, C]
|
| 188 |
+
class_probs, class_ids = probs.max(dim=1)
|
| 189 |
+
final_scores = scores * class_probs
|
| 190 |
+
for k in range(len(indices)):
|
| 191 |
+
results.append(
|
| 192 |
+
{
|
| 193 |
+
"mask": masks[k].cpu().numpy().astype(bool), # bool[M], aligned to coords
|
| 194 |
+
"label": class_names[int(class_ids[k])],
|
| 195 |
+
"label_id": int(class_ids[k]),
|
| 196 |
+
"score": float(final_scores[k]),
|
| 197 |
+
}
|
| 198 |
+
)
|
| 199 |
+
results.sort(key=lambda r: r["score"], reverse=True)
|
| 200 |
+
return coords, results
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# --------------------------------------------------------------------------- #
|
| 204 |
+
# Coloring + viser visualization
|
| 205 |
+
# --------------------------------------------------------------------------- #
|
| 206 |
+
def _instance_colors(coords, results, top_k, score_thresh):
|
| 207 |
+
"""Grey base cloud + a distinct random color per kept (top-k, thresholded) instance.
|
| 208 |
+
|
| 209 |
+
Returns ``(rgb[M,3] uint8, kept)`` where ``kept`` is the list of instances
|
| 210 |
+
actually colored (already sorted by score, high first).
|
| 211 |
+
"""
|
| 212 |
+
rgb = np.full((coords.shape[0], 3), 160, dtype=np.uint8) # grey background
|
| 213 |
+
kept = [r for r in results if r["score"] >= score_thresh][:top_k]
|
| 214 |
+
rng = np.random.default_rng(0)
|
| 215 |
+
for r in kept:
|
| 216 |
+
color = rng.integers(40, 230, size=3).astype(np.uint8)
|
| 217 |
+
r["color"] = tuple(int(c) for c in color) # stash for GUI legend
|
| 218 |
+
rgb[r["mask"]] = color
|
| 219 |
+
return rgb, kept
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def visualize(coords, results, port, top_k, score_thresh, point_size):
|
| 223 |
+
"""Start a viser server, add the colored point cloud + a GUI legend, block."""
|
| 224 |
+
import viser # imported lazily so --help works without viser installed
|
| 225 |
+
|
| 226 |
+
rgb, kept = _instance_colors(coords, results, top_k, score_thresh)
|
| 227 |
+
|
| 228 |
+
server = viser.ViserServer(port=port)
|
| 229 |
+
|
| 230 |
+
# Point cloud: positions from backbone_pc coords, colors per instance.
|
| 231 |
+
server.scene.add_point_cloud(
|
| 232 |
+
name="/scene",
|
| 233 |
+
points=coords.astype(np.float32),
|
| 234 |
+
colors=rgb, # uint8 [M, 3]
|
| 235 |
+
point_size=point_size,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# A small GUI panel listing the kept instances (label + score + #points).
|
| 239 |
+
with server.gui.add_folder(f"Top {len(kept)} instances"):
|
| 240 |
+
if not kept:
|
| 241 |
+
server.gui.add_markdown("_(no instances above the score threshold)_")
|
| 242 |
+
for i, r in enumerate(kept):
|
| 243 |
+
c = r["color"]
|
| 244 |
+
swatch = f'<span style="color:rgb({c[0]},{c[1]},{c[2]})">■</span>'
|
| 245 |
+
server.gui.add_markdown(
|
| 246 |
+
f"{swatch} **{r['label']}** — score {r['score']:.3f}, "
|
| 247 |
+
f"{int(r['mask'].sum())} pts"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
url = f"http://localhost:{port}"
|
| 251 |
+
print(f"\n[viser] serving at {url} (open this URL in your browser)")
|
| 252 |
+
print("[viser] press Ctrl-C to stop.")
|
| 253 |
+
try:
|
| 254 |
+
while True:
|
| 255 |
+
time.sleep(2.0)
|
| 256 |
+
except KeyboardInterrupt:
|
| 257 |
+
print("\n[viser] shutting down.")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def print_summary(results, top_k):
|
| 261 |
+
"""Ranked text summary of instances (label, score, #points) to stdout."""
|
| 262 |
+
print(f"\n=== {len(results)} predicted instances ===")
|
| 263 |
+
header = f"{'#':>3} {'score':>6} {'points':>8} label"
|
| 264 |
+
print(header)
|
| 265 |
+
print("-" * len(header))
|
| 266 |
+
for i, r in enumerate(results[:top_k]):
|
| 267 |
+
print(f"{i:>3} {r['score']:>6.3f} {int(r['mask'].sum()):>8} {r['label']}")
|
| 268 |
+
if len(results) > top_k:
|
| 269 |
+
print(f"... ({len(results) - top_k} more)")
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# --------------------------------------------------------------------------- #
|
| 273 |
+
# Entry point (linear top-to-bottom)
|
| 274 |
+
# --------------------------------------------------------------------------- #
|
| 275 |
+
def main():
|
| 276 |
+
ap = argparse.ArgumentParser(
|
| 277 |
+
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
| 278 |
+
)
|
| 279 |
+
ap.add_argument("--ply", default=None,
|
| 280 |
+
help="input point cloud .ply (default: an open3d bundled sample "
|
| 281 |
+
"or a synthesized random cloud)")
|
| 282 |
+
ap.add_argument("--class-names", nargs="+", default=None,
|
| 283 |
+
help="open-vocab class names (default: a short built-in list)")
|
| 284 |
+
ap.add_argument("--use-scannet200", action="store_true",
|
| 285 |
+
help="use the full ScanNet200 label set as class names")
|
| 286 |
+
ap.add_argument("--ckpt", default=None,
|
| 287 |
+
help="checkpoint path (else $SPACEFORMER_CKPT, else HF download)")
|
| 288 |
+
ap.add_argument("--iou-head", action="store_true",
|
| 289 |
+
help="build the learned IoU head (only for a checkpoint that has one)")
|
| 290 |
+
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 291 |
+
ap.add_argument("--port", type=int, default=8080, help="viser server port")
|
| 292 |
+
ap.add_argument("--score-thresh", type=float, default=0.0,
|
| 293 |
+
help="only color instances with score >= this")
|
| 294 |
+
ap.add_argument("--top-k", type=int, default=30,
|
| 295 |
+
help="max number of instances to color / list")
|
| 296 |
+
ap.add_argument("--point-size", type=float, default=0.02,
|
| 297 |
+
help="viser point size (world units / meters)")
|
| 298 |
+
args = ap.parse_args()
|
| 299 |
+
|
| 300 |
+
device = torch.device(args.device)
|
| 301 |
+
|
| 302 |
+
# Vocabulary (text class names — the open-vocabulary input).
|
| 303 |
+
if args.use_scannet200:
|
| 304 |
+
class_names = list(CLASS_LABELS_200)
|
| 305 |
+
else:
|
| 306 |
+
class_names = args.class_names or list(DEFAULT_CLASS_NAMES)
|
| 307 |
+
print(f"[vocab] {len(class_names)} classes: {', '.join(class_names[:8])}"
|
| 308 |
+
f"{' ...' if len(class_names) > 8 else ''}")
|
| 309 |
+
|
| 310 |
+
# 1) Resolve/load the scene (.ply -> coord[N,3] meters, color[N,3] 0-255).
|
| 311 |
+
ply_path = resolve_sample_ply(args.ply)
|
| 312 |
+
coord_np, color_np = load_scene(ply_path)
|
| 313 |
+
|
| 314 |
+
# 2) Build the eval batch (CenterShift + NormalizeColor + offsets).
|
| 315 |
+
batch = make_batch(coord_np, color_np, device)
|
| 316 |
+
|
| 317 |
+
# 3) Build model + load released weights (WarpConvNet required here).
|
| 318 |
+
from warpconvnet.models.spaceformer import (
|
| 319 |
+
build_spaceformer,
|
| 320 |
+
load_spaceformer_checkpoint,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
net = build_spaceformer(use_iou_head=args.iou_head, device=device)
|
| 324 |
+
missing, unexpected = load_spaceformer_checkpoint(net, resolve_ckpt(args.ckpt))
|
| 325 |
+
print(f"[weights] {len(missing)} missing, {len(unexpected)} unexpected")
|
| 326 |
+
|
| 327 |
+
# 4) Encode the text class names (SigLIP2 + prompt ensembling).
|
| 328 |
+
text_eval = build_text_embeddings(class_names, device)
|
| 329 |
+
|
| 330 |
+
# 5) Forward + post-process -> masks + labels + scores aligned to backbone coords.
|
| 331 |
+
coords, results = segment_aligned(net, batch, text_eval, class_names)
|
| 332 |
+
print(f"[model] backbone output points: {coords.shape[0]} "
|
| 333 |
+
f"(raw input was {coord_np.shape[0]})")
|
| 334 |
+
|
| 335 |
+
# 6) Text summary + interactive viser visualization.
|
| 336 |
+
print_summary(results, args.top_k)
|
| 337 |
+
visualize(coords, results, args.port, args.top_k, args.score_thresh, args.point_size)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
if __name__ == "__main__":
|
| 341 |
+
main()
|
demo/inference.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""SpaceFormer single-scene open-vocabulary 3D instance segmentation (CLI inference).
|
| 4 |
+
|
| 5 |
+
Self-contained command-line entry point for the released
|
| 6 |
+
SpaceFormer. The model + inference pipeline come from the installed ``warpconvnet``
|
| 7 |
+
library (``warpconvnet.models.spaceformer``); this script only wires up checkpoint
|
| 8 |
+
resolution (local path or HuggingFace auto-download), runs one scene, and
|
| 9 |
+
prints/saves the predicted instances.
|
| 10 |
+
|
| 11 |
+
# weights from a local checkpoint
|
| 12 |
+
python inference.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
|
| 13 |
+
--scene /path/to/scene_dir
|
| 14 |
+
|
| 15 |
+
# or auto-download from a HuggingFace model repo
|
| 16 |
+
HF_REPO_ID=chrischoy/SpaCeFormer python inference.py \
|
| 17 |
+
--scene my_scene.ply --class-names "office chair" "desk" "monitor" "other"
|
| 18 |
+
|
| 19 |
+
Requires WarpConvNet (with its compiled extension) + transformers in the environment.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
from warpconvnet.models.spaceformer import (
|
| 28 |
+
build_spaceformer,
|
| 29 |
+
load_spaceformer_checkpoint,
|
| 30 |
+
)
|
| 31 |
+
from labels import CLASS_LABELS_200, DEFAULT_CLASS_NAMES
|
| 32 |
+
from pipeline import (
|
| 33 |
+
build_text_embeddings,
|
| 34 |
+
load_scene,
|
| 35 |
+
make_batch,
|
| 36 |
+
predict_instances,
|
| 37 |
+
print_summary,
|
| 38 |
+
save_results,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def resolve_ckpt(ckpt_arg: str | None) -> str:
|
| 45 |
+
"""Return a local checkpoint path: explicit --ckpt, $SPACEFORMER_CKPT, or HF download."""
|
| 46 |
+
if ckpt_arg:
|
| 47 |
+
return ckpt_arg
|
| 48 |
+
explicit = os.environ.get("SPACEFORMER_CKPT")
|
| 49 |
+
if explicit:
|
| 50 |
+
return explicit
|
| 51 |
+
repo_id = os.environ.get("HF_REPO_ID")
|
| 52 |
+
if not repo_id:
|
| 53 |
+
raise SystemExit(
|
| 54 |
+
"No checkpoint: pass --ckpt, or set SPACEFORMER_CKPT (local path) or "
|
| 55 |
+
"HF_REPO_ID (HuggingFace model repo to download from)."
|
| 56 |
+
)
|
| 57 |
+
from huggingface_hub import hf_hub_download
|
| 58 |
+
return hf_hub_download(repo_id=repo_id, filename=HF_FILENAME)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main() -> None:
|
| 62 |
+
ap = argparse.ArgumentParser(description=__doc__,
|
| 63 |
+
formatter_class=argparse.RawDescriptionHelpFormatter)
|
| 64 |
+
ap.add_argument("--ckpt", default=None,
|
| 65 |
+
help="checkpoint path (else $SPACEFORMER_CKPT, else HF download via $HF_REPO_ID)")
|
| 66 |
+
ap.add_argument("--scene", required=True,
|
| 67 |
+
help="scene dir (coord.npy+color.npy) or .npy/.npz/.ply file")
|
| 68 |
+
ap.add_argument("--class-names", nargs="+", default=None,
|
| 69 |
+
help="open-vocab class names (default: a short built-in list)")
|
| 70 |
+
ap.add_argument("--use-scannet200", action="store_true",
|
| 71 |
+
help="use the full ScanNet200 label set as class names")
|
| 72 |
+
ap.add_argument("--iou-head", action="store_true",
|
| 73 |
+
help="build the learned IoU head (only for a checkpoint that has one)")
|
| 74 |
+
ap.add_argument("--save", default=None, help="optional output .npz path")
|
| 75 |
+
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
| 76 |
+
args = ap.parse_args()
|
| 77 |
+
|
| 78 |
+
device = torch.device(args.device)
|
| 79 |
+
if args.use_scannet200:
|
| 80 |
+
class_names = list(CLASS_LABELS_200)
|
| 81 |
+
else:
|
| 82 |
+
class_names = args.class_names or list(DEFAULT_CLASS_NAMES)
|
| 83 |
+
print(f"[vocab] {len(class_names)} classes")
|
| 84 |
+
|
| 85 |
+
net = build_spaceformer(use_iou_head=args.iou_head, device=device)
|
| 86 |
+
missing, unexpected = load_spaceformer_checkpoint(net, resolve_ckpt(args.ckpt))
|
| 87 |
+
print(f"[weights] {len(missing)} missing, {len(unexpected)} unexpected")
|
| 88 |
+
|
| 89 |
+
coord_np, color_np = load_scene(args.scene)
|
| 90 |
+
batch = make_batch(coord_np, color_np, device)
|
| 91 |
+
text_eval = build_text_embeddings(class_names, device)
|
| 92 |
+
results = predict_instances(net, batch, text_eval, class_names)
|
| 93 |
+
|
| 94 |
+
print_summary(results)
|
| 95 |
+
if args.save:
|
| 96 |
+
save_results(results, coord_np, args.save)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
main()
|
demo/labels.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Label sets and prompt templates for open-vocabulary SpaceFormer evaluation."""
|
| 4 |
+
|
| 5 |
+
# Prompt-ensembling templates (LEVER 1). Each must contain a single positional
|
| 6 |
+
# ``{}`` placeholder. Averaging class embeddings across these is a confirmed
|
| 7 |
+
# eval-time accuracy win for the released checkpoint.
|
| 8 |
+
PROMPT_TEMPLATES = (
|
| 9 |
+
"a {} in a scene",
|
| 10 |
+
"a photo of a {} in a scene",
|
| 11 |
+
"a {}",
|
| 12 |
+
"a photo of a {}",
|
| 13 |
+
"there is a {} in the scene",
|
| 14 |
+
"a 3d model of a {}",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# A short, readable default vocabulary for the demo. Pass your own class names to
|
| 18 |
+
# override, or use CLASS_LABELS_200 for the full ScanNet200 benchmark label set.
|
| 19 |
+
DEFAULT_CLASS_NAMES = (
|
| 20 |
+
"wall",
|
| 21 |
+
"floor",
|
| 22 |
+
"ceiling",
|
| 23 |
+
"chair",
|
| 24 |
+
"table",
|
| 25 |
+
"desk",
|
| 26 |
+
"couch",
|
| 27 |
+
"bed",
|
| 28 |
+
"cabinet",
|
| 29 |
+
"shelf",
|
| 30 |
+
"door",
|
| 31 |
+
"window",
|
| 32 |
+
"monitor",
|
| 33 |
+
"keyboard",
|
| 34 |
+
"lamp",
|
| 35 |
+
"picture",
|
| 36 |
+
"whiteboard",
|
| 37 |
+
"trash can",
|
| 38 |
+
"backpack",
|
| 39 |
+
"plant",
|
| 40 |
+
"other",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# The 200 ScanNet200 instance/semantic class names (benchmark order).
|
| 44 |
+
CLASS_LABELS_200 = (
|
| 45 |
+
"wall",
|
| 46 |
+
"chair",
|
| 47 |
+
"floor",
|
| 48 |
+
"table",
|
| 49 |
+
"door",
|
| 50 |
+
"couch",
|
| 51 |
+
"cabinet",
|
| 52 |
+
"shelf",
|
| 53 |
+
"desk",
|
| 54 |
+
"office chair",
|
| 55 |
+
"bed",
|
| 56 |
+
"pillow",
|
| 57 |
+
"sink",
|
| 58 |
+
"picture",
|
| 59 |
+
"window",
|
| 60 |
+
"toilet",
|
| 61 |
+
"bookshelf",
|
| 62 |
+
"monitor",
|
| 63 |
+
"curtain",
|
| 64 |
+
"book",
|
| 65 |
+
"armchair",
|
| 66 |
+
"coffee table",
|
| 67 |
+
"box",
|
| 68 |
+
"refrigerator",
|
| 69 |
+
"lamp",
|
| 70 |
+
"kitchen cabinet",
|
| 71 |
+
"towel",
|
| 72 |
+
"clothes",
|
| 73 |
+
"tv",
|
| 74 |
+
"nightstand",
|
| 75 |
+
"counter",
|
| 76 |
+
"dresser",
|
| 77 |
+
"stool",
|
| 78 |
+
"cushion",
|
| 79 |
+
"plant",
|
| 80 |
+
"ceiling",
|
| 81 |
+
"bathtub",
|
| 82 |
+
"end table",
|
| 83 |
+
"dining table",
|
| 84 |
+
"keyboard",
|
| 85 |
+
"bag",
|
| 86 |
+
"backpack",
|
| 87 |
+
"toilet paper",
|
| 88 |
+
"printer",
|
| 89 |
+
"tv stand",
|
| 90 |
+
"whiteboard",
|
| 91 |
+
"blanket",
|
| 92 |
+
"shower curtain",
|
| 93 |
+
"trash can",
|
| 94 |
+
"closet",
|
| 95 |
+
"stairs",
|
| 96 |
+
"microwave",
|
| 97 |
+
"stove",
|
| 98 |
+
"shoe",
|
| 99 |
+
"computer tower",
|
| 100 |
+
"bottle",
|
| 101 |
+
"bin",
|
| 102 |
+
"ottoman",
|
| 103 |
+
"bench",
|
| 104 |
+
"board",
|
| 105 |
+
"washing machine",
|
| 106 |
+
"mirror",
|
| 107 |
+
"copier",
|
| 108 |
+
"basket",
|
| 109 |
+
"sofa chair",
|
| 110 |
+
"file cabinet",
|
| 111 |
+
"fan",
|
| 112 |
+
"laptop",
|
| 113 |
+
"shower",
|
| 114 |
+
"paper",
|
| 115 |
+
"person",
|
| 116 |
+
"paper towel dispenser",
|
| 117 |
+
"oven",
|
| 118 |
+
"blinds",
|
| 119 |
+
"rack",
|
| 120 |
+
"plate",
|
| 121 |
+
"blackboard",
|
| 122 |
+
"piano",
|
| 123 |
+
"suitcase",
|
| 124 |
+
"rail",
|
| 125 |
+
"radiator",
|
| 126 |
+
"recycling bin",
|
| 127 |
+
"container",
|
| 128 |
+
"wardrobe",
|
| 129 |
+
"soap dispenser",
|
| 130 |
+
"telephone",
|
| 131 |
+
"bucket",
|
| 132 |
+
"clock",
|
| 133 |
+
"stand",
|
| 134 |
+
"light",
|
| 135 |
+
"laundry basket",
|
| 136 |
+
"pipe",
|
| 137 |
+
"clothes dryer",
|
| 138 |
+
"guitar",
|
| 139 |
+
"toilet paper holder",
|
| 140 |
+
"seat",
|
| 141 |
+
"speaker",
|
| 142 |
+
"column",
|
| 143 |
+
"bicycle",
|
| 144 |
+
"ladder",
|
| 145 |
+
"bathroom stall",
|
| 146 |
+
"shower wall",
|
| 147 |
+
"cup",
|
| 148 |
+
"jacket",
|
| 149 |
+
"storage bin",
|
| 150 |
+
"coffee maker",
|
| 151 |
+
"dishwasher",
|
| 152 |
+
"paper towel roll",
|
| 153 |
+
"machine",
|
| 154 |
+
"mat",
|
| 155 |
+
"windowsill",
|
| 156 |
+
"bar",
|
| 157 |
+
"toaster",
|
| 158 |
+
"bulletin board",
|
| 159 |
+
"ironing board",
|
| 160 |
+
"fireplace",
|
| 161 |
+
"soap dish",
|
| 162 |
+
"kitchen counter",
|
| 163 |
+
"doorframe",
|
| 164 |
+
"toilet paper dispenser",
|
| 165 |
+
"mini fridge",
|
| 166 |
+
"fire extinguisher",
|
| 167 |
+
"ball",
|
| 168 |
+
"hat",
|
| 169 |
+
"shower curtain rod",
|
| 170 |
+
"water cooler",
|
| 171 |
+
"paper cutter",
|
| 172 |
+
"tray",
|
| 173 |
+
"shower door",
|
| 174 |
+
"pillar",
|
| 175 |
+
"ledge",
|
| 176 |
+
"toaster oven",
|
| 177 |
+
"mouse",
|
| 178 |
+
"toilet seat cover dispenser",
|
| 179 |
+
"furniture",
|
| 180 |
+
"cart",
|
| 181 |
+
"storage container",
|
| 182 |
+
"scale",
|
| 183 |
+
"tissue box",
|
| 184 |
+
"light switch",
|
| 185 |
+
"crate",
|
| 186 |
+
"power outlet",
|
| 187 |
+
"decoration",
|
| 188 |
+
"sign",
|
| 189 |
+
"projector",
|
| 190 |
+
"closet door",
|
| 191 |
+
"vacuum cleaner",
|
| 192 |
+
"candle",
|
| 193 |
+
"plunger",
|
| 194 |
+
"stuffed animal",
|
| 195 |
+
"headphones",
|
| 196 |
+
"dish rack",
|
| 197 |
+
"broom",
|
| 198 |
+
"guitar case",
|
| 199 |
+
"range hood",
|
| 200 |
+
"dustpan",
|
| 201 |
+
"hair dryer",
|
| 202 |
+
"water bottle",
|
| 203 |
+
"handicap bar",
|
| 204 |
+
"purse",
|
| 205 |
+
"vent",
|
| 206 |
+
"shower floor",
|
| 207 |
+
"water pitcher",
|
| 208 |
+
"mailbox",
|
| 209 |
+
"bowl",
|
| 210 |
+
"paper bag",
|
| 211 |
+
"alarm clock",
|
| 212 |
+
"music stand",
|
| 213 |
+
"projector screen",
|
| 214 |
+
"divider",
|
| 215 |
+
"laundry detergent",
|
| 216 |
+
"bathroom counter",
|
| 217 |
+
"object",
|
| 218 |
+
"bathroom vanity",
|
| 219 |
+
"closet wall",
|
| 220 |
+
"laundry hamper",
|
| 221 |
+
"bathroom stall door",
|
| 222 |
+
"ceiling light",
|
| 223 |
+
"trash bin",
|
| 224 |
+
"dumbbell",
|
| 225 |
+
"stair rail",
|
| 226 |
+
"tube",
|
| 227 |
+
"bathroom cabinet",
|
| 228 |
+
"cd case",
|
| 229 |
+
"closet rod",
|
| 230 |
+
"coffee kettle",
|
| 231 |
+
"structure",
|
| 232 |
+
"shower head",
|
| 233 |
+
"keyboard piano",
|
| 234 |
+
"case of water bottles",
|
| 235 |
+
"coat rack",
|
| 236 |
+
"storage organizer",
|
| 237 |
+
"folded chair",
|
| 238 |
+
"fire alarm",
|
| 239 |
+
"power strip",
|
| 240 |
+
"calendar",
|
| 241 |
+
"poster",
|
| 242 |
+
"potted plant",
|
| 243 |
+
"luggage",
|
| 244 |
+
"mattress",
|
| 245 |
+
)
|
demo/pipeline.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""Open-vocabulary inference pipeline for the SpaceFormer release.
|
| 4 |
+
|
| 5 |
+
Task-specific glue that turns the raw model outputs (logit / mask / clip_feat,
|
| 6 |
+
from ``warpconvnet.models.spaceformer.SpaCeFormerInstSeg``) into labeled instances:
|
| 7 |
+
scene I/O + minimal eval transforms, SigLIP2 text embedding (prompt-ensembled),
|
| 8 |
+
and SAM2-style mask post-processing. Kept in the release repo (not WarpConvNet),
|
| 9 |
+
since labeling/post-processing is downstream of the model.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from postprocessing import apply_post_processing
|
| 18 |
+
from labels import PROMPT_TEMPLATES
|
| 19 |
+
|
| 20 |
+
# SigLIP2 text encoder used at training time. text_dim 1152 == model clip-head dim.
|
| 21 |
+
SIGLIP_MODEL_ID = "google/siglip2-so400m-patch14-224"
|
| 22 |
+
|
| 23 |
+
# Voxel size the model was trained/exported at (the model voxelizes internally;
|
| 24 |
+
# this only labels the data_dict for parity).
|
| 25 |
+
GRID_SIZE = 0.02
|
| 26 |
+
|
| 27 |
+
# Release eval post-processing: NMS on, min 20 pts/mask, DBSCAN/stability off.
|
| 28 |
+
POST_PROCESSING_CFG = {
|
| 29 |
+
"use_dbscan": False,
|
| 30 |
+
"use_stability_score": False,
|
| 31 |
+
"use_nms": True,
|
| 32 |
+
"nms_thresh": 0.7,
|
| 33 |
+
"min_mask_points": 20,
|
| 34 |
+
"objectness_thresh": 0.0,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# --------------------------------------------------------------------------- #
|
| 39 |
+
# Scene loading + minimal eval transforms
|
| 40 |
+
# --------------------------------------------------------------------------- #
|
| 41 |
+
def load_scene(scene_path: str):
|
| 42 |
+
"""Load one scene as (coord[N,3] float meters, color[N,3] float 0-255)."""
|
| 43 |
+
if os.path.isdir(scene_path):
|
| 44 |
+
coord = np.load(os.path.join(scene_path, "coord.npy"))
|
| 45 |
+
color = np.load(os.path.join(scene_path, "color.npy"))
|
| 46 |
+
elif scene_path.endswith(".npz"):
|
| 47 |
+
z = np.load(scene_path)
|
| 48 |
+
coord = z[_first_key(z, ("coord", "coords", "xyz", "points"))]
|
| 49 |
+
color = z[_first_key(z, ("color", "colors", "rgb"))]
|
| 50 |
+
elif scene_path.endswith(".npy"):
|
| 51 |
+
arr = np.load(scene_path)
|
| 52 |
+
assert arr.ndim == 2 and arr.shape[1] >= 6, "expected [N,6] xyz+rgb .npy"
|
| 53 |
+
coord, color = arr[:, :3], arr[:, 3:6]
|
| 54 |
+
elif scene_path.endswith(".ply"):
|
| 55 |
+
coord, color = _load_ply(scene_path)
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"unsupported scene format: {scene_path}")
|
| 58 |
+
|
| 59 |
+
coord = np.ascontiguousarray(coord, dtype=np.float32)
|
| 60 |
+
color = np.ascontiguousarray(color)
|
| 61 |
+
if color.dtype != np.uint8 and color.max() <= 1.0 + 1e-6:
|
| 62 |
+
color = (color * 255.0).round()
|
| 63 |
+
color = color.astype(np.float32)
|
| 64 |
+
print(f"[scene] {scene_path}: {coord.shape[0]} points")
|
| 65 |
+
return coord, color
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _first_key(z, candidates):
|
| 69 |
+
for c in candidates:
|
| 70 |
+
if c in z:
|
| 71 |
+
return c
|
| 72 |
+
raise KeyError(f"none of {candidates} in {list(z.keys())}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _load_ply(path: str):
|
| 76 |
+
try:
|
| 77 |
+
from plyfile import PlyData
|
| 78 |
+
|
| 79 |
+
v = PlyData.read(path)["vertex"]
|
| 80 |
+
coord = np.stack([v["x"], v["y"], v["z"]], axis=1)
|
| 81 |
+
if "red" in v.data.dtype.names:
|
| 82 |
+
color = np.stack([v["red"], v["green"], v["blue"]], axis=1)
|
| 83 |
+
else:
|
| 84 |
+
color = np.full_like(coord, 127.5)
|
| 85 |
+
return coord, color
|
| 86 |
+
except ImportError:
|
| 87 |
+
import open3d as o3d
|
| 88 |
+
|
| 89 |
+
pcd = o3d.io.read_point_cloud(path)
|
| 90 |
+
coord = np.asarray(pcd.points)
|
| 91 |
+
color = np.asarray(pcd.colors) * 255.0 if pcd.has_colors() else np.full_like(coord, 127.5)
|
| 92 |
+
return coord, color
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def make_batch(coord_np, color_np, device):
|
| 96 |
+
"""Apply the minimal eval transforms and build the single-sample data dict.
|
| 97 |
+
|
| 98 |
+
CenterShift(apply_z) + NormalizeColor(0-255 -> [-1,1]) + offset [0, N].
|
| 99 |
+
Coords stay in meters — the model voxelizes internally at 2 cm.
|
| 100 |
+
"""
|
| 101 |
+
coord = torch.from_numpy(coord_np).float()
|
| 102 |
+
color = torch.from_numpy(color_np).float()
|
| 103 |
+
|
| 104 |
+
cmin = coord.min(dim=0).values
|
| 105 |
+
cmax = coord.max(dim=0).values
|
| 106 |
+
shift = torch.tensor([(cmin[0] + cmax[0]) / 2, (cmin[1] + cmax[1]) / 2, cmin[2]])
|
| 107 |
+
coord = coord - shift
|
| 108 |
+
feat = color / 127.5 - 1.0
|
| 109 |
+
n = coord.shape[0]
|
| 110 |
+
offset = torch.tensor([0, n], dtype=torch.int32)
|
| 111 |
+
return {
|
| 112 |
+
"coord": coord.to(device),
|
| 113 |
+
"feat": feat.to(device),
|
| 114 |
+
"offset": offset.to(device),
|
| 115 |
+
"grid_size": GRID_SIZE,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# --------------------------------------------------------------------------- #
|
| 120 |
+
# Text embeddings (prompt-ensembled) + prediction
|
| 121 |
+
# --------------------------------------------------------------------------- #
|
| 122 |
+
def build_text_embeddings(class_names, device):
|
| 123 |
+
"""Encode class names with SigLIP2 under multiple templates and average."""
|
| 124 |
+
from text_encoder import get_text_encoder
|
| 125 |
+
from clip_eval import CLIPAlignmentEval
|
| 126 |
+
|
| 127 |
+
clip_encoder = get_text_encoder(model_type="siglip2", model_id=SIGLIP_MODEL_ID, device=str(device))
|
| 128 |
+
evaluator = CLIPAlignmentEval(normalize_input=False) # matches official eval
|
| 129 |
+
evaluator.prepare_target_embedding(
|
| 130 |
+
class_names=list(class_names),
|
| 131 |
+
clip_encoder=clip_encoder,
|
| 132 |
+
device=device,
|
| 133 |
+
prompt_templates=list(PROMPT_TEMPLATES), # prompt ensembling ON
|
| 134 |
+
)
|
| 135 |
+
return evaluator
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@torch.inference_mode()
|
| 139 |
+
def predict_instances(net, batch, text_eval, class_names):
|
| 140 |
+
"""Single forward pass -> post-processing -> open-vocab labels."""
|
| 141 |
+
out = net(batch)
|
| 142 |
+
binary_logits = out["logit"][0] # [Q, 2] objectness over {fg, bg}
|
| 143 |
+
mask_logits = out["mask"][0].T # [N, Q] -> [Q, N]
|
| 144 |
+
clip_feats = out["clip_feat"][0] # [Q, D]
|
| 145 |
+
pred_iou = out["pred_iou"][0] if "pred_iou" in out else None
|
| 146 |
+
|
| 147 |
+
class_logits = text_eval.predict(clip_feats, return_logit=True) # [Q, C]
|
| 148 |
+
|
| 149 |
+
masks, scores, _classes, indices = apply_post_processing(
|
| 150 |
+
mask_logits,
|
| 151 |
+
binary_logits,
|
| 152 |
+
mask_threshold=0.0,
|
| 153 |
+
point_coords=None,
|
| 154 |
+
pp_cfg=POST_PROCESSING_CFG,
|
| 155 |
+
pred_iou=pred_iou,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
results = []
|
| 159 |
+
if len(indices) > 0:
|
| 160 |
+
probs = torch.softmax(class_logits[indices], dim=-1) # [K, C]
|
| 161 |
+
class_probs, class_ids = probs.max(dim=1)
|
| 162 |
+
final_scores = scores * class_probs
|
| 163 |
+
for k in range(len(indices)):
|
| 164 |
+
results.append(
|
| 165 |
+
{
|
| 166 |
+
"mask": masks[k].cpu().numpy().astype(bool),
|
| 167 |
+
"label": class_names[int(class_ids[k])],
|
| 168 |
+
"label_id": int(class_ids[k]),
|
| 169 |
+
"score": float(final_scores[k]),
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
results.sort(key=lambda r: r["score"], reverse=True)
|
| 173 |
+
return results
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# --------------------------------------------------------------------------- #
|
| 177 |
+
# Output helpers
|
| 178 |
+
# --------------------------------------------------------------------------- #
|
| 179 |
+
def print_summary(results, top_k=20):
|
| 180 |
+
print(f"\n=== {len(results)} predicted instances ===")
|
| 181 |
+
header = f"{'#':>3} {'score':>6} {'points':>8} label"
|
| 182 |
+
print(header)
|
| 183 |
+
print("-" * len(header))
|
| 184 |
+
for i, r in enumerate(results[:top_k]):
|
| 185 |
+
print(f"{i:>3} {r['score']:>6.3f} {int(r['mask'].sum()):>8} {r['label']}")
|
| 186 |
+
if len(results) > top_k:
|
| 187 |
+
print(f"... ({len(results) - top_k} more)")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def save_results(results, coord_np, out_path):
|
| 191 |
+
if not results:
|
| 192 |
+
np.savez(out_path, masks=np.zeros((0, coord_np.shape[0]), dtype=bool),
|
| 193 |
+
labels=np.array([]), scores=np.array([]))
|
| 194 |
+
return
|
| 195 |
+
np.savez(
|
| 196 |
+
out_path,
|
| 197 |
+
coord=coord_np,
|
| 198 |
+
masks=np.stack([r["mask"] for r in results]),
|
| 199 |
+
labels=np.array([r["label"] for r in results]),
|
| 200 |
+
label_ids=np.array([r["label_id"] for r in results]),
|
| 201 |
+
scores=np.array([r["score"] for r in results]),
|
| 202 |
+
)
|
| 203 |
+
print(f"[save] wrote {len(results)} instances to {out_path}")
|
demo/postprocessing.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
"""
|
| 4 |
+
SAM2-style post-processing utilities for mask segmentation.
|
| 5 |
+
|
| 6 |
+
This module provides shared post-processing functions used by both the
|
| 7 |
+
MaskLanguageLitModule (validation/testing) and the demo script.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Tuple, Optional, Dict
|
| 11 |
+
import time
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from cuml.cluster import DBSCAN
|
| 18 |
+
except ImportError:
|
| 19 |
+
DBSCAN = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def calculate_stability_score(
|
| 23 |
+
masks: torch.Tensor,
|
| 24 |
+
mask_threshold: float = 0.0,
|
| 25 |
+
threshold_offset: float = 1.0,
|
| 26 |
+
) -> torch.Tensor:
|
| 27 |
+
"""
|
| 28 |
+
Computes the stability score for a set of masks.
|
| 29 |
+
|
| 30 |
+
The stability score is the IoU between the binary masks obtained by
|
| 31 |
+
thresholding at (mask_threshold + threshold_offset) and
|
| 32 |
+
(mask_threshold - threshold_offset).
|
| 33 |
+
|
| 34 |
+
High stability means sharp mask boundaries.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
masks: [Q, N] mask logits
|
| 38 |
+
mask_threshold: Base threshold (usually 0.0 for logits)
|
| 39 |
+
threshold_offset: Offset to apply for high/low thresholds
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
stability_score: [Q] stability score per mask
|
| 43 |
+
"""
|
| 44 |
+
high_thresh_mask = masks > (mask_threshold + threshold_offset)
|
| 45 |
+
low_thresh_mask = masks > (mask_threshold - threshold_offset)
|
| 46 |
+
|
| 47 |
+
intersection = high_thresh_mask.float().sum(-1)
|
| 48 |
+
union = low_thresh_mask.float().sum(-1)
|
| 49 |
+
|
| 50 |
+
stability_score = intersection / (union + 1e-6)
|
| 51 |
+
return stability_score
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def apply_nms(
|
| 55 |
+
masks_binary: torch.Tensor,
|
| 56 |
+
scores: torch.Tensor,
|
| 57 |
+
nms_thresh: float = 0.7,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Applies greedy NMS on masks using pairwise IoU.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
masks_binary: [Q, N] binary masks (booleans or 0/1 floats)
|
| 64 |
+
scores: [Q] mask scores for ranking
|
| 65 |
+
nms_thresh: IoU threshold for suppression
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
keep_indices: Tensor of indices to keep after NMS
|
| 69 |
+
"""
|
| 70 |
+
# Sort by score descending
|
| 71 |
+
order = torch.argsort(scores, descending=True)
|
| 72 |
+
masks_binary = masks_binary.bool()
|
| 73 |
+
|
| 74 |
+
keep = []
|
| 75 |
+
indices = order
|
| 76 |
+
|
| 77 |
+
while indices.numel() > 0:
|
| 78 |
+
current = indices[0]
|
| 79 |
+
keep.append(current.item())
|
| 80 |
+
|
| 81 |
+
if indices.numel() == 1:
|
| 82 |
+
break
|
| 83 |
+
|
| 84 |
+
# Compare current mask with rest
|
| 85 |
+
current_mask = masks_binary[current].unsqueeze(0) # [1, N]
|
| 86 |
+
rest_indices = indices[1:]
|
| 87 |
+
rest_masks = masks_binary[rest_indices] # [K, N]
|
| 88 |
+
|
| 89 |
+
intersection = (current_mask & rest_masks).float().sum(dim=1)
|
| 90 |
+
union = (current_mask | rest_masks).float().sum(dim=1)
|
| 91 |
+
iou = intersection / (union + 1e-6)
|
| 92 |
+
|
| 93 |
+
# Keep masks with IoU < thresh
|
| 94 |
+
mask_keep = iou < nms_thresh
|
| 95 |
+
indices = rest_indices[mask_keep]
|
| 96 |
+
|
| 97 |
+
return torch.tensor(keep, device=masks_binary.device, dtype=torch.long)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def apply_dbscan_clustering(
|
| 101 |
+
current_masks: torch.Tensor,
|
| 102 |
+
point_coords: torch.Tensor,
|
| 103 |
+
current_scores: torch.Tensor,
|
| 104 |
+
current_classes: torch.Tensor,
|
| 105 |
+
eps: float = 0.95,
|
| 106 |
+
min_samples: int = 1,
|
| 107 |
+
backend: str = "auto",
|
| 108 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 109 |
+
"""
|
| 110 |
+
Applies DBSCAN to each mask to split spatially disconnected components.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
current_masks: [Q, N] boolean masks
|
| 114 |
+
point_coords: [N, 3] point coordinates
|
| 115 |
+
current_scores: [Q] scores
|
| 116 |
+
current_classes: [Q] classes
|
| 117 |
+
eps: DBSCAN eps parameter
|
| 118 |
+
min_samples: DBSCAN min_samples parameter
|
| 119 |
+
backend: "auto", "cuml", or "cpu"
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
new_masks: [Q', N] expanded boolean masks
|
| 123 |
+
new_scores: [Q'] expanded scores
|
| 124 |
+
new_classes: [Q'] expanded classes
|
| 125 |
+
new_indices: [Q'] indices mapping to original queries
|
| 126 |
+
"""
|
| 127 |
+
# 0. Size check (Performance optimization) - REMOVED GLOBAL CHECK
|
| 128 |
+
# if point_coords.shape[0] > 100000:
|
| 129 |
+
# print(f"DBSCAN: Skipping due to large point cloud ({point_coords.shape[0]} points > 100k)")
|
| 130 |
+
# return current_masks, current_scores, current_classes
|
| 131 |
+
|
| 132 |
+
# 1. Determine Backend
|
| 133 |
+
use_cuml = False
|
| 134 |
+
if backend == "auto":
|
| 135 |
+
use_cuml = DBSCAN is not None
|
| 136 |
+
elif backend == "cuml":
|
| 137 |
+
if DBSCAN is None:
|
| 138 |
+
print("Warning: backend='cuml' requested but cuML not found. Falling back to CPU.")
|
| 139 |
+
use_cuml = False
|
| 140 |
+
else:
|
| 141 |
+
use_cuml = True
|
| 142 |
+
elif backend == "cpu":
|
| 143 |
+
use_cuml = False
|
| 144 |
+
|
| 145 |
+
device = current_masks.device
|
| 146 |
+
num_queries = current_masks.shape[0]
|
| 147 |
+
|
| 148 |
+
# Initialize lists to hold the new split masks
|
| 149 |
+
new_masks_list = []
|
| 150 |
+
# We'll store indices pointing to original scores/classes to avoid duplicating them early
|
| 151 |
+
new_indices_list = []
|
| 152 |
+
|
| 153 |
+
# 2. Execution Path
|
| 154 |
+
if use_cuml:
|
| 155 |
+
# --- cuML (GPU) Path ---
|
| 156 |
+
# print(f"DBSCAN (cuML): Processing {point_coords.shape[0]} points")
|
| 157 |
+
|
| 158 |
+
# Ensure data is on GPU and valid types
|
| 159 |
+
# cuML DBSCAN expects input of shape (n_samples, n_features)
|
| 160 |
+
# We process each mask independently.
|
| 161 |
+
|
| 162 |
+
# Optimization: To avoid loop overhead, we could try to batch, but DBSCAN isn't batched.
|
| 163 |
+
# We iterate over queries.
|
| 164 |
+
|
| 165 |
+
for i in range(num_queries):
|
| 166 |
+
mask = current_masks[i]
|
| 167 |
+
|
| 168 |
+
# Skip empty masks
|
| 169 |
+
if not mask.any():
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
# Filter points for this mask
|
| 173 |
+
# mask is [N], point_coords is [N, 3]
|
| 174 |
+
# Slicing creates a new tensor on GPU
|
| 175 |
+
points = point_coords[mask]
|
| 176 |
+
|
| 177 |
+
# Check per-mask size limit
|
| 178 |
+
if points.shape[0] > 100000:
|
| 179 |
+
# Skip DBSCAN for this mask, keep original
|
| 180 |
+
print(
|
| 181 |
+
f"DBSCAN (cuML): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)"
|
| 182 |
+
)
|
| 183 |
+
new_masks_list.append(mask)
|
| 184 |
+
new_indices_list.append(i)
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
if points.shape[0] < min_samples:
|
| 188 |
+
# Keep original
|
| 189 |
+
print(
|
| 190 |
+
f"DBSCAN (cuML): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})"
|
| 191 |
+
)
|
| 192 |
+
new_masks_list.append(mask)
|
| 193 |
+
new_indices_list.append(i)
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
# Run cuML DBSCAN
|
| 198 |
+
# dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
| 199 |
+
# labels = dbscan.fit_predict(points)
|
| 200 |
+
# fit_predict returns a cudf Series or cupy array depending on input?
|
| 201 |
+
# If input is torch tensor, cuML >= 23.04 supports __cuda_array_interface__
|
| 202 |
+
# It usually returns a cupy array or similar.
|
| 203 |
+
|
| 204 |
+
# Check if we need to convert to cupy explicitly if torch support is iffy in installed version
|
| 205 |
+
# But modern cuML supports torch tensors.
|
| 206 |
+
|
| 207 |
+
start_time = time.time()
|
| 208 |
+
clusterer = DBSCAN(eps=eps, min_samples=min_samples)
|
| 209 |
+
labels = clusterer.fit_predict(points)
|
| 210 |
+
db_time = time.time() - start_time
|
| 211 |
+
|
| 212 |
+
# Labels is likely a cupy array or similar on GPU
|
| 213 |
+
# Convert to torch for easier handling
|
| 214 |
+
if hasattr(labels, "to_dlpack"):
|
| 215 |
+
from torch.utils.dlpack import from_dlpack
|
| 216 |
+
|
| 217 |
+
labels = from_dlpack(labels.to_dlpack())
|
| 218 |
+
elif hasattr(labels, "__cuda_array_interface__"):
|
| 219 |
+
labels = torch.as_tensor(labels, device=device)
|
| 220 |
+
|
| 221 |
+
unique_labels = torch.unique(labels)
|
| 222 |
+
|
| 223 |
+
# Count valid clusters (excluding noise -1)
|
| 224 |
+
valid_clusters = unique_labels[unique_labels != -1]
|
| 225 |
+
|
| 226 |
+
if len(valid_clusters) == 0:
|
| 227 |
+
# All noise? Or just one noise cluster?
|
| 228 |
+
# If essentially no structure found, maybe keep original or drop?
|
| 229 |
+
# Standard behavior: if it was a mask, and now it's all noise...
|
| 230 |
+
# we probably shouldn't discard the *entire* mask content if it was a valid object.
|
| 231 |
+
# But DBSCAN says it's noise.
|
| 232 |
+
# Let's keep original if nothing valid found, similar to CPU path logic.
|
| 233 |
+
pass
|
| 234 |
+
|
| 235 |
+
found_cluster = False
|
| 236 |
+
|
| 237 |
+
# Reconstruct masks
|
| 238 |
+
# We need global indices of the points
|
| 239 |
+
mask_indices = torch.nonzero(mask, as_tuple=True)[0]
|
| 240 |
+
|
| 241 |
+
for label in valid_clusters:
|
| 242 |
+
found_cluster = True
|
| 243 |
+
# Create new boolean mask
|
| 244 |
+
# 1. Start with zeros
|
| 245 |
+
new_mask = torch.zeros_like(mask)
|
| 246 |
+
# 2. Get local indices where label matches
|
| 247 |
+
local_indices = (labels == label).nonzero(as_tuple=True)[0]
|
| 248 |
+
# 3. Map to global indices
|
| 249 |
+
global_indices = mask_indices[local_indices]
|
| 250 |
+
# 4. Set True
|
| 251 |
+
new_mask[global_indices] = True
|
| 252 |
+
|
| 253 |
+
new_masks_list.append(new_mask)
|
| 254 |
+
new_indices_list.append(i)
|
| 255 |
+
|
| 256 |
+
if not found_cluster:
|
| 257 |
+
# Treat as noise/failure to cluster, keep original?
|
| 258 |
+
if len(new_masks_list) == 0 or new_indices_list[-1] != i:
|
| 259 |
+
# If we haven't added anything for this query `i`
|
| 260 |
+
# (Logic check: strictly speaking we might have added splits from previous masks
|
| 261 |
+
# so checking new_indices_list[-1] is valid only if list not empty)
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
except Exception as e:
|
| 265 |
+
print(f"DBSCAN (cuML) Error Query {i}: {e}")
|
| 266 |
+
# Fallback: keep original
|
| 267 |
+
new_masks_list.append(mask)
|
| 268 |
+
new_indices_list.append(i)
|
| 269 |
+
|
| 270 |
+
else:
|
| 271 |
+
# --- CPU Path ---
|
| 272 |
+
# print(f"DBSCAN (CPU): Processing {point_coords.shape[0]} points")
|
| 273 |
+
|
| 274 |
+
# Move inputs to CPU
|
| 275 |
+
masks_cpu = current_masks.detach().cpu().numpy()
|
| 276 |
+
coords_cpu = point_coords.detach().cpu().numpy()
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
from sklearn.cluster import DBSCAN as SklearnDBSCAN
|
| 280 |
+
except ImportError:
|
| 281 |
+
print("Scikit-learn not found. Returning original masks.")
|
| 282 |
+
print("Scikit-learn not found. Returning original masks.")
|
| 283 |
+
return (
|
| 284 |
+
current_masks,
|
| 285 |
+
current_scores,
|
| 286 |
+
current_classes,
|
| 287 |
+
torch.arange(num_queries, device=device),
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
for i in range(num_queries):
|
| 291 |
+
mask = masks_cpu[i]
|
| 292 |
+
|
| 293 |
+
if not mask.any():
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
points = coords_cpu[mask]
|
| 297 |
+
|
| 298 |
+
# Check per-mask size limit
|
| 299 |
+
if points.shape[0] > 100000:
|
| 300 |
+
# Skip DBSCAN for this mask, keep original
|
| 301 |
+
print(
|
| 302 |
+
f"DBSCAN (CPU): Skipping mask {i} due to large point cloud ({points.shape[0]} points > 100k)"
|
| 303 |
+
)
|
| 304 |
+
new_masks_list.append(current_masks[i])
|
| 305 |
+
new_indices_list.append(i)
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
if points.shape[0] < min_samples:
|
| 309 |
+
# Keep original
|
| 310 |
+
print(
|
| 311 |
+
f"DBSCAN (CPU): Skipping mask {i} due to small point cloud ({points.shape[0]} points < {min_samples})"
|
| 312 |
+
)
|
| 313 |
+
new_masks_list.append(current_masks[i])
|
| 314 |
+
new_indices_list.append(i)
|
| 315 |
+
continue
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
# Ensure float32 for sklearn
|
| 319 |
+
start_time = time.time()
|
| 320 |
+
clusterer = SklearnDBSCAN(eps=eps, min_samples=min_samples)
|
| 321 |
+
labels = clusterer.fit_predict(points.astype(np.float32))
|
| 322 |
+
db_time = time.time() - start_time
|
| 323 |
+
unique_labels = np.unique(labels)
|
| 324 |
+
print(
|
| 325 |
+
f"DBSCAN (CPU): Processing {points.shape[0]} points took {db_time:.4f} seconds, found {len(unique_labels)} clusters"
|
| 326 |
+
)
|
| 327 |
+
found_cluster = False
|
| 328 |
+
|
| 329 |
+
# We need indices to reconstruct mask on GPU/CPU
|
| 330 |
+
# Since we are returning torch tensors on `device`, let's construct list of tensors
|
| 331 |
+
# It is faster to construct on CPU then move or construct on GPU?
|
| 332 |
+
# Constructing on GPU inside loop might be slow due to kernel launches.
|
| 333 |
+
# Let's construct on GPU to match the list type of cuML path
|
| 334 |
+
|
| 335 |
+
mask_indices_cpu = np.nonzero(mask)[0]
|
| 336 |
+
|
| 337 |
+
for label in unique_labels:
|
| 338 |
+
if label == -1:
|
| 339 |
+
continue
|
| 340 |
+
found_cluster = True
|
| 341 |
+
|
| 342 |
+
# Construct new mask
|
| 343 |
+
# It's easier to create on CPU then convert
|
| 344 |
+
new_mask_cpu = np.zeros_like(mask) # bool/uint8
|
| 345 |
+
|
| 346 |
+
local_mask = labels == label
|
| 347 |
+
active_indices = mask_indices_cpu[local_mask]
|
| 348 |
+
new_mask_cpu[active_indices] = 1 # True
|
| 349 |
+
|
| 350 |
+
# Convert to tensor on device
|
| 351 |
+
new_masks_list.append(
|
| 352 |
+
torch.from_numpy(new_mask_cpu).to(device, dtype=torch.bool)
|
| 353 |
+
)
|
| 354 |
+
new_indices_list.append(i)
|
| 355 |
+
|
| 356 |
+
if not found_cluster:
|
| 357 |
+
# Keep original? Currently explicitly dropped in previous code pass?
|
| 358 |
+
# "if not found_cluster: # Treated as noise, currently dropped."
|
| 359 |
+
# But we should probably keep it if it was a valid object that just didn't cluster well?
|
| 360 |
+
# The original code did `pass`.
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
print(f"DBSCAN (CPU) Error Query {i}: {e}")
|
| 365 |
+
new_masks_list.append(current_masks[i])
|
| 366 |
+
new_indices_list.append(i)
|
| 367 |
+
|
| 368 |
+
# 3. Assemble Results
|
| 369 |
+
if len(new_masks_list) == 0:
|
| 370 |
+
return (
|
| 371 |
+
torch.zeros((0, current_masks.shape[1]), device=device, dtype=torch.bool),
|
| 372 |
+
torch.zeros((0,), device=device, dtype=current_scores.dtype),
|
| 373 |
+
torch.zeros((0,), device=device, dtype=current_classes.dtype),
|
| 374 |
+
torch.zeros((0,), device=device, dtype=torch.long),
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
final_masks = torch.stack(new_masks_list)
|
| 378 |
+
|
| 379 |
+
# Gather scores and classes using indices
|
| 380 |
+
indices_tensor = torch.tensor(new_indices_list, device=device, dtype=torch.long)
|
| 381 |
+
final_scores = current_scores[indices_tensor]
|
| 382 |
+
final_classes = current_classes[indices_tensor]
|
| 383 |
+
|
| 384 |
+
return final_masks, final_scores, final_classes, indices_tensor
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def apply_post_processing(
|
| 388 |
+
pred_masks: torch.Tensor,
|
| 389 |
+
pred_logits: torch.Tensor,
|
| 390 |
+
mask_threshold: float = 0.0,
|
| 391 |
+
point_coords: Optional[torch.Tensor] = None,
|
| 392 |
+
pp_cfg: Optional[Dict] = None,
|
| 393 |
+
pred_iou: Optional[torch.Tensor] = None,
|
| 394 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 395 |
+
"""
|
| 396 |
+
Applies configured post-processing filters.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
pred_masks: [Q, N] mask logits
|
| 400 |
+
pred_logits: [Q, 2] class logits (objectness is class 0)
|
| 401 |
+
mask_threshold: Threshold for mask binarization (usually 0.0 for logits)
|
| 402 |
+
pred_iou: Optional [Q] learned IoU logits from SpaceFormer's IoU head.
|
| 403 |
+
When provided, `sigmoid(pred_iou)` replaces the hand-coded
|
| 404 |
+
`mask_quality = (sigmoid(masks) * binary).sum / binary.sum` proxy in
|
| 405 |
+
the score = obj * quality formula. DBSCAN expansion copies the same
|
| 406 |
+
scalar to every component of an expanded query.
|
| 407 |
+
pp_cfg: Post-processing configuration dict with keys:
|
| 408 |
+
- objectness_thresh: float (default 0.0, disabled)
|
| 409 |
+
- min_mask_points: int (default 0, disabled)
|
| 410 |
+
- use_stability_score: bool (default False)
|
| 411 |
+
- stability_score_thresh: float (default 0.9)
|
| 412 |
+
- stability_score_offset: float (default 1.0)
|
| 413 |
+
- stability_score_thresh: float (default 0.9)
|
| 414 |
+
- stability_score_offset: float (default 1.0)
|
| 415 |
+
- use_nms: bool (default False)
|
| 416 |
+
- nms_thresh: float (default 0.7)
|
| 417 |
+
- use_dbscan: bool (default False)
|
| 418 |
+
- dbscan_eps: float (default 0.95)
|
| 419 |
+
- dbscan_min_points: int (default 1)
|
| 420 |
+
- dbscan_backend: str (default "auto")
|
| 421 |
+
|
| 422 |
+
Returns:
|
| 423 |
+
final_masks: [Q', N] final binary masks
|
| 424 |
+
final_scores: [Q'] final scores
|
| 425 |
+
final_classes: [Q'] final classes
|
| 426 |
+
final_indices: [Q'] indices mapping to original queries
|
| 427 |
+
"""
|
| 428 |
+
if pp_cfg is None:
|
| 429 |
+
pp_cfg = {}
|
| 430 |
+
|
| 431 |
+
# Basic preparation
|
| 432 |
+
masks_binary = pred_masks > mask_threshold
|
| 433 |
+
|
| 434 |
+
# 0. Min Point Count Filtering (FIRST STEP - early rejection)
|
| 435 |
+
# Filter out small masks before expensive operations like DBSCAN
|
| 436 |
+
keep = torch.arange(pred_masks.shape[0], device=pred_masks.device)
|
| 437 |
+
|
| 438 |
+
if pp_cfg.get("min_mask_points", 0) > 0:
|
| 439 |
+
counts = masks_binary.float().sum(1)
|
| 440 |
+
keep_size = counts >= pp_cfg["min_mask_points"]
|
| 441 |
+
keep = keep[keep_size]
|
| 442 |
+
|
| 443 |
+
if len(keep) == 0:
|
| 444 |
+
return (
|
| 445 |
+
torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
|
| 446 |
+
torch.zeros((0,), device=pred_masks.device, dtype=pred_masks.dtype),
|
| 447 |
+
torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
|
| 448 |
+
torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Filter all inputs
|
| 452 |
+
masks_binary = masks_binary[keep]
|
| 453 |
+
pred_masks = pred_masks[keep]
|
| 454 |
+
pred_logits = pred_logits[keep]
|
| 455 |
+
if pred_iou is not None:
|
| 456 |
+
pred_iou = pred_iou[keep]
|
| 457 |
+
|
| 458 |
+
# 1. DBSCAN Expansion
|
| 459 |
+
# If DBSCAN is used, we expand masks immediately.
|
| 460 |
+
# We maintain a mapping to original logits to allow stability calculation later.
|
| 461 |
+
|
| 462 |
+
current_masks = masks_binary
|
| 463 |
+
current_logits = pred_masks
|
| 464 |
+
current_pred_logits = pred_logits
|
| 465 |
+
|
| 466 |
+
# Track indices (now relative to filtered set if min_mask_points was applied)
|
| 467 |
+
current_indices = keep.clone()
|
| 468 |
+
|
| 469 |
+
# Objectness component
|
| 470 |
+
# Check what class 0 means?
|
| 471 |
+
obj_probs = pred_logits.softmax(dim=-1)[:, 0]
|
| 472 |
+
|
| 473 |
+
# Mask quality component (IoU proxy) — learned if pred_iou is provided
|
| 474 |
+
# (P3-SAM-style IoU head), otherwise the hand-coded sigmoid-mean proxy.
|
| 475 |
+
if pred_iou is not None:
|
| 476 |
+
mask_quality = pred_iou.sigmoid()
|
| 477 |
+
else:
|
| 478 |
+
masks_sigmoid = pred_masks.sigmoid()
|
| 479 |
+
mask_quality = (masks_sigmoid * masks_binary.float()).sum(1) / (
|
| 480 |
+
masks_binary.float().sum(1) + 1e-6
|
| 481 |
+
)
|
| 482 |
+
scores = obj_probs * mask_quality
|
| 483 |
+
classes = torch.zeros_like(scores, dtype=torch.long) # class 0
|
| 484 |
+
|
| 485 |
+
if pp_cfg.get("use_dbscan", False) and point_coords is not None:
|
| 486 |
+
current_masks, scores, classes, dbscan_indices = apply_dbscan_clustering(
|
| 487 |
+
current_masks,
|
| 488 |
+
point_coords,
|
| 489 |
+
scores,
|
| 490 |
+
classes,
|
| 491 |
+
eps=pp_cfg.get("dbscan_eps", 0.95),
|
| 492 |
+
min_samples=pp_cfg.get("dbscan_min_points", 1),
|
| 493 |
+
backend=pp_cfg.get("dbscan_backend", "auto"),
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# We need to map them back to original query indices
|
| 497 |
+
current_indices = keep[dbscan_indices]
|
| 498 |
+
|
| 499 |
+
# Expand logits and other properties to match split masks
|
| 500 |
+
# Use dbscan_indices (relative to current filtered set) for indexing current tensors
|
| 501 |
+
current_logits = current_logits[dbscan_indices]
|
| 502 |
+
current_pred_logits = current_pred_logits[dbscan_indices]
|
| 503 |
+
obj_probs = obj_probs[dbscan_indices]
|
| 504 |
+
|
| 505 |
+
# MASK THE LOGITS (Stability Fix)
|
| 506 |
+
# Key step: constrain the logits to the new binary mask shape
|
| 507 |
+
# so stability score is calculated on the component, not the whole original mask.
|
| 508 |
+
# We use a large negative value for background.
|
| 509 |
+
current_logits = torch.where(current_masks, current_logits, -100.0)
|
| 510 |
+
|
| 511 |
+
# Recalculate mask quality for the NEW masks. With learned IoU we copy
|
| 512 |
+
# the parent query's scalar to every expanded component (no per-component
|
| 513 |
+
# IoU prediction is available); without it, recompute the sigmoid-mean
|
| 514 |
+
# proxy from the masked logits.
|
| 515 |
+
if pred_iou is not None:
|
| 516 |
+
mask_quality = pred_iou[dbscan_indices].sigmoid()
|
| 517 |
+
else:
|
| 518 |
+
masks_sigmoid = current_logits.sigmoid()
|
| 519 |
+
mask_quality = (masks_sigmoid * current_masks.float()).sum(1) / (
|
| 520 |
+
current_masks.float().sum(1) + 1e-6
|
| 521 |
+
)
|
| 522 |
+
# Recalculate scores (Obj * Quality)
|
| 523 |
+
scores = obj_probs * mask_quality
|
| 524 |
+
|
| 525 |
+
# Now we have `current_masks` (binary) and `current_logits` (masked logits).
|
| 526 |
+
# All subsequent steps operate on these.
|
| 527 |
+
|
| 528 |
+
# 2. Objectness Filtering
|
| 529 |
+
keep = torch.arange(current_masks.shape[0], device=current_masks.device)
|
| 530 |
+
|
| 531 |
+
if pp_cfg.get("objectness_thresh", 0.0) > 0:
|
| 532 |
+
# obj_probs is aligned with current set
|
| 533 |
+
keep_obj = obj_probs > pp_cfg["objectness_thresh"]
|
| 534 |
+
keep = keep[keep_obj[keep]]
|
| 535 |
+
|
| 536 |
+
if len(keep) == 0:
|
| 537 |
+
return (
|
| 538 |
+
torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
|
| 539 |
+
torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype),
|
| 540 |
+
torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype),
|
| 541 |
+
torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
# 3. Stability Score
|
| 545 |
+
if pp_cfg.get("use_stability_score", False):
|
| 546 |
+
active_logits = current_logits[keep]
|
| 547 |
+
stability = calculate_stability_score(
|
| 548 |
+
active_logits,
|
| 549 |
+
mask_threshold,
|
| 550 |
+
pp_cfg.get("stability_score_offset", 1.0),
|
| 551 |
+
)
|
| 552 |
+
keep_stable = stability >= pp_cfg.get("stability_score_thresh", 0.9)
|
| 553 |
+
keep = keep[keep_stable]
|
| 554 |
+
|
| 555 |
+
if len(keep) == 0:
|
| 556 |
+
return (
|
| 557 |
+
torch.zeros((0, pred_masks.shape[1]), device=pred_masks.device, dtype=torch.bool),
|
| 558 |
+
torch.zeros((0,), device=pred_masks.device, dtype=scores.dtype),
|
| 559 |
+
torch.zeros((0,), device=pred_masks.device, dtype=classes.dtype),
|
| 560 |
+
torch.zeros((0,), device=pred_masks.device, dtype=torch.long),
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# 4. NMS
|
| 564 |
+
if pp_cfg.get("use_nms", False):
|
| 565 |
+
active_masks = current_masks[keep]
|
| 566 |
+
active_scores = scores[keep]
|
| 567 |
+
|
| 568 |
+
keep_nms = apply_nms(active_masks, active_scores, pp_cfg.get("nms_thresh", 0.7))
|
| 569 |
+
keep = keep[keep_nms]
|
| 570 |
+
|
| 571 |
+
# Final gather
|
| 572 |
+
final_masks = current_masks[keep]
|
| 573 |
+
final_scores = scores[keep]
|
| 574 |
+
final_classes = classes[keep]
|
| 575 |
+
final_indices = current_indices[keep]
|
| 576 |
+
|
| 577 |
+
return final_masks, final_scores, final_classes, final_indices
|
demo/requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Demo/inference dependencies.
|
| 2 |
+
# WarpConvNet (with its compiled _C extension) must be installed separately — a
|
| 3 |
+
# pre-built wheel or built from source; it is environment-specific so not pinned here.
|
| 4 |
+
torch
|
| 5 |
+
einops
|
| 6 |
+
transformers
|
| 7 |
+
numpy
|
| 8 |
+
huggingface_hub
|
| 9 |
+
gradio
|
| 10 |
+
plotly
|
| 11 |
+
# optional point-cloud input formats
|
| 12 |
+
plyfile
|
| 13 |
+
# local viser demo (demo_viser.py): interactive 3D viewer + sample .ply I/O
|
| 14 |
+
viser
|
| 15 |
+
open3d
|
demo/text_encoder.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import os
|
| 7 |
+
import abc
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import AutoTokenizer, AutoModel
|
| 12 |
+
|
| 13 |
+
# Assume that models are already cached
|
| 14 |
+
os.environ["HF_HUB_OFFLINE"] = "1"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Use a deterministic hash function for strings
|
| 18 |
+
def string_hash(s: str) -> int:
|
| 19 |
+
return int(hashlib.md5(s.encode()).hexdigest(), 16)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class CLIPTextEncoderInterace(abc.ABC):
|
| 23 |
+
model: torch.nn.Module
|
| 24 |
+
CHANNEL_DIM: int
|
| 25 |
+
|
| 26 |
+
def __post_init__(self):
|
| 27 |
+
self.freeze_encoder()
|
| 28 |
+
|
| 29 |
+
def freeze_encoder(self):
|
| 30 |
+
for params in self.model.parameters():
|
| 31 |
+
params.requires_grad = False
|
| 32 |
+
|
| 33 |
+
@abc.abstractmethod
|
| 34 |
+
def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
|
| 35 |
+
raise NotImplementedError
|
| 36 |
+
|
| 37 |
+
@torch.inference_mode()
|
| 38 |
+
def get_unique_text_embedding(
|
| 39 |
+
self,
|
| 40 |
+
list_of_texts: List[str] | List[List[str]],
|
| 41 |
+
normalize: bool = True,
|
| 42 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 43 |
+
"""Get unique embeddings for a list of texts.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
list_of_texts: List[str] | List[List[str]]
|
| 47 |
+
List of texts or list of list of texts to get unique embeddings for.
|
| 48 |
+
Total number of texts is N.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
embeddings: torch.Tensor, shape (M, D)
|
| 52 |
+
Unique embeddings for the list of texts.
|
| 53 |
+
from_unique_indices: torch.Tensor, shape (N,)
|
| 54 |
+
Indices of the texts in the original list.
|
| 55 |
+
to_unique_indices: torch.Tensor, shape (M,)
|
| 56 |
+
Indices of the unique texts in the flattened list.
|
| 57 |
+
"""
|
| 58 |
+
# Flatten the list of texts
|
| 59 |
+
if isinstance(list_of_texts, list) and isinstance(list_of_texts[0], list):
|
| 60 |
+
# list of lists
|
| 61 |
+
list_of_texts = [item for sublist in list_of_texts for item in sublist]
|
| 62 |
+
|
| 63 |
+
# cchoy: Get unique texts using hash. Using string directly is not deterministic due to python string object not using the string values only for hashing.
|
| 64 |
+
flat_caption_hash = [string_hash(caption) for caption in list_of_texts]
|
| 65 |
+
_, to_unique_indices, from_unique_indices = np.unique(
|
| 66 |
+
flat_caption_hash, return_index=True, return_inverse=True
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Get unique texts
|
| 70 |
+
unique_texts = [list_of_texts[i] for i in to_unique_indices]
|
| 71 |
+
|
| 72 |
+
# Get embeddings
|
| 73 |
+
embeddings = self(unique_texts, normalize=normalize)
|
| 74 |
+
|
| 75 |
+
# Return embeddings and indices
|
| 76 |
+
return embeddings, torch.tensor(from_unique_indices), torch.tensor(to_unique_indices)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_text_encoder(
|
| 80 |
+
model_type: str,
|
| 81 |
+
device: str,
|
| 82 |
+
**kwargs,
|
| 83 |
+
) -> CLIPTextEncoderInterace:
|
| 84 |
+
if model_type == "siglip2":
|
| 85 |
+
return Siglip2TextEncoder(device=device, **kwargs)
|
| 86 |
+
elif model_type == "openclip": # Recap CLIP is also openclip
|
| 87 |
+
return OpenCLIPTextEncoder(device=device, **kwargs)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"Model type {model_type} not supported")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class OpenCLIPTextEncoder(CLIPTextEncoderInterace):
|
| 93 |
+
CHANNEL_DIM = None
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
model_id: str,
|
| 98 |
+
device: str = "cuda",
|
| 99 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 100 |
+
context_length: int = None,
|
| 101 |
+
**kwargs,
|
| 102 |
+
):
|
| 103 |
+
# This is a not a required dependency, so we need to import it here
|
| 104 |
+
try:
|
| 105 |
+
from open_clip import create_model_from_pretrained, get_tokenizer
|
| 106 |
+
except ImportError:
|
| 107 |
+
raise ImportError(
|
| 108 |
+
"open_clip is not installed. Please install it with `pip install open-clip`"
|
| 109 |
+
)
|
| 110 |
+
self.prepare_data(model_id)
|
| 111 |
+
|
| 112 |
+
self.tokenizer = get_tokenizer(model_id)
|
| 113 |
+
precision = {torch.float16: "fp16", torch.bfloat16: "bf16"}[torch_dtype]
|
| 114 |
+
self.model, _ = create_model_from_pretrained(
|
| 115 |
+
model_id,
|
| 116 |
+
device=device,
|
| 117 |
+
precision=precision,
|
| 118 |
+
)
|
| 119 |
+
self.device = device
|
| 120 |
+
|
| 121 |
+
# Set context_length: use provided value, or infer from model, or use default
|
| 122 |
+
if context_length is not None:
|
| 123 |
+
self.context_length = context_length
|
| 124 |
+
elif hasattr(self.model, "context_length"):
|
| 125 |
+
self.context_length = self.model.context_length
|
| 126 |
+
elif hasattr(self.model, "text") and hasattr(self.model.text, "context_length"):
|
| 127 |
+
self.context_length = self.model.text.context_length
|
| 128 |
+
else:
|
| 129 |
+
# Default to 77 for standard CLIP models
|
| 130 |
+
self.context_length = 77
|
| 131 |
+
print(
|
| 132 |
+
f"Warning: Could not infer context_length from model, using default: {self.context_length}"
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def prepare_data(self, model_id: str):
|
| 136 |
+
from open_clip.factory import download_pretrained_from_hf
|
| 137 |
+
|
| 138 |
+
# Remove hf-hub: prefix if it exists
|
| 139 |
+
model_id = model_id[len("hf-hub:") :] if model_id.startswith("hf-hub:") else model_id
|
| 140 |
+
ckpt_path = download_pretrained_from_hf(
|
| 141 |
+
model_id, cache_dir=os.environ.get("HF_HUB_CACHE", os.path.expanduser("~/.cache/"))
|
| 142 |
+
)
|
| 143 |
+
return ckpt_path
|
| 144 |
+
|
| 145 |
+
@torch.inference_mode()
|
| 146 |
+
@torch.amp.autocast(enabled=True, device_type="cuda")
|
| 147 |
+
def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
|
| 148 |
+
text_tokens = self.tokenizer(list_of_texts, context_length=self.context_length).to(
|
| 149 |
+
self.device
|
| 150 |
+
)
|
| 151 |
+
embeddings = self.model.encode_text(text_tokens)
|
| 152 |
+
if normalize:
|
| 153 |
+
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
| 154 |
+
return embeddings
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class Siglip2TextEncoder(CLIPTextEncoderInterace):
|
| 158 |
+
CHANNEL_DIM = 1152
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
model_id: str = "google/siglip2-so400m-patch16-384",
|
| 163 |
+
device: str = "cuda",
|
| 164 |
+
attn_implementation: str = "flash_attention_2",
|
| 165 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 166 |
+
**kwargs,
|
| 167 |
+
):
|
| 168 |
+
# Disable tokenizer parallelism
|
| 169 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 170 |
+
|
| 171 |
+
# Try loading from local cache first to avoid 429 errors
|
| 172 |
+
try:
|
| 173 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
|
| 174 |
+
self.model = AutoModel.from_pretrained(
|
| 175 |
+
model_id,
|
| 176 |
+
attn_implementation=attn_implementation,
|
| 177 |
+
torch_dtype=torch_dtype,
|
| 178 |
+
device_map=device,
|
| 179 |
+
local_files_only=True,
|
| 180 |
+
)
|
| 181 |
+
print(f"Successfully loaded {model_id} from local cache.")
|
| 182 |
+
except OSError:
|
| 183 |
+
print(
|
| 184 |
+
f"Model {model_id} not found locally. Downloading/Updating from Hugging Face Hub..."
|
| 185 |
+
)
|
| 186 |
+
# Fallback to downloading if not found locally
|
| 187 |
+
# This might still hit 429 if many ranks try it, but it's the standard fallback.
|
| 188 |
+
# Ideally verify downloading on rank 0 only in a multi-node setup if this persists.
|
| 189 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 190 |
+
self.model = AutoModel.from_pretrained(
|
| 191 |
+
model_id,
|
| 192 |
+
attn_implementation=attn_implementation,
|
| 193 |
+
torch_dtype=torch_dtype,
|
| 194 |
+
device_map=device,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.model.vision_model = None # Remove vision model
|
| 198 |
+
self.device = device
|
| 199 |
+
|
| 200 |
+
@torch.inference_mode()
|
| 201 |
+
@torch.amp.autocast(enabled=True, device_type="cuda")
|
| 202 |
+
def __call__(self, list_of_texts: List[str], normalize: bool = True) -> torch.Tensor:
|
| 203 |
+
# Length is 64 https://huggingface.co/docs/transformers/main/model_doc/siglip2
|
| 204 |
+
text_inputs = self.tokenizer(
|
| 205 |
+
list_of_texts,
|
| 206 |
+
padding="max_length",
|
| 207 |
+
truncation=True,
|
| 208 |
+
max_length=64,
|
| 209 |
+
return_tensors="pt",
|
| 210 |
+
).to(self.device)
|
| 211 |
+
outputs = self.model.get_text_features(**text_inputs)
|
| 212 |
+
# In newer transformers, get_text_features may return a
|
| 213 |
+
# BaseModelOutputWithPooling instead of a plain tensor.
|
| 214 |
+
if not isinstance(outputs, torch.Tensor):
|
| 215 |
+
outputs = outputs.pooler_output
|
| 216 |
+
if normalize:
|
| 217 |
+
outputs = torch.nn.functional.normalize(outputs, dim=-1)
|
| 218 |
+
return outputs
|