SpaCeFormer / demo /inference.py
chrischoy's picture
Merge SpaceFormer demo (viser + CLI + Gradio) under demo/
a8e8155 verified
Raw
History Blame Contribute Delete
3.95 kB
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""SpaceFormer single-scene open-vocabulary 3D instance segmentation (CLI inference).
Self-contained command-line entry point for the released
SpaceFormer. The model + inference pipeline come from the installed ``warpconvnet``
library (``warpconvnet.models.spaceformer``); this script only wires up checkpoint
resolution (local path or HuggingFace auto-download), runs one scene, and
prints/saves the predicted instances.
# weights from a local checkpoint
python inference.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \
--scene /path/to/scene_dir
# or auto-download from a HuggingFace model repo
HF_REPO_ID=chrischoy/SpaCeFormer python inference.py \
--scene my_scene.ply --class-names "office chair" "desk" "monitor" "other"
Requires WarpConvNet (with its compiled extension) + transformers in the environment.
"""
import argparse
import os
import torch
from warpconvnet.models.spaceformer import (
build_spaceformer,
load_spaceformer_checkpoint,
)
from labels import CLASS_LABELS_200, DEFAULT_CLASS_NAMES
from pipeline import (
build_text_embeddings,
load_scene,
make_batch,
predict_instances,
print_summary,
save_results,
)
HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt")
def resolve_ckpt(ckpt_arg: str | None) -> str:
"""Return a local checkpoint path: explicit --ckpt, $SPACEFORMER_CKPT, or HF download."""
if ckpt_arg:
return ckpt_arg
explicit = os.environ.get("SPACEFORMER_CKPT")
if explicit:
return explicit
repo_id = os.environ.get("HF_REPO_ID")
if not repo_id:
raise SystemExit(
"No checkpoint: pass --ckpt, or set SPACEFORMER_CKPT (local path) or "
"HF_REPO_ID (HuggingFace model repo to download from)."
)
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=repo_id, filename=HF_FILENAME)
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--ckpt", default=None,
help="checkpoint path (else $SPACEFORMER_CKPT, else HF download via $HF_REPO_ID)")
ap.add_argument("--scene", required=True,
help="scene dir (coord.npy+color.npy) or .npy/.npz/.ply file")
ap.add_argument("--class-names", nargs="+", default=None,
help="open-vocab class names (default: a short built-in list)")
ap.add_argument("--use-scannet200", action="store_true",
help="use the full ScanNet200 label set as class names")
ap.add_argument("--iou-head", action="store_true",
help="build the learned IoU head (only for a checkpoint that has one)")
ap.add_argument("--save", default=None, help="optional output .npz path")
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = ap.parse_args()
device = torch.device(args.device)
if args.use_scannet200:
class_names = list(CLASS_LABELS_200)
else:
class_names = args.class_names or list(DEFAULT_CLASS_NAMES)
print(f"[vocab] {len(class_names)} classes")
net = build_spaceformer(use_iou_head=args.iou_head, device=device)
missing, unexpected = load_spaceformer_checkpoint(net, resolve_ckpt(args.ckpt))
print(f"[weights] {len(missing)} missing, {len(unexpected)} unexpected")
coord_np, color_np = load_scene(args.scene)
batch = make_batch(coord_np, color_np, device)
text_eval = build_text_embeddings(class_names, device)
results = predict_instances(net, batch, text_eval, class_names)
print_summary(results)
if args.save:
save_results(results, coord_np, args.save)
if __name__ == "__main__":
main()