| |
| |
| """SpaceFormer single-scene open-vocabulary 3D instance segmentation (CLI inference). |
| |
| Self-contained command-line entry point for the released |
| SpaceFormer. The model + inference pipeline come from the installed ``warpconvnet`` |
| library (``warpconvnet.models.spaceformer``); this script only wires up checkpoint |
| resolution (local path or HuggingFace auto-download), runs one scene, and |
| prints/saves the predicted instances. |
| |
| # weights from a local checkpoint |
| python inference.py --ckpt /path/to/spaceformer_512_siglip2_ssccc.ckpt \ |
| --scene /path/to/scene_dir |
| |
| # or auto-download from a HuggingFace model repo |
| HF_REPO_ID=chrischoy/SpaCeFormer python inference.py \ |
| --scene my_scene.ply --class-names "office chair" "desk" "monitor" "other" |
| |
| Requires WarpConvNet (with its compiled extension) + transformers in the environment. |
| """ |
|
|
| import argparse |
| import os |
|
|
| import torch |
|
|
| from warpconvnet.models.spaceformer import ( |
| build_spaceformer, |
| load_spaceformer_checkpoint, |
| ) |
| from labels import CLASS_LABELS_200, DEFAULT_CLASS_NAMES |
| from pipeline import ( |
| build_text_embeddings, |
| load_scene, |
| make_batch, |
| predict_instances, |
| print_summary, |
| save_results, |
| ) |
|
|
| HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt") |
|
|
|
|
| def resolve_ckpt(ckpt_arg: str | None) -> str: |
| """Return a local checkpoint path: explicit --ckpt, $SPACEFORMER_CKPT, or HF download.""" |
| if ckpt_arg: |
| return ckpt_arg |
| explicit = os.environ.get("SPACEFORMER_CKPT") |
| if explicit: |
| return explicit |
| repo_id = os.environ.get("HF_REPO_ID") |
| if not repo_id: |
| raise SystemExit( |
| "No checkpoint: pass --ckpt, or set SPACEFORMER_CKPT (local path) or " |
| "HF_REPO_ID (HuggingFace model repo to download from)." |
| ) |
| from huggingface_hub import hf_hub_download |
| return hf_hub_download(repo_id=repo_id, filename=HF_FILENAME) |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser(description=__doc__, |
| formatter_class=argparse.RawDescriptionHelpFormatter) |
| ap.add_argument("--ckpt", default=None, |
| help="checkpoint path (else $SPACEFORMER_CKPT, else HF download via $HF_REPO_ID)") |
| ap.add_argument("--scene", required=True, |
| help="scene dir (coord.npy+color.npy) or .npy/.npz/.ply file") |
| ap.add_argument("--class-names", nargs="+", default=None, |
| help="open-vocab class names (default: a short built-in list)") |
| ap.add_argument("--use-scannet200", action="store_true", |
| help="use the full ScanNet200 label set as class names") |
| ap.add_argument("--iou-head", action="store_true", |
| help="build the learned IoU head (only for a checkpoint that has one)") |
| ap.add_argument("--save", default=None, help="optional output .npz path") |
| ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| args = ap.parse_args() |
|
|
| device = torch.device(args.device) |
| if args.use_scannet200: |
| class_names = list(CLASS_LABELS_200) |
| else: |
| class_names = args.class_names or list(DEFAULT_CLASS_NAMES) |
| print(f"[vocab] {len(class_names)} classes") |
|
|
| net = build_spaceformer(use_iou_head=args.iou_head, device=device) |
| missing, unexpected = load_spaceformer_checkpoint(net, resolve_ckpt(args.ckpt)) |
| print(f"[weights] {len(missing)} missing, {len(unexpected)} unexpected") |
|
|
| coord_np, color_np = load_scene(args.scene) |
| batch = make_batch(coord_np, color_np, device) |
| text_eval = build_text_embeddings(class_names, device) |
| results = predict_instances(net, batch, text_eval, class_names) |
|
|
| print_summary(results) |
| if args.save: |
| save_results(results, coord_np, args.save) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|