# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Local viser demo: SpaceFormer open-vocabulary 3D instance segmentation. Takes TEXT class names, runs one forward pass of the released SpaceFormer over a point cloud, and visualizes the predicted instances in a browser with `viser `_ (each kept instance a distinct color; unassigned points grey). Model + pipeline come from the installed ``warpconvnet`` library and the sibling ``pipeline.py`` / ``postprocessing.py`` in this repo — this file only adds the forward + viser visualization glue. # local checkpoint, custom vocabulary python demo_viser.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \ --ply my_scene.ply --class-names chair table monitor wall floor # auto-download the checkpoint from HuggingFace and use a bundled sample cloud python demo_viser.py --port 8080 # falls back to chrischoy/SpaCeFormer Then open the printed URL (default http://localhost:8080) in a browser. CRITICAL coord/mask alignment (see space_former_seg.py): The model voxelizes internally (PointToSparseWrapper), so its output masks are over the model's OUTPUT points, whose count may NOT equal the raw .ply point count. In eval mode the forward returns ``out["backbone_pc"]`` (a warpconvnet ``Points``) whose ``.coordinates`` correspond 1:1 with the per-point mask rows in ``out["mask"][0]``. We therefore run the forward pass HERE (mirroring pipeline.predict_instances) so we can pull the backbone_pc coordinates and the post-processed masks together, and color THOSE coordinates — never the raw input coords, which may differ in length/order. NOTE: WarpConvNet (with its compiled ``_C`` extension) + transformers + a checkpoint are required to actually run. ``--help`` and import of this file work without viser/WarpConvNet (both are imported lazily inside ``main``). """ import argparse import os import tempfile import time import numpy as np import torch # Reused, task-specific glue from this repo (scene I/O, eval transforms, text # embeddings) — we deliberately do NOT reuse predict_instances() wholesale # because we also need the aligned backbone_pc coordinates, so we inline the # forward below and call apply_post_processing() directly. from labels import CLASS_LABELS_200, DEFAULT_CLASS_NAMES from pipeline import ( POST_PROCESSING_CFG, build_text_embeddings, load_scene, make_batch, ) from postprocessing import apply_post_processing HF_REPO_ID = os.environ.get("HF_REPO_ID", "chrischoy/SpaCeFormer") HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt") # --------------------------------------------------------------------------- # # Checkpoint + sample-scene resolution # --------------------------------------------------------------------------- # def resolve_ckpt(ckpt_arg): """Return a local checkpoint path: --ckpt, $SPACEFORMER_CKPT, or HF download.""" if ckpt_arg: return ckpt_arg explicit = os.environ.get("SPACEFORMER_CKPT") if explicit: return explicit from huggingface_hub import hf_hub_download print(f"[ckpt] downloading {HF_FILENAME} from HuggingFace repo {HF_REPO_ID} ...") return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME) def resolve_sample_ply(ply_arg): """Return a path to a .ply. If --ply is unset, use a zero-config sample. Preference order (no fragile external URLs): 1) open3d's bundled ``PLYPointCloud`` sample (a small RGB point cloud), or 2) a synthesized random RGB point cloud written to a temp .ply. A random/sample cloud will NOT segment into meaningful instances — it only demonstrates that the pipeline + visualization run end to end. """ if ply_arg: return ply_arg # 1) open3d bundled sample (downloaded/cached by open3d itself, offline after). try: import open3d as o3d sample_path = o3d.data.PLYPointCloud().path if os.path.isfile(sample_path): print(f"[sample] using open3d bundled PLYPointCloud sample: {sample_path}") print("[sample] NOTE: a generic sample cloud won't segment meaningfully; " "it's only to demo the pipeline + viz.") return sample_path except Exception as exc: # noqa: BLE001 - any open3d issue -> fall back to synthetic print(f"[sample] open3d sample unavailable ({exc}); synthesizing a random cloud.") # 2) Synthesize a small random RGB point cloud and write a temp .ply. rng = np.random.default_rng(0) n = 20_000 coord = rng.uniform(-2.0, 2.0, size=(n, 3)).astype(np.float32) # meters color = rng.integers(0, 256, size=(n, 3)).astype(np.uint8) tmp = tempfile.NamedTemporaryFile(suffix=".ply", delete=False) tmp.close() _write_ply(tmp.name, coord, color) print(f"[sample] wrote synthetic random RGB cloud ({n} pts) to {tmp.name}") print("[sample] NOTE: a random cloud won't segment meaningfully; it's only to " "demo the pipeline + viz.") return tmp.name def _write_ply(path, coord, color): """Write an ASCII RGB .ply (plyfile if available, else open3d, else raw).""" try: from plyfile import PlyData, PlyElement verts = np.empty( coord.shape[0], dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")], ) verts["x"], verts["y"], verts["z"] = coord[:, 0], coord[:, 1], coord[:, 2] verts["red"], verts["green"], verts["blue"] = color[:, 0], color[:, 1], color[:, 2] PlyData([PlyElement.describe(verts, "vertex")], text=True).write(path) return except ImportError: pass import open3d as o3d pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(coord.astype(np.float64)) pcd.colors = o3d.utility.Vector3dVector(color.astype(np.float64) / 255.0) o3d.io.write_point_cloud(path, pcd) # --------------------------------------------------------------------------- # # Forward + post-processing (aligned to backbone_pc coordinates) # --------------------------------------------------------------------------- # @torch.inference_mode() def segment_aligned(net, batch, text_eval, class_names): """One forward pass -> post-processed instances aligned to backbone_pc coords. Mirrors ``pipeline.predict_instances`` but ALSO returns the model's output coordinates (``out["backbone_pc"].coordinates``), which are the coordinates the returned masks index into (see module docstring / space_former_seg.py). Returns ``(coords[M,3] float32, results)`` where each result is ``{mask: bool[M], label, label_id, score}`` and ``M`` is the backbone point count (== number of columns in ``out["mask"][0]``), NOT the raw .ply count. """ out = net(batch) # backbone_pc is present only in eval; its coordinates align 1:1 with the # mask rows. batch size is 1 here, so every point belongs to sample 0. backbone_pc = out["backbone_pc"] coords = backbone_pc.coordinates.detach().cpu().numpy().astype(np.float32) # [M, 3] binary_logits = out["logit"][0] # [Q, 2] objectness over {fg, bg} mask_logits = out["mask"][0].T # [M, Q] -> [Q, M] clip_feats = out["clip_feat"][0] # [Q, D] pred_iou = out["pred_iou"][0] if "pred_iou" in out else None # Sanity: mask columns must match the backbone point count we will color. assert mask_logits.shape[1] == coords.shape[0], ( f"mask columns {mask_logits.shape[1]} != backbone points {coords.shape[0]}; " "coord/mask alignment broken" ) class_logits = text_eval.predict(clip_feats, return_logit=True) # [Q, C] masks, scores, _classes, indices = apply_post_processing( mask_logits, binary_logits, mask_threshold=0.0, point_coords=None, pp_cfg=POST_PROCESSING_CFG, pred_iou=pred_iou, ) results = [] if len(indices) > 0: probs = torch.softmax(class_logits[indices], dim=-1) # [K, C] class_probs, class_ids = probs.max(dim=1) final_scores = scores * class_probs for k in range(len(indices)): results.append( { "mask": masks[k].cpu().numpy().astype(bool), # bool[M], aligned to coords "label": class_names[int(class_ids[k])], "label_id": int(class_ids[k]), "score": float(final_scores[k]), } ) results.sort(key=lambda r: r["score"], reverse=True) return coords, results # --------------------------------------------------------------------------- # # Coloring + viser visualization # --------------------------------------------------------------------------- # def _instance_colors(coords, results, top_k, score_thresh): """Grey base cloud + a distinct random color per kept (top-k, thresholded) instance. Returns ``(rgb[M,3] uint8, kept)`` where ``kept`` is the list of instances actually colored (already sorted by score, high first). """ rgb = np.full((coords.shape[0], 3), 160, dtype=np.uint8) # grey background kept = [r for r in results if r["score"] >= score_thresh][:top_k] rng = np.random.default_rng(0) for r in kept: color = rng.integers(40, 230, size=3).astype(np.uint8) r["color"] = tuple(int(c) for c in color) # stash for GUI legend rgb[r["mask"]] = color return rgb, kept def visualize(coords, results, port, top_k, score_thresh, point_size): """Start a viser server, add the colored point cloud + a GUI legend, block.""" import viser # imported lazily so --help works without viser installed rgb, kept = _instance_colors(coords, results, top_k, score_thresh) server = viser.ViserServer(port=port) # Point cloud: positions from backbone_pc coords, colors per instance. server.scene.add_point_cloud( name="/scene", points=coords.astype(np.float32), colors=rgb, # uint8 [M, 3] point_size=point_size, ) # A small GUI panel listing the kept instances (label + score + #points). with server.gui.add_folder(f"Top {len(kept)} instances"): if not kept: server.gui.add_markdown("_(no instances above the score threshold)_") for i, r in enumerate(kept): c = r["color"] swatch = f'' server.gui.add_markdown( f"{swatch} **{r['label']}** — score {r['score']:.3f}, " f"{int(r['mask'].sum())} pts" ) url = f"http://localhost:{port}" print(f"\n[viser] serving at {url} (open this URL in your browser)") print("[viser] press Ctrl-C to stop.") try: while True: time.sleep(2.0) except KeyboardInterrupt: print("\n[viser] shutting down.") def print_summary(results, top_k): """Ranked text summary of instances (label, score, #points) to stdout.""" print(f"\n=== {len(results)} predicted instances ===") header = f"{'#':>3} {'score':>6} {'points':>8} label" print(header) print("-" * len(header)) for i, r in enumerate(results[:top_k]): print(f"{i:>3} {r['score']:>6.3f} {int(r['mask'].sum()):>8} {r['label']}") if len(results) > top_k: print(f"... ({len(results) - top_k} more)") # --------------------------------------------------------------------------- # # Entry point (linear top-to-bottom) # --------------------------------------------------------------------------- # def main(): ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) ap.add_argument("--ply", default=None, help="input point cloud .ply (default: an open3d bundled sample " "or a synthesized random cloud)") ap.add_argument("--class-names", nargs="+", default=None, help="open-vocab class names (default: a short built-in list)") ap.add_argument("--use-scannet200", action="store_true", help="use the full ScanNet200 label set as class names") ap.add_argument("--ckpt", default=None, help="checkpoint path (else $SPACEFORMER_CKPT, else HF download)") ap.add_argument("--iou-head", action="store_true", help="build the learned IoU head (only for a checkpoint that has one)") ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") ap.add_argument("--port", type=int, default=8080, help="viser server port") ap.add_argument("--score-thresh", type=float, default=0.0, help="only color instances with score >= this") ap.add_argument("--top-k", type=int, default=30, help="max number of instances to color / list") ap.add_argument("--point-size", type=float, default=0.02, help="viser point size (world units / meters)") args = ap.parse_args() device = torch.device(args.device) # Vocabulary (text class names — the open-vocabulary input). if args.use_scannet200: class_names = list(CLASS_LABELS_200) else: class_names = args.class_names or list(DEFAULT_CLASS_NAMES) print(f"[vocab] {len(class_names)} classes: {', '.join(class_names[:8])}" f"{' ...' if len(class_names) > 8 else ''}") # 1) Resolve/load the scene (.ply -> coord[N,3] meters, color[N,3] 0-255). ply_path = resolve_sample_ply(args.ply) coord_np, color_np = load_scene(ply_path) # 2) Build the eval batch (CenterShift + NormalizeColor + offsets). batch = make_batch(coord_np, color_np, device) # 3) Build model + load released weights (WarpConvNet required here). from warpconvnet.models.spaceformer import ( build_spaceformer, load_spaceformer_checkpoint, ) net = build_spaceformer(use_iou_head=args.iou_head, device=device) missing, unexpected = load_spaceformer_checkpoint(net, resolve_ckpt(args.ckpt)) print(f"[weights] {len(missing)} missing, {len(unexpected)} unexpected") # 4) Encode the text class names (SigLIP2 + prompt ensembling). text_eval = build_text_embeddings(class_names, device) # 5) Forward + post-process -> masks + labels + scores aligned to backbone coords. coords, results = segment_aligned(net, batch, text_eval, class_names) print(f"[model] backbone output points: {coords.shape[0]} " f"(raw input was {coord_np.shape[0]})") # 6) Text summary + interactive viser visualization. print_summary(results, args.top_k) visualize(coords, results, args.port, args.top_k, args.score_thresh, args.point_size) if __name__ == "__main__": main()