Spaces:
Runtime error
Runtime error
Add kimodo_motion_seq endpoint for multi-prompt chained sequences
Browse files
server.py
CHANGED
|
@@ -205,6 +205,113 @@ def kimodo_motion(
|
|
| 205 |
return {"status": "error", "error": f"{type(e).__name__}: {e}"}
|
| 206 |
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
def _historical_extract_soma_skin(progress: gr.Progress = gr.Progress()) -> dict: # noqa: B008
|
| 209 |
"""One-shot dump of kimodo's somaskel77/skin_standard.npz to base64 so the
|
| 210 |
webapp can ship a real SkinnedMesh. Already run; binaries live at
|
|
@@ -445,6 +552,20 @@ with gr.Blocks(title="Genga Kimodo") as demo:
|
|
| 445 |
api_name="kimodo_motion",
|
| 446 |
)
|
| 447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
|
| 450 |
if __name__ == "__main__":
|
|
|
|
| 205 |
return {"status": "error", "error": f"{type(e).__name__}: {e}"}
|
| 206 |
|
| 207 |
|
| 208 |
+
def kimodo_motion_seq(
|
| 209 |
+
prompts_json: str,
|
| 210 |
+
frames_json: str,
|
| 211 |
+
seed: int,
|
| 212 |
+
cfg: float,
|
| 213 |
+
num_steps: int,
|
| 214 |
+
constraints_json: str,
|
| 215 |
+
transition_frames: int = 20,
|
| 216 |
+
progress: gr.Progress = gr.Progress(), # noqa: B008
|
| 217 |
+
) -> dict:
|
| 218 |
+
"""Multi-prompt sequence variant of kimodo_motion. Generates a single
|
| 219 |
+
motion that transitions through each prompt segment in order.
|
| 220 |
+
|
| 221 |
+
prompts_json: JSON list of strings, e.g. '["walk forward", "wave hello"]'
|
| 222 |
+
frames_json: JSON list of ints (per-segment frame counts), same length.
|
| 223 |
+
transition_frames: how many frames the model uses to blend between segments.
|
| 224 |
+
|
| 225 |
+
Returns the same envelope as kimodo_motion. The total numFrames is
|
| 226 |
+
sum(frames). If a single segment is provided this is equivalent to
|
| 227 |
+
kimodo_motion.
|
| 228 |
+
"""
|
| 229 |
+
try:
|
| 230 |
+
prompts = json.loads(prompts_json) if prompts_json else []
|
| 231 |
+
if not isinstance(prompts, list) or not all(isinstance(p, str) and p.strip() for p in prompts):
|
| 232 |
+
return {"status": "error", "error": "prompts_json must be a JSON list of non-empty strings"}
|
| 233 |
+
frames = json.loads(frames_json) if frames_json else []
|
| 234 |
+
if not isinstance(frames, list) or len(frames) != len(prompts) or not all(isinstance(n, int) and 1 <= n <= 300 for n in frames):
|
| 235 |
+
return {"status": "error", "error": "frames_json must be a JSON list of ints (1..300) matching prompts length"}
|
| 236 |
+
total_n = sum(frames)
|
| 237 |
+
if total_n > 600:
|
| 238 |
+
return {"status": "error", "error": f"total frames {total_n} exceeds 600 cap"}
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
raw = json.loads(constraints_json) if constraints_json else []
|
| 242 |
+
parse_constraints(raw, total_n)
|
| 243 |
+
except (ValueError, json.JSONDecodeError) as e:
|
| 244 |
+
return {"status": "error", "error": f"constraint validation: {e}"}
|
| 245 |
+
|
| 246 |
+
progress(0.02, desc="Loading model...")
|
| 247 |
+
model, skeleton, device = _load_model()
|
| 248 |
+
|
| 249 |
+
from kimodo.constraints import load_constraints_lst
|
| 250 |
+
constraint_lst = load_constraints_lst(raw, skeleton, device=device)
|
| 251 |
+
|
| 252 |
+
if seed is not None and int(seed) >= 0:
|
| 253 |
+
from kimodo.tools import seed_everything
|
| 254 |
+
seed_everything(int(seed))
|
| 255 |
+
|
| 256 |
+
progress(0.10, desc=f"Diffusion ({len(prompts)} segments × {int(num_steps)} steps)...")
|
| 257 |
+
cfg_kwargs = {"cfg_type": "regular", "cfg_weight": float(cfg)}
|
| 258 |
+
output = model(
|
| 259 |
+
[p.strip() for p in prompts],
|
| 260 |
+
list(frames),
|
| 261 |
+
constraint_lst=constraint_lst,
|
| 262 |
+
num_denoising_steps=int(num_steps),
|
| 263 |
+
num_samples=1,
|
| 264 |
+
multi_prompt=True,
|
| 265 |
+
num_transition_frames=int(transition_frames),
|
| 266 |
+
return_numpy=True,
|
| 267 |
+
**cfg_kwargs,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
progress(0.92, desc="Serializing...")
|
| 271 |
+
if "posed_joints" not in output or "global_rot_mats" not in output:
|
| 272 |
+
return {"status": "error", "error": f"unexpected model output keys: {list(output.keys())}"}
|
| 273 |
+
|
| 274 |
+
posed_joints = output["posed_joints"]
|
| 275 |
+
global_rot_mats = output["global_rot_mats"]
|
| 276 |
+
joints_pos_t = torch.from_numpy(posed_joints[0]).to(device)
|
| 277 |
+
if "local_rot_mats" in output:
|
| 278 |
+
local_rot_mats_77 = torch.from_numpy(output["local_rot_mats"][0]).to(device)
|
| 279 |
+
else:
|
| 280 |
+
from kimodo.skeleton import global_rots_to_local_rots
|
| 281 |
+
joints_rot_t = torch.from_numpy(global_rot_mats[0]).to(device)
|
| 282 |
+
local_rot_mats_77 = global_rots_to_local_rots(joints_rot_t, skeleton.somaskel77)
|
| 283 |
+
local_rot_mats_30 = skeleton.from_SOMASkeleton77(local_rot_mats_77)
|
| 284 |
+
if local_rot_mats_30.ndim == 5 and local_rot_mats_30.shape[0] == 1:
|
| 285 |
+
local_rot_mats_30 = local_rot_mats_30[0]
|
| 286 |
+
local_rot_mats = local_rot_mats_30.detach().cpu().numpy().astype(np.float32)
|
| 287 |
+
root_translation = joints_pos_t[:, 0, :].detach().cpu().numpy().astype(np.float32)
|
| 288 |
+
|
| 289 |
+
T, J = local_rot_mats.shape[0], local_rot_mats.shape[1]
|
| 290 |
+
# Note: the model may return slightly more or fewer frames than total_n
|
| 291 |
+
# depending on transition handling; report whatever it gave us.
|
| 292 |
+
foot_contacts_out = None
|
| 293 |
+
if "foot_contacts" in output:
|
| 294 |
+
fc = output["foot_contacts"]
|
| 295 |
+
if fc.ndim == 3:
|
| 296 |
+
fc = fc[0]
|
| 297 |
+
foot_contacts_out = np.asarray(fc, dtype=np.float32).tolist()
|
| 298 |
+
|
| 299 |
+
progress(1.0, desc="Done")
|
| 300 |
+
return {
|
| 301 |
+
"status": "ok",
|
| 302 |
+
"numFrames": int(T),
|
| 303 |
+
"fps": int(getattr(model, "fps", 30)),
|
| 304 |
+
"rootTranslation": root_translation.tolist(),
|
| 305 |
+
"jointRotMats": local_rot_mats.tolist(),
|
| 306 |
+
"footContacts": foot_contacts_out,
|
| 307 |
+
"summary": " → ".join(p.strip() for p in prompts),
|
| 308 |
+
"segments": [{"prompt": p.strip(), "frames": int(n)} for p, n in zip(prompts, frames)],
|
| 309 |
+
}
|
| 310 |
+
except Exception as e:
|
| 311 |
+
traceback.print_exc()
|
| 312 |
+
return {"status": "error", "error": f"{type(e).__name__}: {e}"}
|
| 313 |
+
|
| 314 |
+
|
| 315 |
def _historical_extract_soma_skin(progress: gr.Progress = gr.Progress()) -> dict: # noqa: B008
|
| 316 |
"""One-shot dump of kimodo's somaskel77/skin_standard.npz to base64 so the
|
| 317 |
webapp can ship a real SkinnedMesh. Already run; binaries live at
|
|
|
|
| 552 |
api_name="kimodo_motion",
|
| 553 |
)
|
| 554 |
|
| 555 |
+
# Multi-prompt sequence endpoint — header-only inputs (no UI form widgets;
|
| 556 |
+
# the webapp posts JSON directly to /gradio_api/call/kimodo_motion_seq).
|
| 557 |
+
in_prompts_json = gr.Textbox(label="prompts_json", value='["A person walks forward","A person waves hello"]', visible=False)
|
| 558 |
+
in_frames_json = gr.Textbox(label="frames_json", value="[45,45]", visible=False)
|
| 559 |
+
in_transition = gr.Number(value=20, label="transition_frames", precision=0, visible=False)
|
| 560 |
+
out_seq = gr.JSON(label="seq result", visible=False)
|
| 561 |
+
seq_btn = gr.Button("Generate sequence", visible=False)
|
| 562 |
+
seq_btn.click(
|
| 563 |
+
fn=kimodo_motion_seq,
|
| 564 |
+
inputs=[in_prompts_json, in_frames_json, in_seed, in_cfg, in_steps, in_constraints, in_transition],
|
| 565 |
+
outputs=out_seq,
|
| 566 |
+
api_name="kimodo_motion_seq",
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
|
| 570 |
|
| 571 |
if __name__ == "__main__":
|