character-studio / pipeline_manager.py
mamungtai-sat's picture
Add Character Studio app, registry, requirements, docs (#1)
5d0bada
Raw
History Blame
12.1 kB
"""
pipeline_manager.py
-------------------
Loads diffusion pipelines from an editable registry (models.json) and runs
generation across multiple base families (SD1.5 / SDXL / FLUX) and multiple
input modes (txt2img / img2img / IP-Adapter / Face identity).
Designed for Hugging Face ZeroGPU: pipelines are built/cached on CPU and moved
to CUDA inside the @spaces.GPU-decorated caller (see app.py). Nothing here calls
.cuda() at import time.
"""
import os
import json
import gc
import hashlib
import urllib.request
from pathlib import Path
import torch
# ---------------------------------------------------------------------------
# Constants / paths
# ---------------------------------------------------------------------------
HERE = Path(__file__).parent
REGISTRY_PATH = HERE / "models.json"
DOWNLOAD_DIR = Path(os.environ.get("CS_CACHE_DIR", "/tmp/cs_models"))
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
CIVITAI_TOKEN = os.environ.get("CIVITAI_TOKEN", "").strip()
HF_TOKEN = os.environ.get("HF_TOKEN", "").strip() or None
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# SD1.5 / SDXL are most stable in float16; FLUX prefers bfloat16.
DTYPE_SD = torch.float16
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Modes supported per base family. Used by the UI to gate options.
SUPPORTED_MODES = {
"sd15": ["txt2img", "img2img", "ip_adapter", "face_id"],
"sdxl": ["txt2img", "img2img", "ip_adapter", "face_id"],
"flux": ["txt2img", "img2img"],
}
MODE_LABELS = {
"txt2img": "Text → Image",
"img2img": "Image → Image (denoise)",
"ip_adapter": "IP-Adapter (style / subject)",
"face_id": "Face identity (FaceID)",
}
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
def load_registry():
"""Read models.json and return the list of enabled model configs."""
with open(REGISTRY_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
models = [m for m in data.get("models", []) if m.get("enabled", True)]
return models
def get_model(models, model_id):
for m in models:
if m["id"] == model_id:
return m
return None
# ---------------------------------------------------------------------------
# Download helpers (Civitai / arbitrary URL → local cache)
# ---------------------------------------------------------------------------
def _download_url(url):
"""Download a (Civitai or other) URL to the local cache and return the path."""
if not url:
return None
fname = hashlib.sha1(url.encode()).hexdigest()[:16] + ".safetensors"
dest = DOWNLOAD_DIR / fname
if dest.exists() and dest.stat().st_size > 1_000_000:
return str(dest)
dl_url = url
if "civitai.com" in url and CIVITAI_TOKEN and "token=" not in url:
sep = "&" if "?" in url else "?"
dl_url = f"{url}{sep}token={CIVITAI_TOKEN}"
req = urllib.request.Request(dl_url, headers={"User-Agent": "Mozilla/5.0"})
print(f"[download] {url} -> {dest}")
with urllib.request.urlopen(req) as resp, open(dest, "wb") as out:
while True:
chunk = resp.read(1 << 20)
if not chunk:
break
out.write(chunk)
return str(dest)
# ---------------------------------------------------------------------------
# Pipeline cache
# ---------------------------------------------------------------------------
# Keyed by model id. Stores the base txt2img pipeline (CPU). Adapters are loaded
# on demand and tracked via the `_cs_adapter` attribute on the pipe.
_PIPE_CACHE = {}
_FACE_APP = None # lazy insightface FaceAnalysis
def _free_cache(keep_id=None):
"""Evict cached pipelines except keep_id to bound memory (simple LRU-ish)."""
for k in list(_PIPE_CACHE.keys()):
if k != keep_id:
del _PIPE_CACHE[k]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _build_base_pipeline(cfg):
"""Construct the txt2img pipeline for a model config (on CPU)."""
base = cfg["base"]
common = dict(token=HF_TOKEN)
if base == "sd15":
from diffusers import StableDiffusionPipeline
if cfg.get("single_file_url"):
local = _download_url(cfg["single_file_url"])
pipe = StableDiffusionPipeline.from_single_file(
local, torch_dtype=DTYPE_SD, safety_checker=None
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
cfg["repo_id"], torch_dtype=DTYPE_SD, safety_checker=None, **common
)
elif base == "sdxl":
from diffusers import StableDiffusionXLPipeline
if cfg.get("single_file_url"):
local = _download_url(cfg["single_file_url"])
pipe = StableDiffusionXLPipeline.from_single_file(local, torch_dtype=DTYPE_SD)
else:
pipe = StableDiffusionXLPipeline.from_pretrained(
cfg["repo_id"], torch_dtype=DTYPE_SD, **common
)
elif base == "flux":
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(cfg["repo_id"], torch_dtype=DTYPE, **common)
else:
raise ValueError(f"Unknown base family: {base}")
# Apply LoRA if this entry is a LoRA model.
if cfg.get("type") == "lora":
scale = float(cfg.get("lora_scale", 0.8))
if cfg.get("lora_repo_id"):
kwargs = {}
if cfg.get("lora_weight_name"):
kwargs["weight_name"] = cfg["lora_weight_name"]
pipe.load_lora_weights(cfg["lora_repo_id"], **kwargs)
elif cfg.get("lora_url"):
local = _download_url(cfg["lora_url"])
pipe.load_lora_weights(local)
try:
pipe.fuse_lora(lora_scale=scale)
except Exception as e: # noqa
print(f"[lora] fuse skipped: {e}")
pipe.set_progress_bar_config(disable=True)
pipe._cs_adapter = None # track loaded IP-Adapter / FaceID state
return pipe
def get_pipeline(cfg):
"""Return a cached base pipeline for the model, building it if needed."""
mid = cfg["id"]
if mid not in _PIPE_CACHE:
_free_cache(keep_id=None) # one big model at a time on ZeroGPU
print(f"[pipeline] building {mid} ({cfg['base']})")
_PIPE_CACHE[mid] = _build_base_pipeline(cfg)
return _PIPE_CACHE[mid]
# ---------------------------------------------------------------------------
# Adapter management (IP-Adapter / FaceID)
# ---------------------------------------------------------------------------
_IP_ADAPTER_SPECS = {
"sd15": {
"ip_adapter": dict(repo="h94/IP-Adapter", subfolder="models",
weight_name="ip-adapter-plus_sd15.bin"),
"face_id": dict(repo="h94/IP-Adapter-FaceID", subfolder=None,
weight_name="ip-adapter-faceid_sd15.bin",
image_encoder_folder=None),
},
"sdxl": {
"ip_adapter": dict(repo="h94/IP-Adapter", subfolder="sdxl_models",
weight_name="ip-adapter-plus_sdxl_vit-h.bin"),
"face_id": dict(repo="h94/IP-Adapter-FaceID", subfolder=None,
weight_name="ip-adapter-faceid_sdxl.bin",
image_encoder_folder=None),
},
}
def _ensure_adapter(pipe, base, mode):
"""Load the right IP-Adapter for `mode`, unloading any previous one."""
want = mode if mode in ("ip_adapter", "face_id") else None
if pipe._cs_adapter == want:
return
try:
pipe.unload_ip_adapter()
except Exception:
pass
pipe._cs_adapter = None
if want is None:
return
spec = _IP_ADAPTER_SPECS[base][want]
kwargs = {k: v for k, v in spec.items() if k != "repo"}
pipe.load_ip_adapter(spec["repo"], **kwargs)
pipe._cs_adapter = want
def _get_face_app():
global _FACE_APP
if _FACE_APP is None:
from insightface.app import FaceAnalysis
app = FaceAnalysis(name="buffalo_l",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
app.prepare(ctx_id=0, det_size=(640, 640))
_FACE_APP = app
return _FACE_APP
def _face_embeds(image):
"""Return a torch tensor of FaceID embeddings for the largest face."""
import numpy as np
import cv2
app = _get_face_app()
arr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
faces = app.get(arr)
if not faces:
raise ValueError("ไม่พบใบหน้าในรูปต้นแบบ / No face detected in the reference image.")
faces = sorted(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]))
emb = torch.from_numpy(faces[-1].normed_embedding) # [512]
# diffusers IP-Adapter-FaceID expects [2, 1, 1, 512]: [neg, pos] for CFG.
emb = emb.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, 512]
return torch.cat([torch.zeros_like(emb), emb], dim=0).to(DTYPE_SD)
# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
def run_generation(cfg, mode, prompt, negative_prompt, ref_image,
steps, guidance, denoise, ip_scale, width, height, seed):
"""Run one generation. MUST be called inside a @spaces.GPU context."""
base = cfg["base"]
if mode not in SUPPORTED_MODES[base]:
raise ValueError(
f"โหมด '{MODE_LABELS.get(mode, mode)}' ใช้กับ base {base.upper()} ไม่ได้ "
f"(รองรับ: {', '.join(MODE_LABELS[m] for m in SUPPORTED_MODES[base])})"
)
pipe = get_pipeline(cfg)
pipe = pipe.to(DEVICE)
generator = None
if seed is not None and int(seed) >= 0:
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
full_prompt = prompt
if cfg.get("trigger"):
full_prompt = f"{cfg['trigger']}, {prompt}".strip(", ")
call = dict(
prompt=full_prompt,
num_inference_steps=int(steps),
generator=generator,
width=int(width),
height=int(height),
)
# FLUX uses `guidance_scale` differently and has no negative prompt.
if base == "flux":
call["guidance_scale"] = float(guidance)
else:
call["guidance_scale"] = float(guidance)
call["negative_prompt"] = negative_prompt or None
# ----- mode wiring -----
if mode == "txt2img":
_ensure_adapter(pipe, base, None)
elif mode == "img2img":
_ensure_adapter(pipe, base, None) if base != "flux" else None
if ref_image is None:
raise ValueError("img2img ต้องอัปโหลดรูปต้นแบบก่อน / Upload a reference image first.")
from diffusers import AutoPipelineForImage2Image
i2i = AutoPipelineForImage2Image.from_pipe(pipe).to(DEVICE)
call.pop("width"); call.pop("height")
call["image"] = ref_image.convert("RGB")
call["strength"] = float(denoise)
out = i2i(**call).images[0]
return out
elif mode == "ip_adapter":
if ref_image is None:
raise ValueError("IP-Adapter ต้องอัปโหลดรูปต้นแบบก่อน / Upload a reference image first.")
_ensure_adapter(pipe, base, "ip_adapter")
pipe.set_ip_adapter_scale(float(ip_scale))
call["ip_adapter_image"] = ref_image.convert("RGB")
elif mode == "face_id":
if ref_image is None:
raise ValueError("Face identity ต้องอัปโหลดรูปใบหน้าก่อน / Upload a face image first.")
_ensure_adapter(pipe, base, "face_id")
pipe.set_ip_adapter_scale(float(ip_scale))
embeds = _face_embeds(ref_image).to(DEVICE)
call["ip_adapter_image_embeds"] = [embeds]
out = pipe(**call).images[0]
return out