character-studio / pipeline_manager.py
mamungtai-sat's picture
Typhoon as default translator + Pose (ControlNet) for SDXL (#14)
6d8fb5f
Raw
History Blame
18.3 kB
"""
pipeline_manager.py
-------------------
Loads diffusion pipelines from an editable registry (models.json) and runs
generation across multiple base families (SD1.5 / SDXL / FLUX) and multiple
input modes (txt2img / img2img / IP-Adapter / Face identity).
Designed for Hugging Face ZeroGPU: pipelines are built/cached on CPU and moved
to CUDA inside the @spaces.GPU-decorated caller (see app.py). Nothing here calls
.cuda() at import time.
"""
import os
import json
import gc
import hashlib
import urllib.request
from pathlib import Path
import torch
# ---------------------------------------------------------------------------
# Constants / paths
# ---------------------------------------------------------------------------
HERE = Path(__file__).parent
REGISTRY_PATH = HERE / "models.json"
DOWNLOAD_DIR = Path(os.environ.get("CS_CACHE_DIR", "/tmp/cs_models"))
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
CIVITAI_TOKEN = os.environ.get("CIVITAI_TOKEN", "").strip()
HF_TOKEN = os.environ.get("HF_TOKEN", "").strip() or None
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# SD1.5 / SDXL are most stable in float16; FLUX prefers bfloat16.
DTYPE_SD = torch.float16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Modes supported per base family. Used by the UI to gate options.
SUPPORTED_MODES = {
"sd15": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"],
"sdxl": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"],
"flux": ["txt2img", "img2img"],
}
MODE_LABELS = {
"txt2img": "Text → Image",
"img2img": "Image → Image (denoise)",
"ip_adapter": "IP-Adapter (style / subject)",
"face_id": "Face identity (FaceID)",
"pose": "Pose lock (ControlNet OpenPose)",
}
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
def load_registry():
"""Read models.json and return the list of enabled model configs."""
with open(REGISTRY_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
models = [m for m in data.get("models", []) if m.get("enabled", True)]
return models
def get_model(models, model_id):
for m in models:
if m["id"] == model_id:
return m
return None
# ---------------------------------------------------------------------------
# Thai → English prompt translation (the SD/SDXL/FLUX text encoders are English;
# Thai prompts otherwise produce unrelated images). Runs on the Space, no API key.
# ---------------------------------------------------------------------------
TRANSLATORS = {
"nllb": "facebook/nllb-200-distilled-600M",
"typhoon": "scb10x/llama3.2-typhoon2-3b-instruct",
}
_TRANSLATOR_CACHE = {}
def has_thai(text):
return any("฀" <= ch <= "๿" for ch in (text or ""))
def _load_translator(engine):
if engine in _TRANSLATOR_CACHE:
return _TRANSLATOR_CACHE[engine]
name = TRANSLATORS[engine]
if engine == "nllb":
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tok = AutoTokenizer.from_pretrained(name)
model = AutoModelForSeq2SeqLM.from_pretrained(name, torch_dtype=DTYPE_SD)
else: # typhoon (causal LM)
from transformers import AutoTokenizer, AutoModelForCausalLM
tok = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name, torch_dtype=DTYPE_SD)
model.eval()
_TRANSLATOR_CACHE[engine] = (tok, model)
return tok, model
def translate_prompt(text, engine):
"""Translate a Thai prompt to English. Pass-through if empty/English/off.
MUST be called inside the @spaces.GPU context (uses CUDA when available)."""
if not text or engine in (None, "off") or not has_thai(text):
return text
try:
tok, model = _load_translator(engine)
model = model.to(DEVICE)
if engine == "nllb":
tok.src_lang = "tha_Thai"
inputs = tok(text, return_tensors="pt", truncation=True,
max_length=400).to(DEVICE)
bos = tok.convert_tokens_to_ids("eng_Latn")
out = model.generate(**inputs, forced_bos_token_id=bos,
max_new_tokens=256, num_beams=4)
return tok.batch_decode(out, skip_special_tokens=True)[0].strip()
# typhoon: ask the LLM to rewrite as a clean English image prompt
msgs = [
{"role": "system", "content": "You convert Thai text-to-image prompts "
"into a single concise, vivid English prompt for Stable Diffusion. "
"Keep the described subject, clothing, pose, and scene. Output ONLY the "
"English prompt as a comma-separated phrase — no quotes, no explanation."},
{"role": "user", "content": text},
]
chat = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
inputs = tok(chat, return_tensors="pt").to(DEVICE)
eos = tok.eos_token_id
pad = eos[0] if isinstance(eos, (list, tuple)) else eos
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=256, do_sample=False,
pad_token_id=pad)
gen = out[0][inputs["input_ids"].shape[1]:]
return tok.decode(gen, skip_special_tokens=True).strip().strip('"')
except Exception as e: # noqa
import traceback as _tb
print(f"[translate] {engine} failed, using original text: "
f"{type(e).__name__}: {e}")
_tb.print_exc()
return text
# ---------------------------------------------------------------------------
# Download helpers (Civitai / arbitrary URL → local cache)
# ---------------------------------------------------------------------------
def _download_url(url):
"""Download a (Civitai or other) URL to the local cache and return the path."""
if not url:
return None
fname = hashlib.sha1(url.encode()).hexdigest()[:16] + ".safetensors"
dest = DOWNLOAD_DIR / fname
if dest.exists() and dest.stat().st_size > 1_000_000:
return str(dest)
dl_url = url
if "civitai.com" in url and CIVITAI_TOKEN and "token=" not in url:
sep = "&" if "?" in url else "?"
dl_url = f"{url}{sep}token={CIVITAI_TOKEN}"
req = urllib.request.Request(dl_url, headers={"User-Agent": "Mozilla/5.0"})
print(f"[download] {url} -> {dest}")
with urllib.request.urlopen(req) as resp, open(dest, "wb") as out:
while True:
chunk = resp.read(1 << 20)
if not chunk:
break
out.write(chunk)
return str(dest)
# ---------------------------------------------------------------------------
# Pipeline cache
# ---------------------------------------------------------------------------
# Keyed by model id. Stores the base txt2img pipeline (CPU). Adapters are loaded
# on demand and tracked via the `_cs_adapter` attribute on the pipe.
_PIPE_CACHE = {}
_FACE_APP = None # lazy insightface FaceAnalysis
def _free_cache(keep_id=None):
"""Evict cached pipelines except keep_id to bound memory (simple LRU-ish)."""
for k in list(_PIPE_CACHE.keys()):
if k != keep_id:
del _PIPE_CACHE[k]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _build_base_pipeline(cfg):
"""Construct the txt2img pipeline for a model config (on CPU)."""
base = cfg["base"]
common = dict(token=HF_TOKEN)
if base == "sd15":
from diffusers import StableDiffusionPipeline
if cfg.get("single_file_url"):
local = _download_url(cfg["single_file_url"])
pipe = StableDiffusionPipeline.from_single_file(
local, torch_dtype=DTYPE_SD, safety_checker=None
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
cfg["repo_id"], torch_dtype=DTYPE_SD, safety_checker=None, **common
)
elif base == "sdxl":
from diffusers import StableDiffusionXLPipeline
if cfg.get("single_file_url"):
local = _download_url(cfg["single_file_url"])
pipe = StableDiffusionXLPipeline.from_single_file(local, torch_dtype=DTYPE_SD)
else:
pipe = StableDiffusionXLPipeline.from_pretrained(
cfg["repo_id"], torch_dtype=DTYPE_SD, **common
)
elif base == "flux":
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(cfg["repo_id"], torch_dtype=DTYPE, **common)
else:
raise ValueError(f"Unknown base family: {base}")
# Apply LoRA if this entry is a LoRA model.
if cfg.get("type") == "lora":
scale = float(cfg.get("lora_scale", 0.8))
if cfg.get("lora_repo_id"):
kwargs = {}
if cfg.get("lora_weight_name"):
kwargs["weight_name"] = cfg["lora_weight_name"]
pipe.load_lora_weights(cfg["lora_repo_id"], **kwargs)
elif cfg.get("lora_url"):
local = _download_url(cfg["lora_url"])
pipe.load_lora_weights(local)
try:
pipe.fuse_lora(lora_scale=scale)
except Exception as e: # noqa
print(f"[lora] fuse skipped: {e}")
# SD1.5 / SDXL community checkpoints are tuned for the Euler Ancestral sampler;
# it matches the look people get in A1111 / ComfyUI far better than the default.
if base in ("sd15", "sdxl"):
from diffusers import EulerAncestralDiscreteScheduler
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.set_progress_bar_config(disable=True)
pipe._cs_adapter = None # track loaded IP-Adapter / FaceID state
return pipe
def get_pipeline(cfg):
"""Return a cached base pipeline for the model, building it if needed."""
mid = cfg["id"]
if mid not in _PIPE_CACHE:
_free_cache(keep_id=None) # one big model at a time on ZeroGPU
print(f"[pipeline] building {mid} ({cfg['base']})")
_PIPE_CACHE[mid] = _build_base_pipeline(cfg)
return _PIPE_CACHE[mid]
# ---------------------------------------------------------------------------
# Adapter management (IP-Adapter / FaceID)
# ---------------------------------------------------------------------------
_IP_ADAPTER_SPECS = {
"sd15": {
"ip_adapter": dict(repo="h94/IP-Adapter", subfolder="models",
weight_name="ip-adapter-plus_sd15.bin"),
"face_id": dict(repo="h94/IP-Adapter-FaceID", subfolder=None,
weight_name="ip-adapter-faceid_sd15.bin",
image_encoder_folder=None),
},
"sdxl": {
"ip_adapter": dict(repo="h94/IP-Adapter", subfolder="sdxl_models",
weight_name="ip-adapter-plus_sdxl_vit-h.bin"),
"face_id": dict(repo="h94/IP-Adapter-FaceID", subfolder=None,
weight_name="ip-adapter-faceid_sdxl.bin",
image_encoder_folder=None),
},
}
def _ensure_adapter(pipe, base, mode):
"""Load the right IP-Adapter for `mode`, unloading any previous one."""
want = mode if mode in ("ip_adapter", "face_id") else None
if pipe._cs_adapter == want:
return
try:
pipe.unload_ip_adapter()
except Exception:
pass
pipe._cs_adapter = None
if want is None:
return
spec = _IP_ADAPTER_SPECS[base][want]
kwargs = {k: v for k, v in spec.items() if k != "repo"}
pipe.load_ip_adapter(spec["repo"], **kwargs)
pipe._cs_adapter = want
def _get_face_app():
global _FACE_APP
if _FACE_APP is None:
from insightface.app import FaceAnalysis
app = FaceAnalysis(name="buffalo_l",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
app.prepare(ctx_id=0, det_size=(640, 640))
_FACE_APP = app
return _FACE_APP
def _face_embeds(image):
"""Return a torch tensor of FaceID embeddings for the largest face."""
import numpy as np
import cv2
app = _get_face_app()
arr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
faces = app.get(arr)
if not faces:
raise ValueError("ไม่พบใบหน้าในรูปต้นแบบ / No face detected in the reference image.")
faces = sorted(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]))
emb = torch.from_numpy(faces[-1].normed_embedding) # [512]
# diffusers IP-Adapter-FaceID expects [2, 1, 1, 512]: [neg, pos] for CFG.
emb = emb.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, 512]
return torch.cat([torch.zeros_like(emb), emb], dim=0).to(DTYPE_SD)
# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# ControlNet (OpenPose) — locks the generated subject to an uploaded pose.
# ---------------------------------------------------------------------------
_CONTROLNET = {}
_OPENPOSE = None
def _get_controlnet(base):
if base in _CONTROLNET:
return _CONTROLNET[base]
from diffusers import ControlNetModel
repos = {
"sd15": "lllyasviel/control_v11p_sd15_openpose",
"sdxl": "xinsir/controlnet-openpose-sdxl-1.0",
}
if base not in repos:
raise ValueError("Pose (ControlNet) รองรับ SD1.5 / SDXL เท่านั้น.")
cn = ControlNetModel.from_pretrained(repos[base], torch_dtype=DTYPE_SD)
_CONTROLNET[base] = cn
return cn
def _get_openpose():
global _OPENPOSE
if _OPENPOSE is None:
from controlnet_aux import OpenposeDetector
_OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
return _OPENPOSE
def _safe_call(pipe_obj, call):
"""Run the pipeline; if clip_skip trips a version incompatibility, retry without it."""
try:
return pipe_obj(**call).images[0]
except (AttributeError, TypeError) as e:
if "clip_skip" in call:
print(f"[clip_skip] disabled for this run due to: {e}")
call.pop("clip_skip", None)
return pipe_obj(**call).images[0]
raise
def run_generation(cfg, mode, prompt, negative_prompt, ref_image,
steps, guidance, denoise, ip_scale, width, height, seed):
"""Run one generation. MUST be called inside a @spaces.GPU context."""
base = cfg["base"]
if mode not in SUPPORTED_MODES[base]:
raise ValueError(
f"โหมด '{MODE_LABELS.get(mode, mode)}' ใช้กับ base {base.upper()} ไม่ได้ "
f"(รองรับ: {', '.join(MODE_LABELS[m] for m in SUPPORTED_MODES[base])})"
)
pipe = get_pipeline(cfg)
pipe = pipe.to(DEVICE)
generator = None
if seed is not None and int(seed) >= 0:
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
full_prompt = prompt
if cfg.get("trigger"):
full_prompt = f"{cfg['trigger']}, {prompt}".strip(", ")
call = dict(
prompt=full_prompt,
num_inference_steps=int(steps),
generator=generator,
width=int(width),
height=int(height),
)
# FLUX uses `guidance_scale` differently and has no negative prompt.
if base == "flux":
call["guidance_scale"] = float(guidance)
else:
call["guidance_scale"] = float(guidance)
call["negative_prompt"] = negative_prompt or None
# CLIP skip: registry stores the A1111 value (2 = skip one layer); diffusers
# counts skipped layers, so A1111 "2" == diffusers clip_skip=1.
a1111_cs = int(cfg.get("clip_skip", 1))
if a1111_cs > 1:
call["clip_skip"] = a1111_cs - 1
# ----- mode wiring -----
if mode == "txt2img":
_ensure_adapter(pipe, base, None)
elif mode == "img2img":
_ensure_adapter(pipe, base, None) if base != "flux" else None
if ref_image is None:
raise ValueError("img2img ต้องอัปโหลดรูปต้นแบบก่อน / Upload a reference image first.")
from diffusers import AutoPipelineForImage2Image
i2i = AutoPipelineForImage2Image.from_pipe(pipe).to(DEVICE)
call.pop("width"); call.pop("height")
call["image"] = ref_image.convert("RGB")
call["strength"] = float(denoise)
return _safe_call(i2i, call)
elif mode == "ip_adapter":
if ref_image is None:
raise ValueError("IP-Adapter ต้องอัปโหลดรูปต้นแบบก่อน / Upload a reference image first.")
_ensure_adapter(pipe, base, "ip_adapter")
pipe.set_ip_adapter_scale(float(ip_scale))
call["ip_adapter_image"] = ref_image.convert("RGB")
elif mode == "face_id":
if ref_image is None:
raise ValueError("Face identity ต้องอัปโหลดรูปใบหน้าก่อน / Upload a face image first.")
_ensure_adapter(pipe, base, "face_id")
pipe.set_ip_adapter_scale(float(ip_scale))
embeds = _face_embeds(ref_image).to(DEVICE)
call["ip_adapter_image_embeds"] = [embeds]
elif mode == "pose":
if ref_image is None:
raise ValueError("Pose ต้องอัปโหลดรูปท่าทางก่อน / Upload a pose reference image first.")
_ensure_adapter(pipe, base, None)
detector = _get_openpose()
pose_img = detector(ref_image.convert("RGB")).resize((int(width), int(height)))
cn = _get_controlnet(base).to(DEVICE)
if base == "sdxl":
from diffusers import StableDiffusionXLControlNetPipeline
cn_pipe = StableDiffusionXLControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE)
else:
from diffusers import StableDiffusionControlNetPipeline
cn_pipe = StableDiffusionControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE)
call["image"] = pose_img
return _safe_call(cn_pipe, call)
return _safe_call(pipe, call)