# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """SpaceFormer single-scene open-vocabulary 3D instance segmentation (CLI inference). Self-contained command-line entry point for the released SpaceFormer. The model + inference pipeline come from the installed ``warpconvnet`` library (``warpconvnet.models.spaceformer``); this script only wires up checkpoint resolution (local path or HuggingFace auto-download), runs one scene, and prints/saves the predicted instances. # weights from a local checkpoint python inference.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \ --scene /path/to/scene_dir # or auto-download from a HuggingFace model repo HF_REPO_ID=chrischoy/SpaCeFormer python inference.py \ --scene my_scene.ply --class-names "office chair" "desk" "monitor" "other" Requires WarpConvNet (with its compiled extension) + transformers in the environment. """ import argparse import os import torch from warpconvnet.models.spaceformer import ( build_spaceformer, load_spaceformer_checkpoint, ) from labels import CLASS_LABELS_200, DEFAULT_CLASS_NAMES from pipeline import ( build_text_embeddings, load_scene, make_batch, predict_instances, print_summary, save_results, ) HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt") def resolve_ckpt(ckpt_arg: str | None) -> str: """Return a local checkpoint path: explicit --ckpt, $SPACEFORMER_CKPT, or HF download.""" if ckpt_arg: return ckpt_arg explicit = os.environ.get("SPACEFORMER_CKPT") if explicit: return explicit repo_id = os.environ.get("HF_REPO_ID") if not repo_id: raise SystemExit( "No checkpoint: pass --ckpt, or set SPACEFORMER_CKPT (local path) or " "HF_REPO_ID (HuggingFace model repo to download from)." ) from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=repo_id, filename=HF_FILENAME) def main() -> None: ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--ckpt", default=None, help="checkpoint path (else $SPACEFORMER_CKPT, else HF download via $HF_REPO_ID)") ap.add_argument("--scene", required=True, help="scene dir (coord.npy+color.npy) or .npy/.npz/.ply file") ap.add_argument("--class-names", nargs="+", default=None, help="open-vocab class names (default: a short built-in list)") ap.add_argument("--use-scannet200", action="store_true", help="use the full ScanNet200 label set as class names") ap.add_argument("--iou-head", action="store_true", help="build the learned IoU head (only for a checkpoint that has one)") ap.add_argument("--save", default=None, help="optional output .npz path") ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") args = ap.parse_args() device = torch.device(args.device) if args.use_scannet200: class_names = list(CLASS_LABELS_200) else: class_names = args.class_names or list(DEFAULT_CLASS_NAMES) print(f"[vocab] {len(class_names)} classes") net = build_spaceformer(use_iou_head=args.iou_head, device=device) missing, unexpected = load_spaceformer_checkpoint(net, resolve_ckpt(args.ckpt)) print(f"[weights] {len(missing)} missing, {len(unexpected)} unexpected") coord_np, color_np = load_scene(args.scene) batch = make_batch(coord_np, color_np, device) text_eval = build_text_embeddings(class_names, device) results = predict_instances(net, batch, text_eval, class_names) print_summary(results) if args.save: save_results(results, coord_np, args.save) if __name__ == "__main__": main()