pormungtai's picture
Typhoon as default translator + Pose (ControlNet) for SDXL
e9c180e verified
Raw
History Blame
10.5 kB
"""
Character Studio — a ZeroGPU Hugging Face Space.
A multi-model character generator: pick a model from an editable registry,
type a prompt, optionally drop a reference image, and generate. Supports
SD1.5 / SDXL / FLUX bases and txt2img / img2img / IP-Adapter / FaceID modes.
Add or remove models by editing models.json (no code change needed), then
click "🔄 Reload models" or restart the Space.
"""
import random
import traceback
import spaces # must be imported before torch on ZeroGPU
import gradio as gr
# --- Workaround for a gradio_client bug: "argument of type 'bool' is not iterable".
# Some component api_info schemas carry a boolean `additionalProperties`, which the
# schema-to-type walker chokes on, crashing get_api_info() and the whole launch.
# Short-circuit any non-dict schema to "Any". Version-independent and harmless. ---
import gradio_client.utils as _gcu
_orig_get_type = _gcu.get_type
def _safe_get_type(schema):
if not isinstance(schema, dict):
return "Any"
return _orig_get_type(schema)
_gcu.get_type = _safe_get_type
_orig_j2pt = _gcu._json_schema_to_python_type
def _safe_j2pt(schema, *args, **kwargs):
if isinstance(schema, bool):
return "Any"
return _orig_j2pt(schema, *args, **kwargs)
_gcu._json_schema_to_python_type = _safe_j2pt
import pipeline_manager as pm
MAX_SEED = 2**31 - 1
# ---------------------------------------------------------------------------
# Registry helpers
# ---------------------------------------------------------------------------
def load_models():
return pm.load_registry()
MODELS = load_models()
def model_choices(models):
return [(m["label"], m["id"]) for m in models]
# Placeholder shown in the picker when a model has no `preview` image.
_FALLBACK_PREVIEW = "https://placehold.co/400x520/1e293b/93c5fd/png?text=Model"
def gallery_items(models):
"""List of (image_url, caption) for the model picker gallery."""
return [(m.get("preview") or _FALLBACK_PREVIEW, m["label"]) for m in models]
def model_ids(models):
return [m["id"] for m in models]
def modes_for(models, model_id):
m = pm.get_model(models, model_id)
if not m:
return [("Text → Image", "txt2img")]
return [(pm.MODE_LABELS[k], k) for k in pm.SUPPORTED_MODES[m["base"]]]
# ---------------------------------------------------------------------------
# GPU generation
# ---------------------------------------------------------------------------
@spaces.GPU(duration=120)
def generate(model_id, mode, prompt, negative_prompt, ref_image,
steps, guidance, denoise, ip_scale, width, height, seed, randomize,
translator):
models = load_models()
cfg = pm.get_model(models, model_id)
if cfg is None:
raise gr.Error("ไม่พบโมเดลที่เลือก โปรด Reload models / Selected model not found.")
if randomize or seed is None or int(seed) < 0:
seed = random.randint(0, MAX_SEED)
# Thai → English so the (English) text encoders understand the prompt.
note = ""
orig_prompt = prompt
prompt = pm.translate_prompt(prompt, translator)
negative_prompt = pm.translate_prompt(negative_prompt, translator)
if prompt != orig_prompt:
note = f" · 🌐 {translator}: _{prompt[:120]}_"
try:
img = pm.run_generation(
cfg=cfg, mode=mode, prompt=prompt, negative_prompt=negative_prompt,
ref_image=ref_image, steps=steps, guidance=guidance, denoise=denoise,
ip_scale=ip_scale, width=width, height=height, seed=seed,
)
except Exception as e:
traceback.print_exc()
raise gr.Error(str(e))
status = f"✅ {cfg['label']} · {pm.MODE_LABELS.get(mode, mode)} · seed {seed}{note}"
return img, seed, status
# ---------------------------------------------------------------------------
# UI callbacks
# ---------------------------------------------------------------------------
def _apply_model(models, model_id):
"""Shared: given a selected model id, return the dependent UI updates."""
cfg = pm.get_model(models, model_id)
if not cfg:
return (*[gr.update() for _ in range(7)], "ไม่พบโมเดล / model not found")
choices = modes_for(models, model_id)
return (
gr.update(choices=choices, value=choices[0][1]), # mode radio
gr.update(placeholder=cfg.get("recommended_prompt", "")), # prompt
gr.update(value=cfg.get("negative_prompt", "")), # negative
gr.update(value=cfg.get("default_steps", 28)), # steps
gr.update(value=cfg.get("default_guidance", 6.0)), # guidance
gr.update(value=cfg.get("default_width", 768)), # width
gr.update(value=cfg.get("default_height", 768)), # height
f"**เลือก / Selected:** {cfg['label']}", # selected label
)
def on_gallery_select(ids, evt: gr.SelectData):
"""A model card was clicked. evt.index → model id."""
models = load_models()
mid = ids[evt.index] if ids and 0 <= evt.index < len(ids) else None
return (mid, *_apply_model(models, mid))
def reload_registry():
global MODELS
MODELS = load_models()
first = MODELS[0]["id"] if MODELS else None
return (
gr.update(value=gallery_items(MODELS)), # gallery
model_ids(MODELS), # ids state
first, # selected id state
f"🔄 โหลดแล้ว {len(MODELS)} โมเดล",
)
# ---------------------------------------------------------------------------
# Layout (mirrors the FLUX LoRA DLC reference UI)
# ---------------------------------------------------------------------------
CSS = """
#gen-btn {height: 100%; font-size: 1.3rem; font-weight: 700;}
.card {border-radius: 14px;}
footer {visibility: hidden;}
"""
with gr.Blocks(css=CSS, theme=gr.themes.Soft(primary_hue="blue"),
title="Character Studio") as demo:
gr.Markdown("## 🎭 Character Studio — multi-model character generator (ZeroGPU)")
with gr.Row():
prompt = gr.Textbox(
label="Edit Prompt", lines=2, scale=4,
placeholder="✦ เลือกโมเดลแล้วพิมพ์ prompt / Choose a model and type the prompt",
)
gen_btn = gr.Button("Generate", variant="primary", scale=1, elem_id="gen-btn")
# State: ordered list of model ids (matches gallery order) + current selection.
ids_state = gr.State(model_ids(MODELS))
selected_id = gr.State(MODELS[0]["id"] if MODELS else None)
with gr.Row(equal_height=False):
# ---- left: model picker ----
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### 🧩 เลือกโมเดล / Models")
model_gallery = gr.Gallery(
value=gallery_items(MODELS),
label=None, show_label=False, columns=2, height="auto",
object_fit="cover", allow_preview=False, container=False,
elem_classes="card",
)
selected_md = gr.Markdown(
f"**เลือก / Selected:** {MODELS[0]['label']}" if MODELS else ""
)
reload_btn = gr.Button("🔄 Reload models", size="sm")
reload_status = gr.Markdown("")
mode_radio = gr.Radio(
choices=modes_for(MODELS, MODELS[0]["id"]) if MODELS else [],
value="txt2img",
label="โหมดรูปต้นแบบ / Input mode",
)
translator = gr.Radio(
choices=[("ปิด / Off", "off"),
("NLLB-200 (เร็ว)", "nllb"),
("Typhoon 2 (ไทยแน่น)", "typhoon")],
value="typhoon",
label="แปลไทย→อังกฤษ / Auto-translate (พิมพ์ไทยได้เลย)",
)
# ---- right: output ----
with gr.Column(scale=1):
output = gr.Image(label="Generated Image", height=560, elem_classes="card")
status = gr.Markdown("")
# ---- advanced ----
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
with gr.Column():
ref_image = gr.Image(label="Input image (รูปต้นแบบ)", type="pil", height=240)
ip_scale = gr.Slider(0.0, 1.5, value=0.7, step=0.05,
label="Reference strength (IP-Adapter / FaceID)")
denoise = gr.Slider(0.1, 1.0, value=0.65, step=0.01,
label="Denoise strength (img2img · ต่ำ = อิงรูปมาก)")
with gr.Column():
negative_prompt = gr.Textbox(label="Negative prompt", lines=2)
with gr.Row():
steps = gr.Slider(1, 50, value=28, step=1, label="Steps")
guidance = gr.Slider(0.0, 15.0, value=6.5, step=0.1, label="Guidance (CFG)")
with gr.Row():
width = gr.Slider(384, 1280, value=512, step=64, label="Width")
height = gr.Slider(384, 1280, value=768, step=64, label="Height")
with gr.Row():
seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
randomize = gr.Checkbox(value=True, label="Randomize seed")
# ---- wiring ----
model_gallery.select(
on_gallery_select, inputs=[ids_state],
outputs=[selected_id, mode_radio, prompt, negative_prompt,
steps, guidance, width, height, selected_md],
)
reload_btn.click(
reload_registry,
outputs=[model_gallery, ids_state, selected_id, reload_status],
)
gen_inputs = [selected_id, mode_radio, prompt, negative_prompt, ref_image,
steps, guidance, denoise, ip_scale, width, height, seed, randomize,
translator]
gen_btn.click(generate, inputs=gen_inputs, outputs=[output, seed, status])
prompt.submit(generate, inputs=gen_inputs, outputs=[output, seed, status])
if __name__ == "__main__":
# allowed_paths lets Gradio serve the local model preview thumbnails.
demo.queue(max_size=12).launch(allowed_paths=["previews"])