""" pipeline_manager.py ------------------- Loads diffusion pipelines from an editable registry (models.json) and runs generation across multiple base families (SD1.5 / SDXL / FLUX) and multiple input modes (txt2img / img2img / IP-Adapter / Face identity). Designed for Hugging Face ZeroGPU: pipelines are built/cached on CPU and moved to CUDA inside the @spaces.GPU-decorated caller (see app.py). Nothing here calls .cuda() at import time. """ import os import json import gc import hashlib import urllib.request from pathlib import Path import torch # --------------------------------------------------------------------------- # Constants / paths # --------------------------------------------------------------------------- HERE = Path(__file__).parent REGISTRY_PATH = HERE / "models.json" DOWNLOAD_DIR = Path(os.environ.get("CS_CACHE_DIR", "/tmp/cs_models")) DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True) CIVITAI_TOKEN = os.environ.get("CIVITAI_TOKEN", "").strip() HF_TOKEN = os.environ.get("HF_TOKEN", "").strip() or None DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # SD1.5 / SDXL are most stable in float16; FLUX prefers bfloat16. DTYPE_SD = torch.float16 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Modes supported per base family. Used by the UI to gate options. SUPPORTED_MODES = { "sd15": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"], "sdxl": ["txt2img", "img2img", "ip_adapter", "face_id"], "flux": ["txt2img", "img2img"], } MODE_LABELS = { "txt2img": "Text → Image", "img2img": "Image → Image (denoise)", "ip_adapter": "IP-Adapter (style / subject)", "face_id": "Face identity (FaceID)", "pose": "Pose lock (ControlNet OpenPose)", } # --------------------------------------------------------------------------- # Registry # --------------------------------------------------------------------------- def load_registry(): """Read models.json and return the list of enabled model configs.""" with open(REGISTRY_PATH, "r", encoding="utf-8") as f: data = json.load(f) models = [m for m in data.get("models", []) if m.get("enabled", True)] return models def get_model(models, model_id): for m in models: if m["id"] == model_id: return m return None # --------------------------------------------------------------------------- # Thai → English prompt translation (the SD/SDXL/FLUX text encoders are English; # Thai prompts otherwise produce unrelated images). Runs on the Space, no API key. # --------------------------------------------------------------------------- TRANSLATORS = { "nllb": "facebook/nllb-200-distilled-600M", "typhoon": "scb10x/llama3.2-typhoon2-3b-instruct", } _TRANSLATOR_CACHE = {} def has_thai(text): return any("฀" <= ch <= "๿" for ch in (text or "")) def _load_translator(engine): if engine in _TRANSLATOR_CACHE: return _TRANSLATOR_CACHE[engine] name = TRANSLATORS[engine] if engine == "nllb": from transformers import AutoTokenizer, AutoModelForSeq2SeqLM tok = AutoTokenizer.from_pretrained(name) model = AutoModelForSeq2SeqLM.from_pretrained(name, torch_dtype=DTYPE_SD) else: # typhoon (causal LM) from transformers import AutoTokenizer, AutoModelForCausalLM tok = AutoTokenizer.from_pretrained(name) model = AutoModelForCausalLM.from_pretrained(name, torch_dtype=DTYPE_SD) model.eval() _TRANSLATOR_CACHE[engine] = (tok, model) return tok, model def translate_prompt(text, engine): """Translate a Thai prompt to English. Pass-through if empty/English/off. MUST be called inside the @spaces.GPU context (uses CUDA when available).""" if not text or engine in (None, "off") or not has_thai(text): return text try: tok, model = _load_translator(engine) model = model.to(DEVICE) if engine == "nllb": tok.src_lang = "tha_Thai" inputs = tok(text, return_tensors="pt", truncation=True, max_length=400).to(DEVICE) bos = tok.convert_tokens_to_ids("eng_Latn") out = model.generate(**inputs, forced_bos_token_id=bos, max_new_tokens=256, num_beams=4) return tok.batch_decode(out, skip_special_tokens=True)[0].strip() # typhoon: ask the LLM to rewrite as a clean English image prompt msgs = [ {"role": "system", "content": "You convert Thai text-to-image prompts " "into a single concise, vivid English prompt for Stable Diffusion. " "Keep the described subject, clothing, pose, and scene. Output ONLY the " "English prompt as a comma-separated phrase — no quotes, no explanation."}, {"role": "user", "content": text}, ] chat = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) inputs = tok(chat, return_tensors="pt").to(DEVICE) eos = tok.eos_token_id pad = eos[0] if isinstance(eos, (list, tuple)) else eos with torch.no_grad(): out = model.generate(**inputs, max_new_tokens=256, do_sample=False, pad_token_id=pad) gen = out[0][inputs["input_ids"].shape[1]:] return tok.decode(gen, skip_special_tokens=True).strip().strip('"') except Exception as e: # noqa import traceback as _tb print(f"[translate] {engine} failed, using original text: " f"{type(e).__name__}: {e}") _tb.print_exc() return text # --------------------------------------------------------------------------- # Download helpers (Civitai / arbitrary URL → local cache) # --------------------------------------------------------------------------- def _download_url(url): """Download a (Civitai or other) URL to the local cache and return the path.""" if not url: return None fname = hashlib.sha1(url.encode()).hexdigest()[:16] + ".safetensors" dest = DOWNLOAD_DIR / fname if dest.exists() and dest.stat().st_size > 1_000_000: return str(dest) dl_url = url if "civitai.com" in url and CIVITAI_TOKEN and "token=" not in url: sep = "&" if "?" in url else "?" dl_url = f"{url}{sep}token={CIVITAI_TOKEN}" req = urllib.request.Request(dl_url, headers={"User-Agent": "Mozilla/5.0"}) print(f"[download] {url} -> {dest}") with urllib.request.urlopen(req) as resp, open(dest, "wb") as out: while True: chunk = resp.read(1 << 20) if not chunk: break out.write(chunk) return str(dest) # --------------------------------------------------------------------------- # Pipeline cache # --------------------------------------------------------------------------- # Keyed by model id. Stores the base txt2img pipeline (CPU). Adapters are loaded # on demand and tracked via the `_cs_adapter` attribute on the pipe. _PIPE_CACHE = {} _FACE_APP = None # lazy insightface FaceAnalysis def _free_cache(keep_id=None): """Evict cached pipelines except keep_id to bound memory (simple LRU-ish).""" for k in list(_PIPE_CACHE.keys()): if k != keep_id: del _PIPE_CACHE[k] gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def _build_base_pipeline(cfg): """Construct the txt2img pipeline for a model config (on CPU).""" base = cfg["base"] common = dict(token=HF_TOKEN) if base == "sd15": from diffusers import StableDiffusionPipeline if cfg.get("single_file_url"): local = _download_url(cfg["single_file_url"]) pipe = StableDiffusionPipeline.from_single_file( local, torch_dtype=DTYPE_SD, safety_checker=None ) else: pipe = StableDiffusionPipeline.from_pretrained( cfg["repo_id"], torch_dtype=DTYPE_SD, safety_checker=None, **common ) elif base == "sdxl": from diffusers import StableDiffusionXLPipeline if cfg.get("single_file_url"): local = _download_url(cfg["single_file_url"]) pipe = StableDiffusionXLPipeline.from_single_file(local, torch_dtype=DTYPE_SD) else: pipe = StableDiffusionXLPipeline.from_pretrained( cfg["repo_id"], torch_dtype=DTYPE_SD, **common ) elif base == "flux": from diffusers import FluxPipeline pipe = FluxPipeline.from_pretrained(cfg["repo_id"], torch_dtype=DTYPE, **common) else: raise ValueError(f"Unknown base family: {base}") # Apply LoRA if this entry is a LoRA model. if cfg.get("type") == "lora": scale = float(cfg.get("lora_scale", 0.8)) if cfg.get("lora_repo_id"): kwargs = {} if cfg.get("lora_weight_name"): kwargs["weight_name"] = cfg["lora_weight_name"] pipe.load_lora_weights(cfg["lora_repo_id"], **kwargs) elif cfg.get("lora_url"): local = _download_url(cfg["lora_url"]) pipe.load_lora_weights(local) try: pipe.fuse_lora(lora_scale=scale) except Exception as e: # noqa print(f"[lora] fuse skipped: {e}") # SD1.5 / SDXL community checkpoints are tuned for the Euler Ancestral sampler; # it matches the look people get in A1111 / ComfyUI far better than the default. if base in ("sd15", "sdxl"): from diffusers import EulerAncestralDiscreteScheduler pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) pipe.set_progress_bar_config(disable=True) pipe._cs_adapter = None # track loaded IP-Adapter / FaceID state return pipe def get_pipeline(cfg): """Return a cached base pipeline for the model, building it if needed.""" mid = cfg["id"] if mid not in _PIPE_CACHE: _free_cache(keep_id=None) # one big model at a time on ZeroGPU print(f"[pipeline] building {mid} ({cfg['base']})") _PIPE_CACHE[mid] = _build_base_pipeline(cfg) return _PIPE_CACHE[mid] # --------------------------------------------------------------------------- # Adapter management (IP-Adapter / FaceID) # --------------------------------------------------------------------------- _IP_ADAPTER_SPECS = { "sd15": { "ip_adapter": dict(repo="h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin"), "face_id": dict(repo="h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None), }, "sdxl": { "ip_adapter": dict(repo="h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.bin"), "face_id": dict(repo="h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None), }, } def _ensure_adapter(pipe, base, mode): """Load the right IP-Adapter for `mode`, unloading any previous one.""" want = mode if mode in ("ip_adapter", "face_id") else None if pipe._cs_adapter == want: return try: pipe.unload_ip_adapter() except Exception: pass pipe._cs_adapter = None if want is None: return spec = _IP_ADAPTER_SPECS[base][want] kwargs = {k: v for k, v in spec.items() if k != "repo"} pipe.load_ip_adapter(spec["repo"], **kwargs) pipe._cs_adapter = want def _get_face_app(): global _FACE_APP if _FACE_APP is None: from insightface.app import FaceAnalysis app = FaceAnalysis(name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) app.prepare(ctx_id=0, det_size=(640, 640)) _FACE_APP = app return _FACE_APP def _face_embeds(image): """Return a torch tensor of FaceID embeddings for the largest face.""" import numpy as np import cv2 app = _get_face_app() arr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) faces = app.get(arr) if not faces: raise ValueError("ไม่พบใบหน้าในรูปต้นแบบ / No face detected in the reference image.") faces = sorted(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])) emb = torch.from_numpy(faces[-1].normed_embedding) # [512] # diffusers IP-Adapter-FaceID expects [2, 1, 1, 512]: [neg, pos] for CFG. emb = emb.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, 512] return torch.cat([torch.zeros_like(emb), emb], dim=0).to(DTYPE_SD) # --------------------------------------------------------------------------- # Generation # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # ControlNet (OpenPose) — locks the generated subject to an uploaded pose. # --------------------------------------------------------------------------- _CONTROLNET = {} _OPENPOSE = None def _get_controlnet(base): if base in _CONTROLNET: return _CONTROLNET[base] from diffusers import ControlNetModel if base == "sd15": repo = "lllyasviel/control_v11p_sd15_openpose" else: raise ValueError("Pose (ControlNet) รองรับเฉพาะ SD1.5 ตอนนี้ / SD1.5 only for now.") cn = ControlNetModel.from_pretrained(repo, torch_dtype=DTYPE_SD) _CONTROLNET[base] = cn return cn def _get_openpose(): global _OPENPOSE if _OPENPOSE is None: from controlnet_aux import OpenposeDetector _OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators") return _OPENPOSE def _safe_call(pipe_obj, call): """Run the pipeline; if clip_skip trips a version incompatibility, retry without it.""" try: return pipe_obj(**call).images[0] except (AttributeError, TypeError) as e: if "clip_skip" in call: print(f"[clip_skip] disabled for this run due to: {e}") call.pop("clip_skip", None) return pipe_obj(**call).images[0] raise def run_generation(cfg, mode, prompt, negative_prompt, ref_image, steps, guidance, denoise, ip_scale, width, height, seed): """Run one generation. MUST be called inside a @spaces.GPU context.""" base = cfg["base"] if mode not in SUPPORTED_MODES[base]: raise ValueError( f"โหมด '{MODE_LABELS.get(mode, mode)}' ใช้กับ base {base.upper()} ไม่ได้ " f"(รองรับ: {', '.join(MODE_LABELS[m] for m in SUPPORTED_MODES[base])})" ) pipe = get_pipeline(cfg) pipe = pipe.to(DEVICE) generator = None if seed is not None and int(seed) >= 0: generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) full_prompt = prompt if cfg.get("trigger"): full_prompt = f"{cfg['trigger']}, {prompt}".strip(", ") call = dict( prompt=full_prompt, num_inference_steps=int(steps), generator=generator, width=int(width), height=int(height), ) # FLUX uses `guidance_scale` differently and has no negative prompt. if base == "flux": call["guidance_scale"] = float(guidance) else: call["guidance_scale"] = float(guidance) call["negative_prompt"] = negative_prompt or None # CLIP skip: registry stores the A1111 value (2 = skip one layer); diffusers # counts skipped layers, so A1111 "2" == diffusers clip_skip=1. a1111_cs = int(cfg.get("clip_skip", 1)) if a1111_cs > 1: call["clip_skip"] = a1111_cs - 1 # ----- mode wiring ----- if mode == "txt2img": _ensure_adapter(pipe, base, None) elif mode == "img2img": _ensure_adapter(pipe, base, None) if base != "flux" else None if ref_image is None: raise ValueError("img2img ต้องอัปโหลดรูปต้นแบบก่อน / Upload a reference image first.") from diffusers import AutoPipelineForImage2Image i2i = AutoPipelineForImage2Image.from_pipe(pipe).to(DEVICE) call.pop("width"); call.pop("height") call["image"] = ref_image.convert("RGB") call["strength"] = float(denoise) return _safe_call(i2i, call) elif mode == "ip_adapter": if ref_image is None: raise ValueError("IP-Adapter ต้องอัปโหลดรูปต้นแบบก่อน / Upload a reference image first.") _ensure_adapter(pipe, base, "ip_adapter") pipe.set_ip_adapter_scale(float(ip_scale)) call["ip_adapter_image"] = ref_image.convert("RGB") elif mode == "face_id": if ref_image is None: raise ValueError("Face identity ต้องอัปโหลดรูปใบหน้าก่อน / Upload a face image first.") _ensure_adapter(pipe, base, "face_id") pipe.set_ip_adapter_scale(float(ip_scale)) embeds = _face_embeds(ref_image).to(DEVICE) call["ip_adapter_image_embeds"] = [embeds] elif mode == "pose": if ref_image is None: raise ValueError("Pose ต้องอัปโหลดรูปท่าทางก่อน / Upload a pose reference image first.") _ensure_adapter(pipe, base, None) detector = _get_openpose() pose_img = detector(ref_image.convert("RGB")).resize((int(width), int(height))) cn = _get_controlnet(base).to(DEVICE) from diffusers import StableDiffusionControlNetPipeline cn_pipe = StableDiffusionControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE) call["image"] = pose_img return _safe_call(cn_pipe, call) return _safe_call(pipe, call)