# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Open-vocabulary inference pipeline for the SpaceFormer release. Task-specific glue that turns the raw model outputs (logit / mask / clip_feat, from ``warpconvnet.models.spaceformer.SpaCeFormerInstSeg``) into labeled instances: scene I/O + minimal eval transforms, SigLIP2 text embedding (prompt-ensembled), and SAM2-style mask post-processing. Kept in the release repo (not WarpConvNet), since labeling/post-processing is downstream of the model. """ import os import numpy as np import torch from postprocessing import apply_post_processing from labels import PROMPT_TEMPLATES # SigLIP2 text encoder used at training time. text_dim 1152 == model clip-head dim. SIGLIP_MODEL_ID = "google/siglip2-so400m-patch14-224" # Voxel size the model was trained/exported at (the model voxelizes internally; # this only labels the data_dict for parity). GRID_SIZE = 0.02 # Release eval post-processing: NMS on, min 20 pts/mask, DBSCAN/stability off. POST_PROCESSING_CFG = { "use_dbscan": False, "use_stability_score": False, "use_nms": True, "nms_thresh": 0.7, "min_mask_points": 20, "objectness_thresh": 0.0, } # --------------------------------------------------------------------------- # # Scene loading + minimal eval transforms # --------------------------------------------------------------------------- # def load_scene(scene_path: str): """Load one scene as (coord[N,3] float meters, color[N,3] float 0-255).""" if os.path.isdir(scene_path): coord = np.load(os.path.join(scene_path, "coord.npy")) color = np.load(os.path.join(scene_path, "color.npy")) elif scene_path.endswith(".npz"): z = np.load(scene_path) coord = z[_first_key(z, ("coord", "coords", "xyz", "points"))] color = z[_first_key(z, ("color", "colors", "rgb"))] elif scene_path.endswith(".npy"): arr = np.load(scene_path) assert arr.ndim == 2 and arr.shape[1] >= 6, "expected [N,6] xyz+rgb .npy" coord, color = arr[:, :3], arr[:, 3:6] elif scene_path.endswith(".ply"): coord, color = _load_ply(scene_path) else: raise ValueError(f"unsupported scene format: {scene_path}") coord = np.ascontiguousarray(coord, dtype=np.float32) color = np.ascontiguousarray(color) if color.dtype != np.uint8 and color.max() <= 1.0 + 1e-6: color = (color * 255.0).round() color = color.astype(np.float32) print(f"[scene] {scene_path}: {coord.shape[0]} points") return coord, color def _first_key(z, candidates): for c in candidates: if c in z: return c raise KeyError(f"none of {candidates} in {list(z.keys())}") def _load_ply(path: str): try: from plyfile import PlyData v = PlyData.read(path)["vertex"] coord = np.stack([v["x"], v["y"], v["z"]], axis=1) if "red" in v.data.dtype.names: color = np.stack([v["red"], v["green"], v["blue"]], axis=1) else: color = np.full_like(coord, 127.5) return coord, color except ImportError: import open3d as o3d pcd = o3d.io.read_point_cloud(path) coord = np.asarray(pcd.points) color = np.asarray(pcd.colors) * 255.0 if pcd.has_colors() else np.full_like(coord, 127.5) return coord, color def make_batch(coord_np, color_np, device): """Apply the minimal eval transforms and build the single-sample data dict. CenterShift(apply_z) + NormalizeColor(0-255 -> [-1,1]) + offset [0, N]. Coords stay in meters — the model voxelizes internally at 2 cm. """ coord = torch.from_numpy(coord_np).float() color = torch.from_numpy(color_np).float() cmin = coord.min(dim=0).values cmax = coord.max(dim=0).values shift = torch.tensor([(cmin[0] + cmax[0]) / 2, (cmin[1] + cmax[1]) / 2, cmin[2]]) coord = coord - shift feat = color / 127.5 - 1.0 n = coord.shape[0] offset = torch.tensor([0, n], dtype=torch.int32) return { "coord": coord.to(device), "feat": feat.to(device), "offset": offset.to(device), "grid_size": GRID_SIZE, } # --------------------------------------------------------------------------- # # Text embeddings (prompt-ensembled) + prediction # --------------------------------------------------------------------------- # def build_text_embeddings(class_names, device): """Encode class names with SigLIP2 under multiple templates and average.""" from text_encoder import get_text_encoder from clip_eval import CLIPAlignmentEval clip_encoder = get_text_encoder(model_type="siglip2", model_id=SIGLIP_MODEL_ID, device=str(device)) evaluator = CLIPAlignmentEval(normalize_input=False) # matches official eval evaluator.prepare_target_embedding( class_names=list(class_names), clip_encoder=clip_encoder, device=device, prompt_templates=list(PROMPT_TEMPLATES), # prompt ensembling ON ) return evaluator @torch.inference_mode() def predict_instances(net, batch, text_eval, class_names): """Single forward pass -> post-processing -> open-vocab labels.""" out = net(batch) binary_logits = out["logit"][0] # [Q, 2] objectness over {fg, bg} mask_logits = out["mask"][0].T # [N, Q] -> [Q, N] clip_feats = out["clip_feat"][0] # [Q, D] pred_iou = out["pred_iou"][0] if "pred_iou" in out else None class_logits = text_eval.predict(clip_feats, return_logit=True) # [Q, C] masks, scores, _classes, indices = apply_post_processing( mask_logits, binary_logits, mask_threshold=0.0, point_coords=None, pp_cfg=POST_PROCESSING_CFG, pred_iou=pred_iou, ) results = [] if len(indices) > 0: probs = torch.softmax(class_logits[indices], dim=-1) # [K, C] class_probs, class_ids = probs.max(dim=1) final_scores = scores * class_probs for k in range(len(indices)): results.append( { "mask": masks[k].cpu().numpy().astype(bool), "label": class_names[int(class_ids[k])], "label_id": int(class_ids[k]), "score": float(final_scores[k]), } ) results.sort(key=lambda r: r["score"], reverse=True) return results # --------------------------------------------------------------------------- # # Output helpers # --------------------------------------------------------------------------- # def print_summary(results, top_k=20): print(f"\n=== {len(results)} predicted instances ===") header = f"{'#':>3} {'score':>6} {'points':>8} label" print(header) print("-" * len(header)) for i, r in enumerate(results[:top_k]): print(f"{i:>3} {r['score']:>6.3f} {int(r['mask'].sum()):>8} {r['label']}") if len(results) > top_k: print(f"... ({len(results) - top_k} more)") def save_results(results, coord_np, out_path): if not results: np.savez(out_path, masks=np.zeros((0, coord_np.shape[0]), dtype=bool), labels=np.array([]), scores=np.array([])) return np.savez( out_path, coord=coord_np, masks=np.stack([r["mask"] for r in results]), labels=np.array([r["label"] for r in results]), label_ids=np.array([r["label_id"] for r in results]), scores=np.array([r["score"] for r in results]), ) print(f"[save] wrote {len(results)} instances to {out_path}")