mamungtai-sat pormungtai commited on
Commit
0e89ded
·
1 Parent(s): c86ac4c

Add Pose lock mode (ControlNet OpenPose) for SD1.5 (#13)

Browse files

- Add Pose lock mode (ControlNet OpenPose) for SD1.5 (337b47f6340c5259fe86f963b35faf621d30f193)


Co-authored-by: pormungtailaw <pormungtai@users.noreply.huggingface.co>

Files changed (2) hide show
  1. pipeline_manager.py +42 -1
  2. requirements.txt +2 -0
pipeline_manager.py CHANGED
@@ -38,7 +38,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
 
39
  # Modes supported per base family. Used by the UI to gate options.
40
  SUPPORTED_MODES = {
41
- "sd15": ["txt2img", "img2img", "ip_adapter", "face_id"],
42
  "sdxl": ["txt2img", "img2img", "ip_adapter", "face_id"],
43
  "flux": ["txt2img", "img2img"],
44
  }
@@ -48,6 +48,7 @@ MODE_LABELS = {
48
  "img2img": "Image → Image (denoise)",
49
  "ip_adapter": "IP-Adapter (style / subject)",
50
  "face_id": "Face identity (FaceID)",
 
51
  }
52
 
53
  # ---------------------------------------------------------------------------
@@ -328,6 +329,34 @@ def _face_embeds(image):
328
  # ---------------------------------------------------------------------------
329
  # Generation
330
  # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  def _safe_call(pipe_obj, call):
332
  """Run the pipeline; if clip_skip trips a version incompatibility, retry without it."""
333
  try:
@@ -411,4 +440,16 @@ def run_generation(cfg, mode, prompt, negative_prompt, ref_image,
411
  embeds = _face_embeds(ref_image).to(DEVICE)
412
  call["ip_adapter_image_embeds"] = [embeds]
413
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  return _safe_call(pipe, call)
 
38
 
39
  # Modes supported per base family. Used by the UI to gate options.
40
  SUPPORTED_MODES = {
41
+ "sd15": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"],
42
  "sdxl": ["txt2img", "img2img", "ip_adapter", "face_id"],
43
  "flux": ["txt2img", "img2img"],
44
  }
 
48
  "img2img": "Image → Image (denoise)",
49
  "ip_adapter": "IP-Adapter (style / subject)",
50
  "face_id": "Face identity (FaceID)",
51
+ "pose": "Pose lock (ControlNet OpenPose)",
52
  }
53
 
54
  # ---------------------------------------------------------------------------
 
329
  # ---------------------------------------------------------------------------
330
  # Generation
331
  # ---------------------------------------------------------------------------
332
+ # ---------------------------------------------------------------------------
333
+ # ControlNet (OpenPose) — locks the generated subject to an uploaded pose.
334
+ # ---------------------------------------------------------------------------
335
+ _CONTROLNET = {}
336
+ _OPENPOSE = None
337
+
338
+
339
+ def _get_controlnet(base):
340
+ if base in _CONTROLNET:
341
+ return _CONTROLNET[base]
342
+ from diffusers import ControlNetModel
343
+ if base == "sd15":
344
+ repo = "lllyasviel/control_v11p_sd15_openpose"
345
+ else:
346
+ raise ValueError("Pose (ControlNet) รองรับเฉพาะ SD1.5 ตอนนี้ / SD1.5 only for now.")
347
+ cn = ControlNetModel.from_pretrained(repo, torch_dtype=DTYPE_SD)
348
+ _CONTROLNET[base] = cn
349
+ return cn
350
+
351
+
352
+ def _get_openpose():
353
+ global _OPENPOSE
354
+ if _OPENPOSE is None:
355
+ from controlnet_aux import OpenposeDetector
356
+ _OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
357
+ return _OPENPOSE
358
+
359
+
360
  def _safe_call(pipe_obj, call):
361
  """Run the pipeline; if clip_skip trips a version incompatibility, retry without it."""
362
  try:
 
440
  embeds = _face_embeds(ref_image).to(DEVICE)
441
  call["ip_adapter_image_embeds"] = [embeds]
442
 
443
+ elif mode == "pose":
444
+ if ref_image is None:
445
+ raise ValueError("Pose ต้องอัปโหลดรูปท่าทางก่อน / Upload a pose reference image first.")
446
+ _ensure_adapter(pipe, base, None)
447
+ detector = _get_openpose()
448
+ pose_img = detector(ref_image.convert("RGB")).resize((int(width), int(height)))
449
+ cn = _get_controlnet(base).to(DEVICE)
450
+ from diffusers import StableDiffusionControlNetPipeline
451
+ cn_pipe = StableDiffusionControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE)
452
+ call["image"] = pose_img
453
+ return _safe_call(cn_pipe, call)
454
+
455
  return _safe_call(pipe, call)
requirements.txt CHANGED
@@ -19,3 +19,5 @@ opencv-python-headless
19
  # Face-identity mode (IP-Adapter FaceID). Heavy; comment out if you don't use Face mode.
20
  insightface==0.7.3
21
  onnxruntime
 
 
 
19
  # Face-identity mode (IP-Adapter FaceID). Heavy; comment out if you don't use Face mode.
20
  insightface==0.7.3
21
  onnxruntime
22
+ # Pose lock mode (ControlNet OpenPose) — pose detection from the uploaded image.
23
+ controlnet_aux