""" Pydantic models for the Dental Aligner Trajectory Planning environment. """ from typing import Dict, List, Optional from pydantic import BaseModel, ConfigDict, Field from openenv.core.env_server.types import Action, Observation, State class ToothTrajectoryStage(BaseModel): """A single stage in the agent's planned trajectory.""" model_config = ConfigDict(extra='forbid', validate_assignment=True) stage_index: int = Field(..., ge=1, le=24, description='Stage 1-24 (0=initial, 25=final are fixed)') poses: List[List[float]] = Field(..., description='28 poses, each [qw,qx,qy,qz,tx,ty,tz]') tooth_ids: List[int] = Field(..., description='28 FDI tooth IDs in same order as poses') class AlignerAction(Action): """Agent action: the full planned trajectory or revised remaining stages.""" model_config = ConfigDict(extra='forbid', validate_assignment=True, arbitrary_types_allowed=True) trajectory: List[ToothTrajectoryStage] = Field( default_factory=list, description='Agent planned stages. Length 24 for full plan, or stages_remaining for revised plan.' ) reasoning: str = Field(default='', description="Agent's planning rationale") confidence: float = Field(default=0.5, ge=0.0, le=1.0) class ToothPoseTableRow(BaseModel): """One row in the tooth pose table — current vs target for one tooth.""" model_config = ConfigDict(extra='forbid', validate_assignment=True) tooth_id: int tooth_type: str current_qw: float current_qx: float current_qy: float current_qz: float current_tx: float current_ty: float current_tz: float target_qw: float target_qx: float target_qy: float target_qz: float target_tx: float target_ty: float target_tz: float remaining_trans_mm: float = Field(..., description='Euclidean distance to target in mm') remaining_rot_deg: float = Field(..., description='Angular distance to target in degrees') class AlignerObservation(Observation): """Observation returned to agent after reset() or step().""" model_config = ConfigDict(extra='forbid', validate_assignment=True, arbitrary_types_allowed=True) # done: bool and reward: float|None are INHERITED from Observation current_stage: int = Field(default=0, description='Current stage index (0=initial)') stages_remaining: int = Field(default=24, description='Number of stages the agent must still plan') task_id: str = Field(default='') task_description: str = Field(default='') tooth_table: List[ToothPoseTableRow] = Field(default_factory=list) tooth_table_text: str = Field(default='', description='Markdown table for text-based agents') arch_graph_json: str = Field(default='', description='JSON adjacency list for GNN agents') baseline_trajectory_json: str = Field(default='', description='Deprecated. SLERP baseline removed from observation per spec 1.4 (no answer leakage to agent).') adversarial_jitter_applied: bool = Field(default=False) jitter_description: str = Field(default='') last_plan_feedback: str = Field(default='') step_number: int = Field(default=0) class AlignerState(State): """Environment state (episode-level tracking).""" # episode_id and step_count are INHERITED; State has extra='allow' task_id: str = Field(default='') difficulty: str = Field(default='easy') current_stage: int = Field(default=0) seed: int = Field(default=0) total_violations: int = Field(default=0) adversarial_perturbations: int = Field(default=0) best_trajectory_score: float = Field(default=0.0) coordinate_frame: str = Field(default='dental_v1') missing_mask: List[bool] = Field(default_factory=lambda: [False] * 28) oracle_trajectory: Optional[List[ToothTrajectoryStage]] = Field( default=None, description='Spec 1.8: clinical-rule oracle staging of (init, target). ' 'Length 24 (intermediate stages only). Read-only. Populated ' 'only when /state is queried with ?include_oracle=true.', ) b2b_head_probs: Optional[Dict[str, float]] = Field( default=None, description='Spec 3.8: per-head probabilities from the learned-occlusion ' 'classifier (Bits2Bites heads). Keys: class_i_healthy, ' 'overbite_normal, overjet_normal, crowding_minimal. Each in [0,1]. ' 'Populated by the most recent stepwise step on the active session.', ) eval_mode: bool = Field( default=False, description='Spec 1.11: True when the active episode was reset with ' '``mode="eval"``; downstream graders use this to record ' 'which split produced a reward.', ) eval_tier: Optional[int] = Field( default=None, description='Spec 1.11: 1 / 2 / 3 when the episode is on a held-out tier; ' 'None for train-mode episodes.', ) prompt_includes_strategy: bool = Field( default=False, description='Spec 1.14: True when the active episode was reset with ' '``include_strategy_in_prompt=True``; lets the BC-SFT data ' 'builder assert the correct prompt template per checkpoint.', ) # --------------------------------------------------------------------------- # Spec 2.8 — ReasoningTrace schema # --------------------------------------------------------------------------- class ToolCallTrace(BaseModel): """A single tool invocation captured during a rollout (spec 2.8).""" model_config = ConfigDict(extra='forbid') name: str arguments: Dict = Field(default_factory=dict) result: Dict = Field(default_factory=dict) duration_ms: int = Field(default=0) class DiagnosisStep(BaseModel): """The pre-rollout diagnosis emitted by the agent (spec 2.8 + 2.1 tools).""" model_config = ConfigDict(extra='forbid') tool_calls: List[ToolCallTrace] = Field(default_factory=list) diagnosis_text: str = Field(default='') ground_truth_class: Optional[str] = Field(default=None) match: bool = Field(default=False) class StrategyStep(BaseModel): """The strategy chosen by the agent before staging (spec 2.8 + 1.2).""" model_config = ConfigDict(extra='forbid') chosen: str = Field(default='') rationale: str = Field(default='') gold_strategy: Optional[str] = Field(default=None) multiplier: float = Field(default=1.0) class StageTrace(BaseModel): """Per-stage record (1..24) including poses, tool calls, and reward break-down.""" model_config = ConfigDict(extra='forbid') stage_index: int tool_calls: List[ToolCallTrace] = Field(default_factory=list) reasoning: str = Field(default='') action_poses: List[List[float]] = Field(default_factory=list) reward_breakdown: Dict = Field(default_factory=dict) collisions: int = Field(default=0) pdl_violations: int = Field(default=0) class ReasoningTrace(BaseModel): """Full rollout trace for spec 2.8's Agent Console. Produced by ``scripts/precompute_traces.py`` (cached path) or live by ``POST /demo/trace`` once an inference backend is wired. """ model_config = ConfigDict(extra='forbid') schema_version: int = Field(default=1) case_id: str case_label: str = Field(default='') model_name: str = Field(default='unknown') diagnosis: Optional[DiagnosisStep] = Field(default=None) strategy: Optional[StrategyStep] = Field(default=None) stages: List[StageTrace] = Field(default_factory=list) final_reward: Optional[float] = Field(default=None) reward_breakdown_terminal: Dict = Field(default_factory=dict) # --------------------------------------------------------------------------- # Stepwise (24-step sequential) models # --------------------------------------------------------------------------- class StepwiseAction(BaseModel): """Agent action for one stage in stepwise mode: 28 tooth poses.""" model_config = ConfigDict(extra='forbid', validate_assignment=True) poses: List[List[float]] = Field( ..., description='28 poses for the next stage, each [qw,qx,qy,qz,tx,ty,tz]' ) memo: str = Field( default='', description='Spec 3.7: optional free-form treatment memo. Graded only ' 'when committed at stage 12 (midpoint); ignored otherwise. ' 'A non-empty graded memo contributes a bonus reward in [0, 0.1] ' 'via reward_breakdown["reward_memo"].', ) class StepwiseObservation(BaseModel): """Observation returned after each step in stepwise mode.""" model_config = ConfigDict(extra='forbid', validate_assignment=True) task_id: str = Field(default='') current_stage: int = Field(default=0, description='Current stage (0-24)') stages_remaining: int = Field(default=24) done: bool = Field(default=False) # Pose data current_config: List[List[float]] = Field( default_factory=list, description='Current 28x7 tooth poses after last committed stage' ) target_config: List[List[float]] = Field( default_factory=list, description='Target 28x7 tooth poses (goal)' ) # Progress tracking per_tooth_progress: List[float] = Field( default_factory=list, description='Per-tooth % of distance covered toward target (0.0-1.0)' ) cumulative_violations: int = Field(default=0, description='Total constraint violations so far') # Rewards step_reward: Optional[float] = Field(default=None, description='Dense reward for last step') terminal_reward: Optional[float] = Field(default=None, description='Final graded score (only when done)') reward_breakdown: Optional[dict] = Field(default=None, description='Per-component reward details') # Context stage_history_summary: str = Field( default='', description='Text summary of movement so far' ) tooth_table_text: str = Field(default='', description='Markdown table of current vs target poses') data_source: str = Field(default='synthetic', description='Dataset source used for this episode') clinical_profile: Optional[dict] = Field( default=None, description='Tsinghua clinical profile (spec 1.2): malocclusion, crowding, ' 'overbite, overjet, difficulty_level, patient_id.', ) case_source: str = Field( default='synthetic', description='Spec 1.7: dataset source for this episode — ' '``synthetic``, ``real_tsinghua``, ``ofj``, or ``bits2bites``.', ) patient_id: Optional[str] = Field( default=None, description='Spec 1.7: patient identifier when the case comes from real ' 'clinical data; None for synthetic cases.', ) coordinate_frame: str = Field( default='dental_v1', description='Spec 1.6: canonical dental frame ID. Always set; ' 'guarantees cross-dataset comparability.', ) missing_mask: List[bool] = Field( default_factory=lambda: [False] * 28, description='Spec 1.6: True for teeth absent in this case ' '(extractions, missing landmarks). Length 28, FDI order. ' 'Reward functions ignore True slots.', ) class ToolCall(BaseModel): """Agent tool call in stepwise mode.""" model_config = ConfigDict(extra='forbid', validate_assignment=True) tool: str = Field( ..., description='Tool name: inspect_tooth, simulate_step, check_collisions, commit_stage, rollback_stage' ) args: dict = Field(default_factory=dict, description='Tool arguments') class ToolResult(BaseModel): """Result of a tool call.""" model_config = ConfigDict(extra='forbid', validate_assignment=True) tool: str success: bool = True result: dict = Field(default_factory=dict) error: Optional[str] = None