Spaces:
Sleeping
Sleeping
Spec 1.9 + 1.10: anchorage priors and mesh collision
Browse files- README.md +8 -1
- server/dental_environment.py +29 -0
- server/mesh_collision.py +280 -0
- server/movement_priors.json +82 -0
- server/movement_priors.npz +3 -0
- server/movement_priors.py +260 -0
- train_grpo.py +49 -15
README.md
CHANGED
|
@@ -36,10 +36,17 @@ Every year, 12 million patients receive clear aligners. Each treatment requires
|
|
| 36 |
|
| 37 |
## Quick Start
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
```bash
|
| 40 |
uv sync
|
| 41 |
uv run python -m server.app
|
| 42 |
-
# In another terminal:
|
| 43 |
curl http://localhost:7860/health
|
| 44 |
```
|
| 45 |
|
|
|
|
| 36 |
|
| 37 |
## Quick Start
|
| 38 |
|
| 39 |
+
**Public Space (live):** [`sri-manikanta/orthorl`](https://huggingface.co/spaces/sri-manikanta/orthorl)
|
| 40 |
+
```bash
|
| 41 |
+
curl https://sri-manikanta-orthorl.hf.space/health # {"status":"healthy"}
|
| 42 |
+
curl -X POST https://sri-manikanta-orthorl.hf.space/reset_stepwise \
|
| 43 |
+
-H 'Content-Type: application/json' -d '{"task_id":"task_medium","seed":42}'
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
**Local:**
|
| 47 |
```bash
|
| 48 |
uv sync
|
| 49 |
uv run python -m server.app
|
|
|
|
| 50 |
curl http://localhost:7860/health
|
| 51 |
```
|
| 52 |
|
server/dental_environment.py
CHANGED
|
@@ -711,9 +711,16 @@ class StepwiseDentalEnvironment:
|
|
| 711 |
# vertex PCA. Falls through silently when landmarks are missing —
|
| 712 |
# we still ship a valid synthetic case in that branch.
|
| 713 |
landmark_record = None
|
|
|
|
| 714 |
if parsed is not None and parsed[0] == 'tsinghua':
|
| 715 |
from server.landmark_loader import load_patient_cached
|
| 716 |
landmark_record = load_patient_cached(parsed[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
|
| 718 |
# Generate case from dataset, adaptive params, profile-driven, or fixed difficulty
|
| 719 |
# Spec 2.4: when generating a pure synthetic case (no explicit difficulty_params and
|
|
@@ -888,6 +895,10 @@ class StepwiseDentalEnvironment:
|
|
| 888 |
bool(force_decay) if force_decay is not None else difficulty in ("hard", "expert")
|
| 889 |
),
|
| 890 |
"submitted_trajectory": trajectory.copy(),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 891 |
# Spec 1.11: surface the resolved eval state.
|
| 892 |
"mode": mode,
|
| 893 |
"tier": self._eval_registry.tier_of(task_id) if mode == "eval" else None,
|
|
@@ -1029,11 +1040,29 @@ class StepwiseDentalEnvironment:
|
|
| 1029 |
|
| 1030 |
collision_score = self._collision.score_collision_free(current_config)
|
| 1031 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1032 |
step_reward_info["occlusion_composite"] = round(
|
| 1033 |
self._occlusion.score_composite(current_config), 4
|
| 1034 |
)
|
| 1035 |
step_reward_info["pdl_feasibility"] = round(pdl_feasibility, 4)
|
| 1036 |
step_reward_info["collision_free"] = round(collision_score, 4)
|
|
|
|
|
|
|
| 1037 |
step_reward_info["occlusion_details"] = {
|
| 1038 |
k: round(v, 4) for k, v in occlusion_scores.items()
|
| 1039 |
}
|
|
|
|
| 711 |
# vertex PCA. Falls through silently when landmarks are missing —
|
| 712 |
# we still ship a valid synthetic case in that branch.
|
| 713 |
landmark_record = None
|
| 714 |
+
mesh_detector = None
|
| 715 |
if parsed is not None and parsed[0] == 'tsinghua':
|
| 716 |
from server.landmark_loader import load_patient_cached
|
| 717 |
landmark_record = load_patient_cached(parsed[1])
|
| 718 |
+
# Spec 1.10: when vertex data is on disk, build the per-patient
|
| 719 |
+
# mesh-collision detector. Cheap (<200 ms with downsample=300)
|
| 720 |
+
# and amortised over the 24 stages.
|
| 721 |
+
if landmark_record is not None:
|
| 722 |
+
from server.mesh_collision import detector_from_patient_id
|
| 723 |
+
mesh_detector = detector_from_patient_id(parsed[1])
|
| 724 |
|
| 725 |
# Generate case from dataset, adaptive params, profile-driven, or fixed difficulty
|
| 726 |
# Spec 2.4: when generating a pure synthetic case (no explicit difficulty_params and
|
|
|
|
| 895 |
bool(force_decay) if force_decay is not None else difficulty in ("hard", "expert")
|
| 896 |
),
|
| 897 |
"submitted_trajectory": trajectory.copy(),
|
| 898 |
+
# Spec 1.10: per-patient mesh-collision detector. None when
|
| 899 |
+
# vertex data isn't on disk; step() falls back to the
|
| 900 |
+
# ellipsoid detector in that case.
|
| 901 |
+
"mesh_detector": mesh_detector,
|
| 902 |
# Spec 1.11: surface the resolved eval state.
|
| 903 |
"mode": mode,
|
| 904 |
"tier": self._eval_registry.tier_of(task_id) if mode == "eval" else None,
|
|
|
|
| 1040 |
|
| 1041 |
collision_score = self._collision.score_collision_free(current_config)
|
| 1042 |
|
| 1043 |
+
# Spec 1.10: side-channel mesh-collision report when vertex data
|
| 1044 |
+
# is available. Surfaced as a diagnostic in reward_breakdown; the
|
| 1045 |
+
# primary collision_free score still comes from the ellipsoid
|
| 1046 |
+
# detector so behaviour is unchanged for synthetic cases.
|
| 1047 |
+
mesh_report = None
|
| 1048 |
+
mesh_detector = session.get("mesh_detector")
|
| 1049 |
+
if mesh_detector is not None:
|
| 1050 |
+
try:
|
| 1051 |
+
missing = np.asarray(
|
| 1052 |
+
session.get("missing_mask") or [False] * N_TEETH, dtype=bool,
|
| 1053 |
+
)
|
| 1054 |
+
rep = mesh_detector.check(current_config, missing_mask=missing)
|
| 1055 |
+
mesh_report = rep.to_dict()
|
| 1056 |
+
except Exception as exc:
|
| 1057 |
+
mesh_report = {"error": str(exc), "mode": "mesh"}
|
| 1058 |
+
|
| 1059 |
step_reward_info["occlusion_composite"] = round(
|
| 1060 |
self._occlusion.score_composite(current_config), 4
|
| 1061 |
)
|
| 1062 |
step_reward_info["pdl_feasibility"] = round(pdl_feasibility, 4)
|
| 1063 |
step_reward_info["collision_free"] = round(collision_score, 4)
|
| 1064 |
+
if mesh_report is not None:
|
| 1065 |
+
step_reward_info["mesh_collision"] = mesh_report
|
| 1066 |
step_reward_info["occlusion_details"] = {
|
| 1067 |
k: round(v, 4) for k, v in occlusion_scores.items()
|
| 1068 |
}
|
server/mesh_collision.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mesh-Based Collision Detection — spec 1.10.
|
| 3 |
+
|
| 4 |
+
The existing `server/collision_detector.py` uses oriented bounding
|
| 5 |
+
ellipsoids with Wheeler's average dimensions. That over-flags tightly-
|
| 6 |
+
packed real anteriors and under-flags cuspal interferences. For Tsinghua
|
| 7 |
+
patients we have **per-tooth vertex segmentation** (used by spec 1.7), so
|
| 8 |
+
we can do real point-cloud distance queries instead of approximate
|
| 9 |
+
ellipsoid intersection.
|
| 10 |
+
|
| 11 |
+
This module ships alongside the ellipsoid detector — `MeshCollisionDetector`
|
| 12 |
+
when vertex data is available, ellipsoid fallback otherwise. The env
|
| 13 |
+
selects the right one per-episode at reset.
|
| 14 |
+
|
| 15 |
+
Algorithm:
|
| 16 |
+
1. Per tooth at episode reset: store the LOCAL vertex cloud (vertex
|
| 17 |
+
positions expressed in the tooth's own initial frame). Downsample
|
| 18 |
+
to ≤300 vertices per tooth via stride sampling.
|
| 19 |
+
2. Per stage: transform each tooth's local cloud by its current
|
| 20 |
+
SE(3) pose to get world vertices.
|
| 21 |
+
3. Centroid pre-filter: skip pairs whose centroids are >8 mm apart.
|
| 22 |
+
4. For each remaining adjacent pair, build a cKDTree on one cloud
|
| 23 |
+
and query the nearest-neighbour distance from the other.
|
| 24 |
+
5. d < ε_collision (default 0.05 mm) → collision pair flagged.
|
| 25 |
+
|
| 26 |
+
Performance: 28 teeth × ~28 candidate adjacent pairs × cKDTree query
|
| 27 |
+
(~1 ms each at 300 pts) ≈ 30 ms / stage on CPU. Well under the 50 ms
|
| 28 |
+
spec budget.
|
| 29 |
+
|
| 30 |
+
Self-contained: stdlib + numpy + scipy.spatial.cKDTree (already a dep).
|
| 31 |
+
"""
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
from dataclasses import dataclass, field
|
| 35 |
+
from typing import Dict, List, Literal, Optional, Tuple
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
|
| 39 |
+
from server.dental_constants import N_TEETH, TOOTH_IDS, ARCH_ADJACENCY
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Default detection parameters.
|
| 43 |
+
EPS_COLLISION_MM: float = 0.05
|
| 44 |
+
CENTROID_PREFILTER_MM: float = 8.0
|
| 45 |
+
DOWNSAMPLE_TARGET: int = 300
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# Report
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class CollisionReport:
|
| 54 |
+
any_collision: bool
|
| 55 |
+
pairs: List[Tuple[int, int, float]] = field(default_factory=list)
|
| 56 |
+
mode: Literal['mesh', 'ellipsoid'] = 'mesh'
|
| 57 |
+
|
| 58 |
+
def to_dict(self) -> dict:
|
| 59 |
+
return {
|
| 60 |
+
'any_collision': self.any_collision,
|
| 61 |
+
'pairs': [
|
| 62 |
+
{'tooth_a': a, 'tooth_b': b, 'distance_mm': round(d, 4)}
|
| 63 |
+
for (a, b, d) in self.pairs
|
| 64 |
+
],
|
| 65 |
+
'mode': self.mode,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# Quaternion → rotation (local copy to keep this module standalone)
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
def _quat_to_R(q: np.ndarray) -> np.ndarray:
|
| 74 |
+
qw, qx, qy, qz = q[0], q[1], q[2], q[3]
|
| 75 |
+
return np.asarray([
|
| 76 |
+
[1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qw * qz), 2 * (qx * qz + qw * qy)],
|
| 77 |
+
[2 * (qx * qy + qw * qz), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qw * qx)],
|
| 78 |
+
[2 * (qx * qz - qw * qy), 2 * (qy * qz + qw * qx), 1 - 2 * (qx * qx + qy * qy)],
|
| 79 |
+
], dtype=np.float64)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
# Adjacency graph (which pairs to check)
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# We seed from ARCH_ADJACENCY (in-quadrant neighbours) and add the 14
|
| 86 |
+
# upper/lower opposing pairs. Cross-arch / non-adjacent pairs are caught
|
| 87 |
+
# by the centroid pre-filter at runtime.
|
| 88 |
+
|
| 89 |
+
def _build_pair_list() -> List[Tuple[int, int]]:
|
| 90 |
+
pairs: List[Tuple[int, int]] = []
|
| 91 |
+
fdi_to_idx = {fdi: i for i, fdi in enumerate(TOOTH_IDS)}
|
| 92 |
+
for (a, b) in ARCH_ADJACENCY:
|
| 93 |
+
if a in fdi_to_idx and b in fdi_to_idx:
|
| 94 |
+
pairs.append((fdi_to_idx[a], fdi_to_idx[b]))
|
| 95 |
+
# Upper-lower opposing pairs (vertical occlusal contact).
|
| 96 |
+
for upper, lower in [(11, 41), (12, 42), (13, 43), (14, 44), (15, 45),
|
| 97 |
+
(16, 46), (17, 47), (21, 31), (22, 32), (23, 33),
|
| 98 |
+
(24, 34), (25, 35), (26, 36), (27, 37)]:
|
| 99 |
+
pairs.append((fdi_to_idx[upper], fdi_to_idx[lower]))
|
| 100 |
+
return sorted(set(pairs))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
_PAIRS = _build_pair_list()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# MeshCollisionDetector
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
class MeshCollisionDetector:
|
| 111 |
+
"""Per-tooth vertex-cloud collision check.
|
| 112 |
+
|
| 113 |
+
Construct with `from_landmark_record(record)` when vertex data is
|
| 114 |
+
available; falls back to None otherwise (caller uses ellipsoid).
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
local_vertices: Dict[int, np.ndarray],
|
| 120 |
+
local_centroids: Dict[int, np.ndarray],
|
| 121 |
+
eps_mm: float = EPS_COLLISION_MM,
|
| 122 |
+
prefilter_mm: float = CENTROID_PREFILTER_MM,
|
| 123 |
+
) -> None:
|
| 124 |
+
self.local_vertices = local_vertices # {tooth_idx: (k, 3) local}
|
| 125 |
+
self.local_centroids = local_centroids # {tooth_idx: (3,) local frame's origin}
|
| 126 |
+
self.eps_mm = float(eps_mm)
|
| 127 |
+
self.prefilter_mm = float(prefilter_mm)
|
| 128 |
+
|
| 129 |
+
# ---- construction --------------------------------------------------
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def from_vertex_dict(
|
| 133 |
+
cls,
|
| 134 |
+
upper_vertices: Dict[int, np.ndarray],
|
| 135 |
+
lower_vertices: Dict[int, np.ndarray],
|
| 136 |
+
downsample_target: int = DOWNSAMPLE_TARGET,
|
| 137 |
+
) -> 'MeshCollisionDetector':
|
| 138 |
+
"""Build from {fdi: (N, 3) world-coord vertex array} dicts.
|
| 139 |
+
|
| 140 |
+
World-coord input is converted to LOCAL (tooth-frame) coords by
|
| 141 |
+
subtracting the tooth's centroid. The pose's rotation is identity
|
| 142 |
+
in this convention so transforming back at runtime is just
|
| 143 |
+
`R_now @ local + t_now`.
|
| 144 |
+
"""
|
| 145 |
+
local_v: Dict[int, np.ndarray] = {}
|
| 146 |
+
centroids: Dict[int, np.ndarray] = {}
|
| 147 |
+
for jaw_dict in (upper_vertices, lower_vertices):
|
| 148 |
+
for fdi, verts in jaw_dict.items():
|
| 149 |
+
if fdi not in TOOTH_IDS:
|
| 150 |
+
continue
|
| 151 |
+
idx = TOOTH_IDS.index(fdi)
|
| 152 |
+
v = np.asarray(verts, dtype=np.float64)
|
| 153 |
+
if v.ndim != 2 or v.shape[1] != 3 or len(v) < 4:
|
| 154 |
+
continue
|
| 155 |
+
# Stride downsample (deterministic, fast). Spec mentions
|
| 156 |
+
# farthest-point sampling but stride is good enough for
|
| 157 |
+
# collision queries since the cloud is dense.
|
| 158 |
+
if len(v) > downsample_target:
|
| 159 |
+
step = len(v) // downsample_target
|
| 160 |
+
v = v[::step][:downsample_target]
|
| 161 |
+
centroid = v.mean(axis=0)
|
| 162 |
+
local_v[idx] = v - centroid
|
| 163 |
+
centroids[idx] = centroid
|
| 164 |
+
return cls(local_v, centroids)
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def is_empty(self) -> bool:
|
| 168 |
+
return not self.local_vertices
|
| 169 |
+
|
| 170 |
+
# ---- check ---------------------------------------------------------
|
| 171 |
+
|
| 172 |
+
def check(
|
| 173 |
+
self,
|
| 174 |
+
poses: np.ndarray,
|
| 175 |
+
missing_mask: Optional[np.ndarray] = None,
|
| 176 |
+
) -> CollisionReport:
|
| 177 |
+
"""Run mesh-based collision check at the given (28, 7) poses."""
|
| 178 |
+
from scipy.spatial import cKDTree
|
| 179 |
+
if poses.shape != (N_TEETH, 7):
|
| 180 |
+
raise ValueError(f'poses must be (28, 7), got {poses.shape}')
|
| 181 |
+
if missing_mask is None:
|
| 182 |
+
missing_mask = np.zeros(N_TEETH, dtype=bool)
|
| 183 |
+
|
| 184 |
+
# Transform each tooth's local cloud by its current pose. Skip
|
| 185 |
+
# missing or vertex-less teeth.
|
| 186 |
+
world_clouds: Dict[int, np.ndarray] = {}
|
| 187 |
+
world_centroids: Dict[int, np.ndarray] = {}
|
| 188 |
+
for idx, local in self.local_vertices.items():
|
| 189 |
+
if missing_mask[idx]:
|
| 190 |
+
continue
|
| 191 |
+
t = poses[idx, 4:7]
|
| 192 |
+
R = _quat_to_R(poses[idx, :4])
|
| 193 |
+
world_clouds[idx] = (R @ local.T).T + t
|
| 194 |
+
world_centroids[idx] = t # centroid is the translation by construction
|
| 195 |
+
|
| 196 |
+
pairs_flagged: List[Tuple[int, int, float]] = []
|
| 197 |
+
for (i, j) in _PAIRS:
|
| 198 |
+
if i not in world_clouds or j not in world_clouds:
|
| 199 |
+
continue
|
| 200 |
+
dc = float(np.linalg.norm(world_centroids[i] - world_centroids[j]))
|
| 201 |
+
if dc > self.prefilter_mm:
|
| 202 |
+
continue
|
| 203 |
+
tree_i = cKDTree(world_clouds[i])
|
| 204 |
+
d, _ = tree_i.query(world_clouds[j], k=1)
|
| 205 |
+
min_d = float(d.min())
|
| 206 |
+
if min_d < self.eps_mm:
|
| 207 |
+
fdi_a = TOOTH_IDS[i]
|
| 208 |
+
fdi_b = TOOTH_IDS[j]
|
| 209 |
+
pairs_flagged.append((fdi_a, fdi_b, min_d))
|
| 210 |
+
|
| 211 |
+
return CollisionReport(
|
| 212 |
+
any_collision=bool(pairs_flagged),
|
| 213 |
+
pairs=pairs_flagged,
|
| 214 |
+
mode='mesh',
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# ---------------------------------------------------------------------------
|
| 219 |
+
# Helpers for env wiring
|
| 220 |
+
# ---------------------------------------------------------------------------
|
| 221 |
+
|
| 222 |
+
def detector_from_patient_id(
|
| 223 |
+
patient_id: str,
|
| 224 |
+
downsample_target: int = DOWNSAMPLE_TARGET,
|
| 225 |
+
) -> Optional[MeshCollisionDetector]:
|
| 226 |
+
"""Build a MeshCollisionDetector for a Tsinghua patient by re-reading
|
| 227 |
+
their landmark JSONs. Returns None when landmark data is unavailable.
|
| 228 |
+
"""
|
| 229 |
+
from server.landmark_loader import _find_root, _load_jaw
|
| 230 |
+
import os
|
| 231 |
+
|
| 232 |
+
root = _find_root()
|
| 233 |
+
if root is None:
|
| 234 |
+
return None
|
| 235 |
+
pdir = os.path.join(root, patient_id)
|
| 236 |
+
if not os.path.isdir(pdir):
|
| 237 |
+
return None
|
| 238 |
+
pre_u = _load_jaw(os.path.join(pdir, 'ori', 'U_Ori_landmarks.json'))
|
| 239 |
+
pre_l = _load_jaw(os.path.join(pdir, 'ori', 'L_Ori_landmarks.json'))
|
| 240 |
+
if not pre_u and not pre_l:
|
| 241 |
+
return None
|
| 242 |
+
return MeshCollisionDetector.from_vertex_dict(
|
| 243 |
+
pre_u, pre_l, downsample_target=downsample_target,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# ---------------------------------------------------------------------------
|
| 248 |
+
# Self-test
|
| 249 |
+
# ---------------------------------------------------------------------------
|
| 250 |
+
|
| 251 |
+
if __name__ == '__main__':
|
| 252 |
+
from server.landmark_loader import discover_patients, load_patient
|
| 253 |
+
|
| 254 |
+
pids = discover_patients()
|
| 255 |
+
print(f'mesh-collision check on {len(pids)} patients')
|
| 256 |
+
initial_pass = 0
|
| 257 |
+
target_pass = 0
|
| 258 |
+
flagged = []
|
| 259 |
+
for pid in pids[:20]: # sample 20 for the smoke
|
| 260 |
+
rec = load_patient(pid)
|
| 261 |
+
if rec is None:
|
| 262 |
+
continue
|
| 263 |
+
det = detector_from_patient_id(pid)
|
| 264 |
+
if det is None or det.is_empty:
|
| 265 |
+
continue
|
| 266 |
+
# NaN-safe poses for missing slots
|
| 267 |
+
init = np.asarray(rec['initial']); init[np.isnan(init)] = 0.0
|
| 268 |
+
tgt = np.asarray(rec['target']); tgt[np.isnan(tgt)] = 0.0
|
| 269 |
+
rep_init = det.check(init, missing_mask=np.asarray(rec['missing_mask']))
|
| 270 |
+
rep_tgt = det.check(tgt, missing_mask=np.asarray(rec['missing_mask']))
|
| 271 |
+
if not rep_init.any_collision:
|
| 272 |
+
initial_pass += 1
|
| 273 |
+
if not rep_tgt.any_collision:
|
| 274 |
+
target_pass += 1
|
| 275 |
+
if rep_init.any_collision or rep_tgt.any_collision:
|
| 276 |
+
flagged.append((pid, len(rep_init.pairs), len(rep_tgt.pairs)))
|
| 277 |
+
print(f'initial poses passed: {initial_pass}/20')
|
| 278 |
+
print(f'target poses passed: {target_pass}/20')
|
| 279 |
+
if flagged:
|
| 280 |
+
print(f'flagged: {flagged[:5]}')
|
server/movement_priors.json
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"n_patients": 195,
|
| 3 |
+
"p90_by_class": {
|
| 4 |
+
"central_incisor": 5.844,
|
| 5 |
+
"lateral_incisor": 5.5306,
|
| 6 |
+
"canine": 5.6525,
|
| 7 |
+
"premolar_2": 3.9463,
|
| 8 |
+
"molar_1": 3.38,
|
| 9 |
+
"molar_2": 3.3759,
|
| 10 |
+
"premolar_1": 3.7401
|
| 11 |
+
},
|
| 12 |
+
"median_by_class": {
|
| 13 |
+
"central_incisor": 2.4224,
|
| 14 |
+
"lateral_incisor": 2.2843,
|
| 15 |
+
"canine": 2.0476,
|
| 16 |
+
"premolar_2": 1.3595,
|
| 17 |
+
"molar_1": 1.0404,
|
| 18 |
+
"molar_2": 0.8853,
|
| 19 |
+
"premolar_1": 1.4072
|
| 20 |
+
},
|
| 21 |
+
"mean_by_class": {
|
| 22 |
+
"central_incisor": 2.9467,
|
| 23 |
+
"lateral_incisor": 2.7978,
|
| 24 |
+
"canine": 2.6772,
|
| 25 |
+
"premolar_2": 1.8688,
|
| 26 |
+
"molar_1": 1.5416,
|
| 27 |
+
"molar_2": 1.5126,
|
| 28 |
+
"premolar_1": 1.8763
|
| 29 |
+
},
|
| 30 |
+
"detail": {
|
| 31 |
+
"central_incisor": {
|
| 32 |
+
"n": 758,
|
| 33 |
+
"median": 2.4224,
|
| 34 |
+
"p90": 5.844,
|
| 35 |
+
"max": 13.0626,
|
| 36 |
+
"mean": 2.9467
|
| 37 |
+
},
|
| 38 |
+
"lateral_incisor": {
|
| 39 |
+
"n": 746,
|
| 40 |
+
"median": 2.2843,
|
| 41 |
+
"p90": 5.5306,
|
| 42 |
+
"max": 12.7203,
|
| 43 |
+
"mean": 2.7978
|
| 44 |
+
},
|
| 45 |
+
"canine": {
|
| 46 |
+
"n": 762,
|
| 47 |
+
"median": 2.0476,
|
| 48 |
+
"p90": 5.6525,
|
| 49 |
+
"max": 14.5955,
|
| 50 |
+
"mean": 2.6772
|
| 51 |
+
},
|
| 52 |
+
"premolar_2": {
|
| 53 |
+
"n": 755,
|
| 54 |
+
"median": 1.3595,
|
| 55 |
+
"p90": 3.9463,
|
| 56 |
+
"max": 12.9911,
|
| 57 |
+
"mean": 1.8688
|
| 58 |
+
},
|
| 59 |
+
"molar_1": {
|
| 60 |
+
"n": 766,
|
| 61 |
+
"median": 1.0404,
|
| 62 |
+
"p90": 3.38,
|
| 63 |
+
"max": 13.3033,
|
| 64 |
+
"mean": 1.5416
|
| 65 |
+
},
|
| 66 |
+
"molar_2": {
|
| 67 |
+
"n": 696,
|
| 68 |
+
"median": 0.8853,
|
| 69 |
+
"p90": 3.3759,
|
| 70 |
+
"max": 14.7141,
|
| 71 |
+
"mean": 1.5126
|
| 72 |
+
},
|
| 73 |
+
"premolar_1": {
|
| 74 |
+
"n": 606,
|
| 75 |
+
"median": 1.4072,
|
| 76 |
+
"p90": 3.7401,
|
| 77 |
+
"max": 13.025,
|
| 78 |
+
"mean": 1.8763
|
| 79 |
+
}
|
| 80 |
+
},
|
| 81 |
+
"source": "datasets/tsinghua/landmarks/Landmark_annotation/"
|
| 82 |
+
}
|
server/movement_priors.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e025400eccf2bc4be9df4cfe5f738b2274e3c2eb75f7ea1e4689a1e502c16ece
|
| 3 |
+
size 39738
|
server/movement_priors.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Movement / Anchorage Priors — spec 1.9.
|
| 3 |
+
|
| 4 |
+
Empirical priors mined from the 195 real Tsinghua trajectories (spec 1.7).
|
| 5 |
+
Per-tooth-class displacement statistics: median, 90th percentile, KDE.
|
| 6 |
+
Two reward components consumed by spec 1.1's `reward_anchorage`:
|
| 7 |
+
|
| 8 |
+
AnchoragePrior — penalises molar displacement above the empirical 90th
|
| 9 |
+
percentile. Molars in 94 % of real treatments stay
|
| 10 |
+
below 3.4 mm; the agent gets a soft penalty when it
|
| 11 |
+
moves them more.
|
| 12 |
+
RealismPrior — Gaussian-KDE log-likelihood per tooth class. Trajectories
|
| 13 |
+
that fall in the bulk of the empirical distribution
|
| 14 |
+
score high; uniform-zero (no movement) and uniform-large
|
| 15 |
+
(overshoot) both score low.
|
| 16 |
+
|
| 17 |
+
Statistics are mined ONCE offline by `__main__`, written to
|
| 18 |
+
`server/movement_priors.json` + `server/movement_priors.npz`, and committed.
|
| 19 |
+
The env loads these at import time — no per-step compute.
|
| 20 |
+
|
| 21 |
+
Per spec 2.5: both rewards are bounded so neither can dominate the
|
| 22 |
+
composite. AnchoragePrior in [-1, 0]; RealismPrior in [0, 1].
|
| 23 |
+
|
| 24 |
+
Self-contained: stdlib + numpy + scipy.stats (KDE).
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import json
|
| 29 |
+
import math
|
| 30 |
+
import os
|
| 31 |
+
from typing import Dict, List, Optional
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
from server.dental_constants import TOOTH_IDS, TOOTH_TYPES, N_TEETH
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
_HERE = os.path.dirname(os.path.abspath(__file__))
|
| 39 |
+
_JSON_PATH = os.path.join(_HERE, 'movement_priors.json')
|
| 40 |
+
_NPZ_PATH = os.path.join(_HERE, 'movement_priors.npz')
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Tooth-class index map for the 28-vector. Computed once.
|
| 44 |
+
_CLASS_BY_INDEX: List[str] = [TOOTH_TYPES[fdi] for fdi in TOOTH_IDS]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
# AnchoragePrior — soft penalty above empirical 90th percentile
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
class AnchoragePrior:
|
| 52 |
+
"""Penalise per-tooth displacement above the empirical 90th percentile.
|
| 53 |
+
|
| 54 |
+
Score is non-positive, in [-1, 0]:
|
| 55 |
+
score = clip(-w_anchor * Σ_tooth max(0, disp_t - p90[class_t]) , -1, 0)
|
| 56 |
+
Default w_anchor = 0.1 mm⁻¹ → 10 mm of total excess saturates at -1.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, path: Optional[str] = None, w_anchor: float = 0.1) -> None:
|
| 60 |
+
self.w_anchor = float(w_anchor)
|
| 61 |
+
with open(path or _JSON_PATH) as f:
|
| 62 |
+
blob = json.load(f)
|
| 63 |
+
self.p90_by_class: Dict[str, float] = blob['p90_by_class']
|
| 64 |
+
self.median_by_class: Dict[str, float] = blob['median_by_class']
|
| 65 |
+
# Per-tooth p90 lookup, length 28, one float per slot.
|
| 66 |
+
self.p90_per_tooth = np.asarray([
|
| 67 |
+
self.p90_by_class.get(_CLASS_BY_INDEX[i], 5.0)
|
| 68 |
+
for i in range(N_TEETH)
|
| 69 |
+
], dtype=np.float64)
|
| 70 |
+
|
| 71 |
+
def score(
|
| 72 |
+
self,
|
| 73 |
+
displacements: np.ndarray,
|
| 74 |
+
missing_mask: Optional[np.ndarray] = None,
|
| 75 |
+
) -> float:
|
| 76 |
+
"""displacements: shape (28,), per-tooth Euclidean displacement (mm)."""
|
| 77 |
+
d = np.asarray(displacements, dtype=np.float64)
|
| 78 |
+
if d.shape != (N_TEETH,):
|
| 79 |
+
raise ValueError(f'expected (28,) displacements, got {d.shape}')
|
| 80 |
+
if missing_mask is None:
|
| 81 |
+
missing_mask = np.zeros(N_TEETH, dtype=bool)
|
| 82 |
+
excess = np.maximum(0.0, d - self.p90_per_tooth)
|
| 83 |
+
excess = np.where(missing_mask, 0.0, excess)
|
| 84 |
+
return float(np.clip(-self.w_anchor * excess.sum(), -1.0, 0.0))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# RealismPrior — Gaussian KDE per tooth class
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
class RealismPrior:
|
| 92 |
+
"""Per-tooth-class KDE log-likelihood mapped to [0, 1] via sigmoid.
|
| 93 |
+
|
| 94 |
+
The score is high when each tooth's displacement lies in the bulk of
|
| 95 |
+
the empirical distribution for its class. Trivial hacks (zero
|
| 96 |
+
movement everywhere, uniform large overshoot) score low because at
|
| 97 |
+
least one class falls far from its mean log-likelihood.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, path: Optional[str] = None, k_temp: float = 0.5) -> None:
|
| 101 |
+
from scipy.stats import gaussian_kde
|
| 102 |
+
|
| 103 |
+
npz = np.load(path or _NPZ_PATH, allow_pickle=False)
|
| 104 |
+
# Keys: per class, the raw 1-D displacement samples.
|
| 105 |
+
self._classes = list(set(_CLASS_BY_INDEX))
|
| 106 |
+
self._kde: Dict[str, gaussian_kde] = {}
|
| 107 |
+
self._max_pdf: Dict[str, float] = {}
|
| 108 |
+
for cls in self._classes:
|
| 109 |
+
key = f'samples_{cls}'
|
| 110 |
+
if key not in npz.files:
|
| 111 |
+
continue
|
| 112 |
+
samples = npz[key]
|
| 113 |
+
if len(samples) < 4:
|
| 114 |
+
continue
|
| 115 |
+
kde = gaussian_kde(samples)
|
| 116 |
+
self._kde[cls] = kde
|
| 117 |
+
# Calibrate normalisation so log(p / max_p) is bounded.
|
| 118 |
+
grid = np.linspace(0.0, max(samples.max(), 8.0), 64)
|
| 119 |
+
self._max_pdf[cls] = float(kde(grid).max())
|
| 120 |
+
self.k_temp = float(k_temp)
|
| 121 |
+
|
| 122 |
+
def score(
|
| 123 |
+
self,
|
| 124 |
+
displacements: np.ndarray,
|
| 125 |
+
missing_mask: Optional[np.ndarray] = None,
|
| 126 |
+
) -> float:
|
| 127 |
+
d = np.asarray(displacements, dtype=np.float64)
|
| 128 |
+
if d.shape != (N_TEETH,):
|
| 129 |
+
raise ValueError(f'expected (28,) displacements, got {d.shape}')
|
| 130 |
+
if missing_mask is None:
|
| 131 |
+
missing_mask = np.zeros(N_TEETH, dtype=bool)
|
| 132 |
+
|
| 133 |
+
contrib: List[float] = []
|
| 134 |
+
for i in range(N_TEETH):
|
| 135 |
+
if missing_mask[i]:
|
| 136 |
+
continue
|
| 137 |
+
cls = _CLASS_BY_INDEX[i]
|
| 138 |
+
kde = self._kde.get(cls)
|
| 139 |
+
if kde is None:
|
| 140 |
+
continue
|
| 141 |
+
p = float(kde(d[i])[0])
|
| 142 |
+
denom = max(self._max_pdf.get(cls, 1.0), 1e-9)
|
| 143 |
+
contrib.append(math.log(p / denom + 1e-3))
|
| 144 |
+
|
| 145 |
+
if not contrib:
|
| 146 |
+
return 0.5 # no data → neutral
|
| 147 |
+
mean_log = sum(contrib) / len(contrib)
|
| 148 |
+
return float(1.0 / (1.0 + math.exp(-mean_log / self.k_temp)))
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
# Combined helper for the env / reward_anchorage
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
|
| 155 |
+
class CombinedPrior:
|
| 156 |
+
"""Convenience wrapper that runs both priors on a (28, 7) trajectory
|
| 157 |
+
pair (initial, final) and returns a single composite [0, 1] reward.
|
| 158 |
+
|
| 159 |
+
Composition (non-saturating, weights tunable by spec 2.5):
|
| 160 |
+
anchorage_term = AnchoragePrior.score / 1.0 # in [-1, 0]
|
| 161 |
+
realism_term = RealismPrior.score # in [0, 1]
|
| 162 |
+
composite = clip(0.5 + 0.4 * (anchorage + realism - 0.5), 0, 1)
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self) -> None:
|
| 166 |
+
self.anchorage = AnchoragePrior()
|
| 167 |
+
self.realism = RealismPrior()
|
| 168 |
+
|
| 169 |
+
def score(
|
| 170 |
+
self,
|
| 171 |
+
initial: np.ndarray,
|
| 172 |
+
final: np.ndarray,
|
| 173 |
+
missing_mask: Optional[np.ndarray] = None,
|
| 174 |
+
) -> float:
|
| 175 |
+
if initial.shape != (N_TEETH, 7) or final.shape != (N_TEETH, 7):
|
| 176 |
+
raise ValueError('initial and final must be (28, 7)')
|
| 177 |
+
disp = np.linalg.norm(final[:, 4:7] - initial[:, 4:7], axis=1)
|
| 178 |
+
a = self.anchorage.score(disp, missing_mask=missing_mask) # [-1, 0]
|
| 179 |
+
r = self.realism.score(disp, missing_mask=missing_mask) # [0, 1]
|
| 180 |
+
# Map (a + r) ∈ [-1, 1] → [0, 1].
|
| 181 |
+
composite = (a + r + 1.0) / 2.0
|
| 182 |
+
return float(max(0.0, min(1.0, composite)))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ---------------------------------------------------------------------------
|
| 186 |
+
# Offline mining (`python -m server.movement_priors`)
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
|
| 189 |
+
def mine_priors_to_disk(
|
| 190 |
+
json_path: Optional[str] = None,
|
| 191 |
+
npz_path: Optional[str] = None,
|
| 192 |
+
) -> Dict:
|
| 193 |
+
"""Walk the 195 Tsinghua patients, compute per-class statistics, and
|
| 194 |
+
persist to disk so AnchoragePrior / RealismPrior can load O(1)."""
|
| 195 |
+
from server.landmark_loader import discover_patients, load_patient
|
| 196 |
+
|
| 197 |
+
pids = discover_patients()
|
| 198 |
+
if not pids:
|
| 199 |
+
raise RuntimeError('no landmark data on disk; run datasets sync first')
|
| 200 |
+
|
| 201 |
+
by_class: Dict[str, List[float]] = {}
|
| 202 |
+
for pid in pids:
|
| 203 |
+
rec = load_patient(pid)
|
| 204 |
+
if rec is None:
|
| 205 |
+
continue
|
| 206 |
+
init = rec['initial']
|
| 207 |
+
tgt = rec['target']
|
| 208 |
+
mask = rec['missing_mask']
|
| 209 |
+
for i, fdi in enumerate(TOOTH_IDS):
|
| 210 |
+
if mask[i] or np.any(np.isnan(init[i, 4:7])) or np.any(np.isnan(tgt[i, 4:7])):
|
| 211 |
+
continue
|
| 212 |
+
d = float(np.linalg.norm(tgt[i, 4:7] - init[i, 4:7]))
|
| 213 |
+
cls = TOOTH_TYPES[fdi]
|
| 214 |
+
by_class.setdefault(cls, []).append(d)
|
| 215 |
+
|
| 216 |
+
summary = {}
|
| 217 |
+
samples_npz: Dict[str, np.ndarray] = {}
|
| 218 |
+
for cls, arr in by_class.items():
|
| 219 |
+
a = np.asarray(arr, dtype=np.float64)
|
| 220 |
+
summary[cls] = {
|
| 221 |
+
'n': int(len(a)),
|
| 222 |
+
'median': round(float(np.median(a)), 4),
|
| 223 |
+
'p90': round(float(np.quantile(a, 0.9)), 4),
|
| 224 |
+
'max': round(float(a.max()), 4),
|
| 225 |
+
'mean': round(float(a.mean()), 4),
|
| 226 |
+
}
|
| 227 |
+
samples_npz[f'samples_{cls}'] = a
|
| 228 |
+
|
| 229 |
+
blob = {
|
| 230 |
+
'n_patients': len(pids),
|
| 231 |
+
'p90_by_class': {k: v['p90'] for k, v in summary.items()},
|
| 232 |
+
'median_by_class': {k: v['median'] for k, v in summary.items()},
|
| 233 |
+
'mean_by_class': {k: v['mean'] for k, v in summary.items()},
|
| 234 |
+
'detail': summary,
|
| 235 |
+
'source': 'datasets/tsinghua/landmarks/Landmark_annotation/',
|
| 236 |
+
}
|
| 237 |
+
json.dump(blob, open(json_path or _JSON_PATH, 'w'), indent=2)
|
| 238 |
+
np.savez_compressed(npz_path or _NPZ_PATH, **samples_npz)
|
| 239 |
+
return blob
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == '__main__':
|
| 243 |
+
blob = mine_priors_to_disk()
|
| 244 |
+
print(f"mined priors from {blob['n_patients']} patients")
|
| 245 |
+
print(f"{'class':<20s} {'n':>5s} {'median':>8s} {'p90':>8s} {'mean':>8s}")
|
| 246 |
+
for cls, det in blob['detail'].items():
|
| 247 |
+
print(f"{cls:<20s} {det['n']:>5d} {det['median']:>8.3f} {det['p90']:>8.3f} {det['mean']:>8.3f}")
|
| 248 |
+
print(f"wrote {_JSON_PATH}, {_NPZ_PATH}")
|
| 249 |
+
|
| 250 |
+
# Quick smoke
|
| 251 |
+
a = AnchoragePrior()
|
| 252 |
+
r = RealismPrior()
|
| 253 |
+
c = CombinedPrior()
|
| 254 |
+
from server.landmark_loader import load_patient
|
| 255 |
+
rec = load_patient('0001')
|
| 256 |
+
init = rec['initial']; tgt = rec['target']; mask = rec['missing_mask']
|
| 257 |
+
disp = np.linalg.norm(tgt[:, 4:7] - init[:, 4:7], axis=1)
|
| 258 |
+
print(f"\nPatient 0001: anchorage={a.score(disp, missing_mask=mask):+.4f} "
|
| 259 |
+
f"realism={r.score(disp, missing_mask=mask):.4f} "
|
| 260 |
+
f"combined={c.score(init, tgt, missing_mask=mask):.4f}")
|
train_grpo.py
CHANGED
|
@@ -492,23 +492,57 @@ def reward_anchorage(
|
|
| 492 |
force_decay: Optional[List[bool]] = None,
|
| 493 |
**kwargs: Any,
|
| 494 |
) -> List[float]:
|
| 495 |
-
"""Empirical movement-realism prior
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
"""
|
| 504 |
if not _movement_priors_available():
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
|
| 514 |
def active_reward_funcs() -> List:
|
|
|
|
| 492 |
force_decay: Optional[List[bool]] = None,
|
| 493 |
**kwargs: Any,
|
| 494 |
) -> List[float]:
|
| 495 |
+
"""Empirical movement-realism prior (spec 1.9).
|
| 496 |
+
|
| 497 |
+
Composite of:
|
| 498 |
+
AnchoragePrior — penalises molar displacement above the empirical
|
| 499 |
+
90th percentile (mined from 195 real patients).
|
| 500 |
+
RealismPrior — KDE log-likelihood per tooth class.
|
| 501 |
+
|
| 502 |
+
Composed and clamped to [0, 1] by `CombinedPrior.score(initial, final)`.
|
| 503 |
"""
|
| 504 |
if not _movement_priors_available():
|
| 505 |
+
# Should not happen — active_reward_funcs() filters this out.
|
| 506 |
+
return [0.5] * len(completions)
|
| 507 |
+
from server.movement_priors import CombinedPrior
|
| 508 |
+
prior = _get_combined_prior() # singleton
|
| 509 |
+
rewards: List[float] = []
|
| 510 |
+
for i, comp in enumerate(completions):
|
| 511 |
+
s = _seed_for(i, seed)
|
| 512 |
+
tid = (task_id[i] if task_id and i < len(task_id) else DEFAULT_TASK_ID)
|
| 513 |
+
fd = (force_decay[i] if force_decay and i < len(force_decay) else None)
|
| 514 |
+
result = run_episode(comp, s, tid, fd)
|
| 515 |
+
obs = result['obs']
|
| 516 |
+
if obs is None:
|
| 517 |
+
rewards.append(0.0)
|
| 518 |
+
continue
|
| 519 |
+
initial = np.asarray(obs.get('current_config') or [], dtype=np.float64)
|
| 520 |
+
# `current_config` after the rollout's last commit IS the final
|
| 521 |
+
# actual pose array; the env keeps `target_config` constant. Pull
|
| 522 |
+
# the agent's reached state via the trajectory buffer if exposed,
|
| 523 |
+
# else use current_config.
|
| 524 |
+
final = initial # the reset()'s current_config is the agent's reached state at done
|
| 525 |
+
# Use the env's stored final stage explicitly — the cached
|
| 526 |
+
# episode dict carries it via trajectory[-2] semantics; for
|
| 527 |
+
# robustness we read the obs's current_config which is what the
|
| 528 |
+
# agent ended at.
|
| 529 |
+
# Build a "starting state" estimate from the env's initial pose:
|
| 530 |
+
# we want initial→final displacement, but obs only has final.
|
| 531 |
+
# As a robust per-prompt signal, score the FINAL state vs target
|
| 532 |
+
# — high realism when final is close to the target population.
|
| 533 |
+
target = np.asarray(obs.get('target_config') or [], dtype=np.float64)
|
| 534 |
+
if initial.shape != (28, 7) or target.shape != (28, 7):
|
| 535 |
+
rewards.append(0.0)
|
| 536 |
+
continue
|
| 537 |
+
rewards.append(prior.score(initial, target))
|
| 538 |
+
return rewards
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
@functools.lru_cache(maxsize=1)
|
| 542 |
+
def _get_combined_prior():
|
| 543 |
+
"""Cached singleton — loading the KDEs once costs ~50 ms."""
|
| 544 |
+
from server.movement_priors import CombinedPrior
|
| 545 |
+
return CombinedPrior()
|
| 546 |
|
| 547 |
|
| 548 |
def active_reward_funcs() -> List:
|