Spaces:
Running on Zero
Running on Zero
| """ | |
| pipeline_manager.py | |
| ------------------- | |
| Loads diffusion pipelines from an editable registry (models.json) and runs | |
| generation across multiple base families (SD1.5 / SDXL / FLUX) and multiple | |
| input modes (txt2img / img2img / IP-Adapter / Face identity). | |
| Designed for Hugging Face ZeroGPU: pipelines are built/cached on CPU and moved | |
| to CUDA inside the @spaces.GPU-decorated caller (see app.py). Nothing here calls | |
| .cuda() at import time. | |
| """ | |
| import os | |
| import json | |
| import gc | |
| import hashlib | |
| import urllib.request | |
| from pathlib import Path | |
| import torch | |
| # --------------------------------------------------------------------------- | |
| # Constants / paths | |
| # --------------------------------------------------------------------------- | |
| HERE = Path(__file__).parent | |
| REGISTRY_PATH = HERE / "models.json" | |
| DOWNLOAD_DIR = Path(os.environ.get("CS_CACHE_DIR", "/tmp/cs_models")) | |
| DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| CIVITAI_TOKEN = os.environ.get("CIVITAI_TOKEN", "").strip() | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "").strip() or None | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| # SD1.5 / SDXL are most stable in float16; FLUX prefers bfloat16. | |
| DTYPE_SD = torch.float16 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Modes supported per base family. Used by the UI to gate options. | |
| SUPPORTED_MODES = { | |
| "sd15": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"], | |
| "sdxl": ["txt2img", "img2img", "ip_adapter", "face_id", "pose"], | |
| "flux": ["txt2img", "img2img"], | |
| } | |
| MODE_LABELS = { | |
| "txt2img": "Text → Image", | |
| "img2img": "Image → Image (denoise)", | |
| "ip_adapter": "IP-Adapter (style / subject)", | |
| "face_id": "Face identity (FaceID)", | |
| "pose": "Pose lock (ControlNet OpenPose)", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Registry | |
| # --------------------------------------------------------------------------- | |
| def load_registry(): | |
| """Read models.json and return the list of enabled model configs.""" | |
| with open(REGISTRY_PATH, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| models = [m for m in data.get("models", []) if m.get("enabled", True)] | |
| return models | |
| def get_model(models, model_id): | |
| for m in models: | |
| if m["id"] == model_id: | |
| return m | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Thai → English prompt translation (the SD/SDXL/FLUX text encoders are English; | |
| # Thai prompts otherwise produce unrelated images). Runs on the Space, no API key. | |
| # --------------------------------------------------------------------------- | |
| TRANSLATORS = { | |
| "nllb": "facebook/nllb-200-distilled-600M", | |
| "typhoon": "scb10x/llama3.2-typhoon2-3b-instruct", | |
| } | |
| _TRANSLATOR_CACHE = {} | |
| def has_thai(text): | |
| return any("" <= ch <= "" for ch in (text or "")) | |
| def _load_translator(engine): | |
| if engine in _TRANSLATOR_CACHE: | |
| return _TRANSLATOR_CACHE[engine] | |
| name = TRANSLATORS[engine] | |
| if engine == "nllb": | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| tok = AutoTokenizer.from_pretrained(name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(name, torch_dtype=DTYPE_SD) | |
| else: # typhoon (causal LM) | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| tok = AutoTokenizer.from_pretrained(name) | |
| model = AutoModelForCausalLM.from_pretrained(name, torch_dtype=DTYPE_SD) | |
| model.eval() | |
| _TRANSLATOR_CACHE[engine] = (tok, model) | |
| return tok, model | |
| def translate_prompt(text, engine): | |
| """Translate a Thai prompt to English. Pass-through if empty/English/off. | |
| MUST be called inside the @spaces.GPU context (uses CUDA when available).""" | |
| if not text or engine in (None, "off") or not has_thai(text): | |
| return text | |
| try: | |
| tok, model = _load_translator(engine) | |
| model = model.to(DEVICE) | |
| if engine == "nllb": | |
| tok.src_lang = "tha_Thai" | |
| inputs = tok(text, return_tensors="pt", truncation=True, | |
| max_length=400).to(DEVICE) | |
| bos = tok.convert_tokens_to_ids("eng_Latn") | |
| out = model.generate(**inputs, forced_bos_token_id=bos, | |
| max_new_tokens=256, num_beams=4) | |
| return tok.batch_decode(out, skip_special_tokens=True)[0].strip() | |
| # typhoon: ask the LLM to rewrite as a clean English image prompt | |
| msgs = [ | |
| {"role": "system", "content": "You convert Thai text-to-image prompts " | |
| "into a single concise, vivid English prompt for Stable Diffusion. " | |
| "Keep the described subject, clothing, pose, and scene. Output ONLY the " | |
| "English prompt as a comma-separated phrase — no quotes, no explanation."}, | |
| {"role": "user", "content": text}, | |
| ] | |
| chat = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) | |
| inputs = tok(chat, return_tensors="pt").to(DEVICE) | |
| eos = tok.eos_token_id | |
| pad = eos[0] if isinstance(eos, (list, tuple)) else eos | |
| with torch.no_grad(): | |
| out = model.generate(**inputs, max_new_tokens=256, do_sample=False, | |
| pad_token_id=pad) | |
| gen = out[0][inputs["input_ids"].shape[1]:] | |
| return tok.decode(gen, skip_special_tokens=True).strip().strip('"') | |
| except Exception as e: # noqa | |
| import traceback as _tb | |
| print(f"[translate] {engine} failed, using original text: " | |
| f"{type(e).__name__}: {e}") | |
| _tb.print_exc() | |
| return text | |
| # --------------------------------------------------------------------------- | |
| # Download helpers (Civitai / arbitrary URL → local cache) | |
| # --------------------------------------------------------------------------- | |
| def _download_url(url): | |
| """Download a (Civitai or other) URL to the local cache and return the path.""" | |
| if not url: | |
| return None | |
| fname = hashlib.sha1(url.encode()).hexdigest()[:16] + ".safetensors" | |
| dest = DOWNLOAD_DIR / fname | |
| if dest.exists() and dest.stat().st_size > 1_000_000: | |
| return str(dest) | |
| dl_url = url | |
| if "civitai.com" in url and CIVITAI_TOKEN and "token=" not in url: | |
| sep = "&" if "?" in url else "?" | |
| dl_url = f"{url}{sep}token={CIVITAI_TOKEN}" | |
| req = urllib.request.Request(dl_url, headers={"User-Agent": "Mozilla/5.0"}) | |
| print(f"[download] {url} -> {dest}") | |
| with urllib.request.urlopen(req) as resp, open(dest, "wb") as out: | |
| while True: | |
| chunk = resp.read(1 << 20) | |
| if not chunk: | |
| break | |
| out.write(chunk) | |
| return str(dest) | |
| # --------------------------------------------------------------------------- | |
| # Pipeline cache | |
| # --------------------------------------------------------------------------- | |
| # Keyed by model id. Stores the base txt2img pipeline (CPU). Adapters are loaded | |
| # on demand and tracked via the `_cs_adapter` attribute on the pipe. | |
| _PIPE_CACHE = {} | |
| _FACE_APP = None # lazy insightface FaceAnalysis | |
| def _free_cache(keep_id=None): | |
| """Evict cached pipelines except keep_id to bound memory (simple LRU-ish).""" | |
| for k in list(_PIPE_CACHE.keys()): | |
| if k != keep_id: | |
| del _PIPE_CACHE[k] | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def _build_base_pipeline(cfg): | |
| """Construct the txt2img pipeline for a model config (on CPU).""" | |
| base = cfg["base"] | |
| common = dict(token=HF_TOKEN) | |
| if base == "sd15": | |
| from diffusers import StableDiffusionPipeline | |
| if cfg.get("single_file_url"): | |
| local = _download_url(cfg["single_file_url"]) | |
| pipe = StableDiffusionPipeline.from_single_file( | |
| local, torch_dtype=DTYPE_SD, safety_checker=None | |
| ) | |
| else: | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| cfg["repo_id"], torch_dtype=DTYPE_SD, safety_checker=None, **common | |
| ) | |
| elif base == "sdxl": | |
| from diffusers import StableDiffusionXLPipeline | |
| if cfg.get("single_file_url"): | |
| local = _download_url(cfg["single_file_url"]) | |
| pipe = StableDiffusionXLPipeline.from_single_file(local, torch_dtype=DTYPE_SD) | |
| else: | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| cfg["repo_id"], torch_dtype=DTYPE_SD, **common | |
| ) | |
| elif base == "flux": | |
| from diffusers import FluxPipeline | |
| pipe = FluxPipeline.from_pretrained(cfg["repo_id"], torch_dtype=DTYPE, **common) | |
| else: | |
| raise ValueError(f"Unknown base family: {base}") | |
| # Apply LoRA if this entry is a LoRA model. | |
| if cfg.get("type") == "lora": | |
| scale = float(cfg.get("lora_scale", 0.8)) | |
| if cfg.get("lora_repo_id"): | |
| kwargs = {} | |
| if cfg.get("lora_weight_name"): | |
| kwargs["weight_name"] = cfg["lora_weight_name"] | |
| pipe.load_lora_weights(cfg["lora_repo_id"], **kwargs) | |
| elif cfg.get("lora_url"): | |
| local = _download_url(cfg["lora_url"]) | |
| pipe.load_lora_weights(local) | |
| try: | |
| pipe.fuse_lora(lora_scale=scale) | |
| except Exception as e: # noqa | |
| print(f"[lora] fuse skipped: {e}") | |
| # SD1.5 / SDXL community checkpoints are tuned for the Euler Ancestral sampler; | |
| # it matches the look people get in A1111 / ComfyUI far better than the default. | |
| if base in ("sd15", "sdxl"): | |
| from diffusers import EulerAncestralDiscreteScheduler | |
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
| pipe.set_progress_bar_config(disable=True) | |
| pipe._cs_adapter = None # track loaded IP-Adapter / FaceID state | |
| return pipe | |
| def get_pipeline(cfg): | |
| """Return a cached base pipeline for the model, building it if needed.""" | |
| mid = cfg["id"] | |
| if mid not in _PIPE_CACHE: | |
| _free_cache(keep_id=None) # one big model at a time on ZeroGPU | |
| print(f"[pipeline] building {mid} ({cfg['base']})") | |
| _PIPE_CACHE[mid] = _build_base_pipeline(cfg) | |
| return _PIPE_CACHE[mid] | |
| # --------------------------------------------------------------------------- | |
| # Adapter management (IP-Adapter / FaceID) | |
| # --------------------------------------------------------------------------- | |
| _IP_ADAPTER_SPECS = { | |
| "sd15": { | |
| "ip_adapter": dict(repo="h94/IP-Adapter", subfolder="models", | |
| weight_name="ip-adapter-plus_sd15.bin"), | |
| "face_id": dict(repo="h94/IP-Adapter-FaceID", subfolder=None, | |
| weight_name="ip-adapter-faceid_sd15.bin", | |
| image_encoder_folder=None), | |
| }, | |
| "sdxl": { | |
| "ip_adapter": dict(repo="h94/IP-Adapter", subfolder="sdxl_models", | |
| weight_name="ip-adapter-plus_sdxl_vit-h.bin"), | |
| "face_id": dict(repo="h94/IP-Adapter-FaceID", subfolder=None, | |
| weight_name="ip-adapter-faceid_sdxl.bin", | |
| image_encoder_folder=None), | |
| }, | |
| } | |
| def _ensure_adapter(pipe, base, mode): | |
| """Load the right IP-Adapter for `mode`, unloading any previous one.""" | |
| want = mode if mode in ("ip_adapter", "face_id") else None | |
| if pipe._cs_adapter == want: | |
| return | |
| try: | |
| pipe.unload_ip_adapter() | |
| except Exception: | |
| pass | |
| pipe._cs_adapter = None | |
| if want is None: | |
| return | |
| spec = _IP_ADAPTER_SPECS[base][want] | |
| kwargs = {k: v for k, v in spec.items() if k != "repo"} | |
| pipe.load_ip_adapter(spec["repo"], **kwargs) | |
| pipe._cs_adapter = want | |
| def _get_face_app(): | |
| global _FACE_APP | |
| if _FACE_APP is None: | |
| from insightface.app import FaceAnalysis | |
| app = FaceAnalysis(name="buffalo_l", | |
| providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) | |
| app.prepare(ctx_id=0, det_size=(640, 640)) | |
| _FACE_APP = app | |
| return _FACE_APP | |
| def _face_embeds(image): | |
| """Return a torch tensor of FaceID embeddings for the largest face.""" | |
| import numpy as np | |
| import cv2 | |
| app = _get_face_app() | |
| arr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) | |
| faces = app.get(arr) | |
| if not faces: | |
| raise ValueError("ไม่พบใบหน้าในรูปต้นแบบ / No face detected in the reference image.") | |
| faces = sorted(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])) | |
| emb = torch.from_numpy(faces[-1].normed_embedding) # [512] | |
| # diffusers IP-Adapter-FaceID expects [2, 1, 1, 512]: [neg, pos] for CFG. | |
| emb = emb.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, 512] | |
| return torch.cat([torch.zeros_like(emb), emb], dim=0).to(DTYPE_SD) | |
| # --------------------------------------------------------------------------- | |
| # Generation | |
| # --------------------------------------------------------------------------- | |
| # --------------------------------------------------------------------------- | |
| # ControlNet (OpenPose) — locks the generated subject to an uploaded pose. | |
| # --------------------------------------------------------------------------- | |
| _CONTROLNET = {} | |
| _OPENPOSE = None | |
| def _get_controlnet(base): | |
| if base in _CONTROLNET: | |
| return _CONTROLNET[base] | |
| from diffusers import ControlNetModel | |
| repos = { | |
| "sd15": "lllyasviel/control_v11p_sd15_openpose", | |
| "sdxl": "xinsir/controlnet-openpose-sdxl-1.0", | |
| } | |
| if base not in repos: | |
| raise ValueError("Pose (ControlNet) รองรับ SD1.5 / SDXL เท่านั้น.") | |
| cn = ControlNetModel.from_pretrained(repos[base], torch_dtype=DTYPE_SD) | |
| _CONTROLNET[base] = cn | |
| return cn | |
| def _get_openpose(): | |
| global _OPENPOSE | |
| if _OPENPOSE is None: | |
| from controlnet_aux import OpenposeDetector | |
| _OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators") | |
| return _OPENPOSE | |
| def _safe_call(pipe_obj, call): | |
| """Run the pipeline; if clip_skip trips a version incompatibility, retry without it.""" | |
| try: | |
| return pipe_obj(**call).images[0] | |
| except (AttributeError, TypeError) as e: | |
| if "clip_skip" in call: | |
| print(f"[clip_skip] disabled for this run due to: {e}") | |
| call.pop("clip_skip", None) | |
| return pipe_obj(**call).images[0] | |
| raise | |
| def run_generation(cfg, mode, prompt, negative_prompt, ref_image, | |
| steps, guidance, denoise, ip_scale, width, height, seed): | |
| """Run one generation. MUST be called inside a @spaces.GPU context.""" | |
| base = cfg["base"] | |
| if mode not in SUPPORTED_MODES[base]: | |
| raise ValueError( | |
| f"โหมด '{MODE_LABELS.get(mode, mode)}' ใช้กับ base {base.upper()} ไม่ได้ " | |
| f"(รองรับ: {', '.join(MODE_LABELS[m] for m in SUPPORTED_MODES[base])})" | |
| ) | |
| pipe = get_pipeline(cfg) | |
| pipe = pipe.to(DEVICE) | |
| generator = None | |
| if seed is not None and int(seed) >= 0: | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| full_prompt = prompt | |
| if cfg.get("trigger"): | |
| full_prompt = f"{cfg['trigger']}, {prompt}".strip(", ") | |
| call = dict( | |
| prompt=full_prompt, | |
| num_inference_steps=int(steps), | |
| generator=generator, | |
| width=int(width), | |
| height=int(height), | |
| ) | |
| # FLUX uses `guidance_scale` differently and has no negative prompt. | |
| if base == "flux": | |
| call["guidance_scale"] = float(guidance) | |
| else: | |
| call["guidance_scale"] = float(guidance) | |
| call["negative_prompt"] = negative_prompt or None | |
| # CLIP skip: registry stores the A1111 value (2 = skip one layer); diffusers | |
| # counts skipped layers, so A1111 "2" == diffusers clip_skip=1. | |
| a1111_cs = int(cfg.get("clip_skip", 1)) | |
| if a1111_cs > 1: | |
| call["clip_skip"] = a1111_cs - 1 | |
| # ----- mode wiring ----- | |
| if mode == "txt2img": | |
| _ensure_adapter(pipe, base, None) | |
| elif mode == "img2img": | |
| _ensure_adapter(pipe, base, None) if base != "flux" else None | |
| if ref_image is None: | |
| raise ValueError("img2img ต้องอัปโหลดรูปต้นแบบก่อน / Upload a reference image first.") | |
| from diffusers import AutoPipelineForImage2Image | |
| i2i = AutoPipelineForImage2Image.from_pipe(pipe).to(DEVICE) | |
| call.pop("width"); call.pop("height") | |
| call["image"] = ref_image.convert("RGB") | |
| call["strength"] = float(denoise) | |
| return _safe_call(i2i, call) | |
| elif mode == "ip_adapter": | |
| if ref_image is None: | |
| raise ValueError("IP-Adapter ต้องอัปโหลดรูปต้นแบบก่อน / Upload a reference image first.") | |
| _ensure_adapter(pipe, base, "ip_adapter") | |
| pipe.set_ip_adapter_scale(float(ip_scale)) | |
| call["ip_adapter_image"] = ref_image.convert("RGB") | |
| elif mode == "face_id": | |
| if ref_image is None: | |
| raise ValueError("Face identity ต้องอัปโหลดรูปใบหน้าก่อน / Upload a face image first.") | |
| _ensure_adapter(pipe, base, "face_id") | |
| pipe.set_ip_adapter_scale(float(ip_scale)) | |
| embeds = _face_embeds(ref_image).to(DEVICE) | |
| call["ip_adapter_image_embeds"] = [embeds] | |
| elif mode == "pose": | |
| if ref_image is None: | |
| raise ValueError("Pose ต้องอัปโหลดรูปท่าทางก่อน / Upload a pose reference image first.") | |
| _ensure_adapter(pipe, base, None) | |
| detector = _get_openpose() | |
| pose_img = detector(ref_image.convert("RGB")).resize((int(width), int(height))) | |
| cn = _get_controlnet(base).to(DEVICE) | |
| if base == "sdxl": | |
| from diffusers import StableDiffusionXLControlNetPipeline | |
| cn_pipe = StableDiffusionXLControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE) | |
| else: | |
| from diffusers import StableDiffusionControlNetPipeline | |
| cn_pipe = StableDiffusionControlNetPipeline.from_pipe(pipe, controlnet=cn).to(DEVICE) | |
| call["image"] = pose_img | |
| return _safe_call(cn_pipe, call) | |
| return _safe_call(pipe, call) | |