genga-kimodo / constraints_schema.py
WalkingOnSaturn's picture
initial: Gradio API server (Kimodo-SOMA-RP-v1.1) + constraints schema
c781b57 verified
Raw
History Blame
3.93 kB
"""Pydantic schema for the Kimodo constraint payload accepted by server.py.
Mirrors the JSON shape produced by the NVIDIA authoring demo so we can fall
back on the official kimodo.constraints classes for inference. The webapp
sends a JSON-stringified list of these objects in the Gradio `constraints_json`
arg.
Coordinates: Y-up, meters, character-local. Frame indices are 0-based within
the generated clip. The root is canonicalized to the XZ origin at frame 0.
"""
from __future__ import annotations
from typing import List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, field_validator
Vec2 = Tuple[float, float]
Vec3 = Tuple[float, float, float]
class Root2DConstraint(BaseModel):
type: Literal["root2d"]
frame_indices: List[int]
smooth_root_2d: List[Vec2]
global_root_heading: Optional[List[Vec2]] = None
@field_validator("frame_indices")
@classmethod
def _frames_match_length(cls, v, info):
# We can only sanity-check inside this constraint; cross-list checks
# happen in server.py once num_frames is known.
if not v:
raise ValueError("root2d constraint must have at least one frame_index")
return v
class FullBodyConstraint(BaseModel):
type: Literal["fullbody"]
frame_indices: List[int]
root_positions: List[Vec3]
local_joints_rot: List[List[Vec3]]
smooth_root_2d: Optional[List[Vec2]] = None
class EndEffectorConstraint(BaseModel):
type: Literal[
"left-hand",
"right-hand",
"left-foot",
"right-foot",
"end-effector",
]
frame_indices: List[int]
root_positions: List[Vec3]
local_joints_rot: List[List[Vec3]]
smooth_root_2d: Optional[List[Vec2]] = None
joint_names: Optional[List[str]] = None # required when type == "end-effector"
@field_validator("joint_names")
@classmethod
def _names_required_for_custom(cls, v, info):
if info.data.get("type") == "end-effector" and not v:
raise ValueError(
"type='end-effector' requires `joint_names`; use a typed variant "
"(left-hand, right-hand, left-foot, right-foot) otherwise."
)
return v
KimodoConstraint = Union[Root2DConstraint, FullBodyConstraint, EndEffectorConstraint]
def parse_constraints(payload: List[dict], num_frames: int) -> List[KimodoConstraint]:
"""Validate the JSON payload from the webapp and bound-check frame indices.
Returns a list of typed Pydantic objects ready to feed kimodo's sampler.
Raises ValueError on the first violation; the caller surfaces the error to
the SSE stream so the webapp toast renders it.
"""
if not isinstance(payload, list):
raise ValueError("constraints must be a JSON list")
out: List[KimodoConstraint] = []
for i, raw in enumerate(payload):
if not isinstance(raw, dict) or "type" not in raw:
raise ValueError(f"constraints[{i}]: must be a dict with a 'type' field")
t = raw["type"]
cls = {
"root2d": Root2DConstraint,
"fullbody": FullBodyConstraint,
"left-hand": EndEffectorConstraint,
"right-hand": EndEffectorConstraint,
"left-foot": EndEffectorConstraint,
"right-foot": EndEffectorConstraint,
"end-effector": EndEffectorConstraint,
}.get(t)
if cls is None:
raise ValueError(f"constraints[{i}]: unknown type '{t}'")
try:
obj = cls(**raw)
except Exception as e:
raise ValueError(f"constraints[{i}] ({t}): {e}") from e
for f in obj.frame_indices:
if f < 0 or f >= num_frames:
raise ValueError(
f"constraints[{i}] ({t}): frame_index {f} is out of range "
f"[0, {num_frames - 1}]"
)
out.append(obj)
return out