pormungtai commited on
Commit
e9c180e
·
verified ·
1 Parent(s): 0e89ded

Typhoon as default translator + Pose (ControlNet) for SDXL

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. pipeline_manager.py +14 -8
app.py CHANGED
@@ -202,7 +202,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(primary_hue="blue"),
202
  choices=[("ปิด / Off", "off"),
203
  ("NLLB-200 (เร็ว)", "nllb"),
204
  ("Typhoon 2 (ไทยแน่น)", "typhoon")],
205
- value="nllb",
206
  label="แปลไทย→อังกฤษ / Auto-translate (พิมพ์ไทยได้เลย)",
207
  )
208
 
 
202
  choices=[("ปิด / Off", "off"),
203
  ("NLLB-200 (เร็ว)", "nllb"),
204
  ("Typhoon 2 (ไทยแน่น)", "typhoon")],
205
+ value="typhoon",
206
  label="แปลไทย→อังกฤษ / Auto-translate (พิมพ์ไทยได้เลย)",
207
  )
208
 
pipeline_manager.py CHANGED
@@ -39,7 +39,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
  # Modes supported per base family. Used by the UI to gate options.
40
  SUPPORTED_MODES = {
41
  "sd15": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"],
42
- "sdxl": ["txt2img", "img2img", "ip_adapter", "face_id"],
43
  "flux": ["txt2img", "img2img"],
44
  }
45
 
@@ -340,11 +340,13 @@ def _get_controlnet(base):
340
  if base in _CONTROLNET:
341
  return _CONTROLNET[base]
342
  from diffusers import ControlNetModel
343
- if base == "sd15":
344
- repo = "lllyasviel/control_v11p_sd15_openpose"
345
- else:
346
- raise ValueError("Pose (ControlNet) รองรับเฉพาะ SD1.5 ตอนนี้ / SD1.5 only for now.")
347
- cn = ControlNetModel.from_pretrained(repo, torch_dtype=DTYPE_SD)
 
 
348
  _CONTROLNET[base] = cn
349
  return cn
350
 
@@ -447,8 +449,12 @@ def run_generation(cfg, mode, prompt, negative_prompt, ref_image,
447
  detector = _get_openpose()
448
  pose_img = detector(ref_image.convert("RGB")).resize((int(width), int(height)))
449
  cn = _get_controlnet(base).to(DEVICE)
450
- from diffusers import StableDiffusionControlNetPipeline
451
- cn_pipe = StableDiffusionControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE)
 
 
 
 
452
  call["image"] = pose_img
453
  return _safe_call(cn_pipe, call)
454
 
 
39
  # Modes supported per base family. Used by the UI to gate options.
40
  SUPPORTED_MODES = {
41
  "sd15": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"],
42
+ "sdxl": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"],
43
  "flux": ["txt2img", "img2img"],
44
  }
45
 
 
340
  if base in _CONTROLNET:
341
  return _CONTROLNET[base]
342
  from diffusers import ControlNetModel
343
+ repos = {
344
+ "sd15": "lllyasviel/control_v11p_sd15_openpose",
345
+ "sdxl": "xinsir/controlnet-openpose-sdxl-1.0",
346
+ }
347
+ if base not in repos:
348
+ raise ValueError("Pose (ControlNet) รองรับ SD1.5 / SDXL เท่านั้น.")
349
+ cn = ControlNetModel.from_pretrained(repos[base], torch_dtype=DTYPE_SD)
350
  _CONTROLNET[base] = cn
351
  return cn
352
 
 
449
  detector = _get_openpose()
450
  pose_img = detector(ref_image.convert("RGB")).resize((int(width), int(height)))
451
  cn = _get_controlnet(base).to(DEVICE)
452
+ if base == "sdxl":
453
+ from diffusers import StableDiffusionXLControlNetPipeline
454
+ cn_pipe = StableDiffusionXLControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE)
455
+ else:
456
+ from diffusers import StableDiffusionControlNetPipeline
457
+ cn_pipe = StableDiffusionControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE)
458
  call["image"] = pose_img
459
  return _safe_call(cn_pipe, call)
460