# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """HuggingFace Space: SpaceFormer open-vocabulary 3D instance segmentation release. Presentation/deployment layer. All model + inference logic is imported from the installed ``warpconvnet`` library (``warpconvnet.models.spaceformer``); this file only adds the Gradio UI, the 3D Plotly viewer, and checkpoint download. Upload an RGB point cloud (.ply / [N,6] .npy / .npz), type comma-separated class names, and get an interactive 3D view colored by predicted instance + a ranked table of {label, score, #points}. WarpConvNet (with its compiled extension) and transformers must be installed in the Space image. Configure the checkpoint via Space variables: HF_REPO_ID model repo holding the checkpoint (e.g. chrischoy/SpaCeFormer) HF_FILENAME checkpoint filename (default: spaceformer_512_siglip2_ssccc.ckpt) SPACEFORMER_CKPT explicit local checkpoint path (overrides the HF download) """ import os import numpy as np import torch from warpconvnet.models.spaceformer import ( build_spaceformer, load_spaceformer_checkpoint, ) from labels import DEFAULT_CLASS_NAMES, PROMPT_TEMPLATES from pipeline import ( SIGLIP_MODEL_ID, load_scene, make_batch, predict_instances, ) HF_REPO_ID = os.environ.get("HF_REPO_ID", "") HF_FILENAME = os.environ.get("HF_FILENAME", "spaceformer_512_siglip2_ssccc.ckpt") _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") _STATE = {"net": None, "clip": None} # lazy singletons, kept resident across requests def _resolve_ckpt() -> str: explicit = os.environ.get("SPACEFORMER_CKPT") if explicit: return explicit if not HF_REPO_ID: raise RuntimeError( "Set SPACEFORMER_CKPT to a local checkpoint, or HF_REPO_ID to a " "HuggingFace model repo to auto-download from." ) from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME) def _get_model(): if _STATE["net"] is None: net = build_spaceformer(device=_DEVICE) load_spaceformer_checkpoint(net, _resolve_ckpt()) _STATE["net"] = net return _STATE["net"] def _get_clip_encoder(): if _STATE["clip"] is None: from text_encoder import get_text_encoder _STATE["clip"] = get_text_encoder( model_type="siglip2", model_id=SIGLIP_MODEL_ID, device=str(_DEVICE) ) return _STATE["clip"] def _text_eval(class_names): from clip_eval import CLIPAlignmentEval evaluator = CLIPAlignmentEval(normalize_input=False) evaluator.prepare_target_embedding( class_names=list(class_names), clip_encoder=_get_clip_encoder(), device=_DEVICE, prompt_templates=list(PROMPT_TEMPLATES), ) return evaluator def _palette(n): rng = np.random.default_rng(0) return [tuple(int(x) for x in rng.integers(40, 230, size=3)) for _ in range(max(n, 1))] def _plot(coord_np, results, top_k, score_thresh): """Plotly 3D scatter colored by instance (top_k by score).""" import plotly.graph_objects as go kept = [r for r in results if r["score"] >= score_thresh][:top_k] rgb = np.full((coord_np.shape[0], 3), 160, dtype=np.uint8) # grey background palette = _palette(len(kept)) for i, r in enumerate(kept): rgb[r["mask"]] = palette[i] colors = [f"rgb({c[0]},{c[1]},{c[2]})" for c in rgb] # Subsample for browser responsiveness. n = coord_np.shape[0] if n > 120_000: idx = np.random.default_rng(0).choice(n, 120_000, replace=False) else: idx = np.arange(n) fig = go.Figure( data=[go.Scatter3d( x=coord_np[idx, 0], y=coord_np[idx, 1], z=coord_np[idx, 2], mode="markers", marker=dict(size=1.5, color=[colors[i] for i in idx]), )] ) fig.update_layout( scene=dict(aspectmode="data"), margin=dict(l=0, r=0, t=0, b=0), showlegend=False, ) return fig def segment(scene_file, class_text, top_k, score_thresh): """Gradio callback: file + class names -> (3D figure, results table).""" if scene_file is None: return None, [["(upload a point cloud first)", "", ""]] class_names = [c.strip() for c in class_text.split(",") if c.strip()] \ or list(DEFAULT_CLASS_NAMES) path = scene_file.name if hasattr(scene_file, "name") else scene_file coord_np, color_np = load_scene(path) batch = make_batch(coord_np, color_np, _DEVICE) net = _get_model() results = predict_instances(net, batch, _text_eval(class_names), class_names) fig = _plot(coord_np, results, int(top_k), float(score_thresh)) table = [ [r["label"], f"{r['score']:.3f}", int(r["mask"].sum())] for r in results[: int(top_k)] ] if not table: table = [["(no instances above threshold)", "", ""]] return fig, table def build_interface(): import gradio as gr with gr.Blocks(title="SpaceFormer — Open-Vocab 3D Instance Segmentation") as demo: gr.Markdown( "# SpaceFormer\n" "Proposal-free **open-vocabulary 3D instance segmentation**. Upload an " "RGB point cloud, type any class names, and get instance masks labeled " "against your vocabulary (SigLIP2 text + prompt ensembling).\n\n" "Released checkpoint: ScanNet200 **0.1265** / ScanNet++ 0.2217 / Replica 0.2644." ) with gr.Row(): with gr.Column(scale=1): scene_file = gr.File(label="Point cloud (.ply / .npy[N,6] / .npz)") class_text = gr.Textbox( label="Class names (comma-separated)", value=", ".join(DEFAULT_CLASS_NAMES), ) top_k = gr.Slider(1, 100, value=30, step=1, label="Max instances shown") score_thresh = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Score threshold") run = gr.Button("Segment", variant="primary") with gr.Column(scale=2): plot = gr.Plot(label="Predicted instances (colored)") table = gr.Dataframe( headers=["label", "score", "#points"], label="Instances", wrap=True, ) run.click(segment, [scene_file, class_text, top_k, score_thresh], [plot, table]) return demo if __name__ == "__main__": build_interface().launch(server_name="0.0.0.0")