WalkingOnSaturn commited on
Commit
9f127fb
·
verified ·
1 Parent(s): 9c9fbf7

Add kimodo_motion_seq endpoint for multi-prompt chained sequences

Browse files
Files changed (1) hide show
  1. server.py +121 -0
server.py CHANGED
@@ -205,6 +205,113 @@ def kimodo_motion(
205
  return {"status": "error", "error": f"{type(e).__name__}: {e}"}
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def _historical_extract_soma_skin(progress: gr.Progress = gr.Progress()) -> dict: # noqa: B008
209
  """One-shot dump of kimodo's somaskel77/skin_standard.npz to base64 so the
210
  webapp can ship a real SkinnedMesh. Already run; binaries live at
@@ -445,6 +552,20 @@ with gr.Blocks(title="Genga Kimodo") as demo:
445
  api_name="kimodo_motion",
446
  )
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
 
450
  if __name__ == "__main__":
 
205
  return {"status": "error", "error": f"{type(e).__name__}: {e}"}
206
 
207
 
208
+ def kimodo_motion_seq(
209
+ prompts_json: str,
210
+ frames_json: str,
211
+ seed: int,
212
+ cfg: float,
213
+ num_steps: int,
214
+ constraints_json: str,
215
+ transition_frames: int = 20,
216
+ progress: gr.Progress = gr.Progress(), # noqa: B008
217
+ ) -> dict:
218
+ """Multi-prompt sequence variant of kimodo_motion. Generates a single
219
+ motion that transitions through each prompt segment in order.
220
+
221
+ prompts_json: JSON list of strings, e.g. '["walk forward", "wave hello"]'
222
+ frames_json: JSON list of ints (per-segment frame counts), same length.
223
+ transition_frames: how many frames the model uses to blend between segments.
224
+
225
+ Returns the same envelope as kimodo_motion. The total numFrames is
226
+ sum(frames). If a single segment is provided this is equivalent to
227
+ kimodo_motion.
228
+ """
229
+ try:
230
+ prompts = json.loads(prompts_json) if prompts_json else []
231
+ if not isinstance(prompts, list) or not all(isinstance(p, str) and p.strip() for p in prompts):
232
+ return {"status": "error", "error": "prompts_json must be a JSON list of non-empty strings"}
233
+ frames = json.loads(frames_json) if frames_json else []
234
+ if not isinstance(frames, list) or len(frames) != len(prompts) or not all(isinstance(n, int) and 1 <= n <= 300 for n in frames):
235
+ return {"status": "error", "error": "frames_json must be a JSON list of ints (1..300) matching prompts length"}
236
+ total_n = sum(frames)
237
+ if total_n > 600:
238
+ return {"status": "error", "error": f"total frames {total_n} exceeds 600 cap"}
239
+
240
+ try:
241
+ raw = json.loads(constraints_json) if constraints_json else []
242
+ parse_constraints(raw, total_n)
243
+ except (ValueError, json.JSONDecodeError) as e:
244
+ return {"status": "error", "error": f"constraint validation: {e}"}
245
+
246
+ progress(0.02, desc="Loading model...")
247
+ model, skeleton, device = _load_model()
248
+
249
+ from kimodo.constraints import load_constraints_lst
250
+ constraint_lst = load_constraints_lst(raw, skeleton, device=device)
251
+
252
+ if seed is not None and int(seed) >= 0:
253
+ from kimodo.tools import seed_everything
254
+ seed_everything(int(seed))
255
+
256
+ progress(0.10, desc=f"Diffusion ({len(prompts)} segments × {int(num_steps)} steps)...")
257
+ cfg_kwargs = {"cfg_type": "regular", "cfg_weight": float(cfg)}
258
+ output = model(
259
+ [p.strip() for p in prompts],
260
+ list(frames),
261
+ constraint_lst=constraint_lst,
262
+ num_denoising_steps=int(num_steps),
263
+ num_samples=1,
264
+ multi_prompt=True,
265
+ num_transition_frames=int(transition_frames),
266
+ return_numpy=True,
267
+ **cfg_kwargs,
268
+ )
269
+
270
+ progress(0.92, desc="Serializing...")
271
+ if "posed_joints" not in output or "global_rot_mats" not in output:
272
+ return {"status": "error", "error": f"unexpected model output keys: {list(output.keys())}"}
273
+
274
+ posed_joints = output["posed_joints"]
275
+ global_rot_mats = output["global_rot_mats"]
276
+ joints_pos_t = torch.from_numpy(posed_joints[0]).to(device)
277
+ if "local_rot_mats" in output:
278
+ local_rot_mats_77 = torch.from_numpy(output["local_rot_mats"][0]).to(device)
279
+ else:
280
+ from kimodo.skeleton import global_rots_to_local_rots
281
+ joints_rot_t = torch.from_numpy(global_rot_mats[0]).to(device)
282
+ local_rot_mats_77 = global_rots_to_local_rots(joints_rot_t, skeleton.somaskel77)
283
+ local_rot_mats_30 = skeleton.from_SOMASkeleton77(local_rot_mats_77)
284
+ if local_rot_mats_30.ndim == 5 and local_rot_mats_30.shape[0] == 1:
285
+ local_rot_mats_30 = local_rot_mats_30[0]
286
+ local_rot_mats = local_rot_mats_30.detach().cpu().numpy().astype(np.float32)
287
+ root_translation = joints_pos_t[:, 0, :].detach().cpu().numpy().astype(np.float32)
288
+
289
+ T, J = local_rot_mats.shape[0], local_rot_mats.shape[1]
290
+ # Note: the model may return slightly more or fewer frames than total_n
291
+ # depending on transition handling; report whatever it gave us.
292
+ foot_contacts_out = None
293
+ if "foot_contacts" in output:
294
+ fc = output["foot_contacts"]
295
+ if fc.ndim == 3:
296
+ fc = fc[0]
297
+ foot_contacts_out = np.asarray(fc, dtype=np.float32).tolist()
298
+
299
+ progress(1.0, desc="Done")
300
+ return {
301
+ "status": "ok",
302
+ "numFrames": int(T),
303
+ "fps": int(getattr(model, "fps", 30)),
304
+ "rootTranslation": root_translation.tolist(),
305
+ "jointRotMats": local_rot_mats.tolist(),
306
+ "footContacts": foot_contacts_out,
307
+ "summary": " → ".join(p.strip() for p in prompts),
308
+ "segments": [{"prompt": p.strip(), "frames": int(n)} for p, n in zip(prompts, frames)],
309
+ }
310
+ except Exception as e:
311
+ traceback.print_exc()
312
+ return {"status": "error", "error": f"{type(e).__name__}: {e}"}
313
+
314
+
315
  def _historical_extract_soma_skin(progress: gr.Progress = gr.Progress()) -> dict: # noqa: B008
316
  """One-shot dump of kimodo's somaskel77/skin_standard.npz to base64 so the
317
  webapp can ship a real SkinnedMesh. Already run; binaries live at
 
552
  api_name="kimodo_motion",
553
  )
554
 
555
+ # Multi-prompt sequence endpoint — header-only inputs (no UI form widgets;
556
+ # the webapp posts JSON directly to /gradio_api/call/kimodo_motion_seq).
557
+ in_prompts_json = gr.Textbox(label="prompts_json", value='["A person walks forward","A person waves hello"]', visible=False)
558
+ in_frames_json = gr.Textbox(label="frames_json", value="[45,45]", visible=False)
559
+ in_transition = gr.Number(value=20, label="transition_frames", precision=0, visible=False)
560
+ out_seq = gr.JSON(label="seq result", visible=False)
561
+ seq_btn = gr.Button("Generate sequence", visible=False)
562
+ seq_btn.click(
563
+ fn=kimodo_motion_seq,
564
+ inputs=[in_prompts_json, in_frames_json, in_seed, in_cfg, in_steps, in_constraints, in_transition],
565
+ outputs=out_seq,
566
+ api_name="kimodo_motion_seq",
567
+ )
568
+
569
 
570
 
571
  if __name__ == "__main__":