Krea_2_demo / app.py
yingzhac's picture
Add Turbo/Raw model switch (default Turbo)
0c9ebb1 verified
Raw
History Blame Contribute Delete
8.61 kB
import gradio as gr
import spaces
import torch
from diffusers import Krea2Pipeline
MODELS = {
"Turbo": {
"id": "krea/Krea-2-Turbo",
"steps": 8,
"cfg": 0.0,
},
"Raw": {
"id": "krea/Krea-2-Raw",
"steps": 28,
"cfg": 4.5,
},
}
DEFAULT_MODE = "Turbo"
DTYPE = torch.bfloat16
MAX_SEED = 2**31 - 1
pipes = {}
def get_pipe(mode):
if mode not in MODELS:
raise gr.Error(f"Unknown model mode: {mode}")
if mode in pipes:
return pipes[mode]
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Set this Space hardware to ZeroGPU before generating.")
model_id = MODELS[mode]["id"]
def _load():
print(f"Loading {model_id} pipeline...")
loaded = Krea2Pipeline.from_pretrained(model_id, torch_dtype=DTYPE).to("cuda")
print(f"{mode} pipeline loaded!")
return loaded
try:
pipes[mode] = _load()
except (torch.cuda.OutOfMemoryError, RuntimeError) as exc:
# Not enough VRAM to hold both checkpoints at once: drop the others and retry.
print(f"Load failed ({exc}); freeing other cached pipelines and retrying.")
for other in list(pipes):
del pipes[other]
torch.cuda.empty_cache()
pipes[mode] = _load()
return pipes[mode]
if torch.cuda.is_available():
get_pipe(DEFAULT_MODE)
else:
print("CUDA is not available at startup. The UI will load, but generation requires ZeroGPU hardware.")
def gpu_duration(
mode,
prompt,
negative_prompt,
height,
width,
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
progress=None,
):
megapixels = max(1.0, (int(width) * int(height)) / (1024 * 1024))
return min(300, int(int(num_inference_steps) * 8 * megapixels * 4 + 180))
@spaces.GPU(duration=gpu_duration, size="xlarge")
def generate_image(
mode,
prompt,
negative_prompt,
height,
width,
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
progress=gr.Progress(track_tqdm=True),
):
"""Generate 4 Krea 2 images (Turbo or Raw) with seeds: seed, 2x, 3x, 4x."""
if not prompt or not str(prompt).strip():
raise gr.Error("Enter a prompt to generate images.")
if not torch.cuda.is_available():
raise gr.Error("CUDA is not available. Set this Space hardware to ZeroGPU before generating.")
pipe = get_pipe(mode)
if randomize_seed:
seed = torch.randint(0, MAX_SEED, (1,)).item()
base_seed = int(seed) % MAX_SEED
seeds = [(base_seed * i) % MAX_SEED for i in range(1, 5)]
guidance = float(guidance_scale)
neg_prompt = None
if guidance > 0 and isinstance(negative_prompt, str) and negative_prompt.strip():
neg_prompt = negative_prompt
images = []
try:
for current_seed in seeds:
generator = torch.Generator("cuda").manual_seed(int(current_seed))
image = pipe(
prompt=str(prompt).strip(),
negative_prompt=neg_prompt,
height=int(height),
width=int(width),
num_inference_steps=int(num_inference_steps),
guidance_scale=guidance,
generator=generator,
).images[0]
images.append(image)
except RuntimeError as exc:
torch.cuda.empty_cache()
raise gr.Error(
f"Generation failed at {int(width)}x{int(height)}. Try 1024x1024, fewer images, or fewer steps."
) from exc
return images, ", ".join(str(s) for s in seeds)
def apply_mode(mode):
"""Reset Steps and CFG to the selected checkpoint's recommended defaults."""
cfg = MODELS.get(mode, MODELS[DEFAULT_MODE])
return gr.update(value=cfg["steps"]), gr.update(value=cfg["cfg"])
examples = [
[
"A russet harvest mouse clinging to a branch, macro photograph, shallow depth of field, "
"creamy green bokeh, soft natural light"
],
[
'A quiet city bookstore at night, rain on the windows, a neon sign that reads "open late", '
"cinematic warm lighting, detailed reflections"
],
[
"A fashion editorial portrait of a model wearing sculptural silver fabric, clean studio backdrop, "
"softbox lighting, high-end magazine photography"
],
[
"A whimsical hand-drawn village built inside a giant teacup, watercolor texture, cozy evening light"
],
]
with gr.Blocks(title="Krea 2 Demo") as demo:
gr.Markdown(
"""
# 🎨 Krea 2 Demo
Generate images with Krea 2. Use the **Model** switch to choose
[Turbo](https://huggingface.co/krea/Krea-2-Turbo) (fast distilled — 8 steps, CFG 0.0) or
[Raw](https://huggingface.co/krea/Krea-2-Raw) (full quality — ~28-52 steps, CFG 4.5).
Switching the model resets Steps and CFG to that checkpoint's recommended defaults.
"""
)
with gr.Row():
with gr.Column(scale=1):
mode = gr.Radio(
choices=["Turbo", "Raw"],
value=DEFAULT_MODE,
label="Model",
info="Turbo = fast preview. Raw = full quality, slower.",
)
prompt = gr.Textbox(
label="Prompt",
placeholder='Describe the image in natural language. Wrap rendered text in quotes, e.g. a sign that reads "open late".',
lines=4,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Only used when CFG Guidance Scale is above 0.",
lines=3,
)
with gr.Row():
height = gr.Slider(
minimum=512,
maximum=1536,
value=1024,
step=16,
label="Height",
)
width = gr.Slider(
minimum=512,
maximum=1536,
value=1024,
step=16,
label="Width",
)
with gr.Row():
num_inference_steps = gr.Slider(
minimum=1,
maximum=60,
value=MODELS[DEFAULT_MODE]["steps"],
step=1,
label="Inference Steps",
info="Turbo ~8 steps; Raw ~28-52 steps. Resets with the Model switch.",
)
guidance_scale = gr.Slider(
minimum=0.0,
maximum=10.0,
value=MODELS[DEFAULT_MODE]["cfg"],
step=0.1,
label="CFG Guidance Scale",
info="Turbo uses 0.0; Raw uses ~4.5. Negative prompt is ignored when CFG is 0.",
)
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
precision=0,
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=False,
)
generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
with gr.Column(scale=1):
output_images = gr.Gallery(
label="Generated Images",
columns=2,
rows=2,
preview=True,
)
used_seeds = gr.Textbox(
label="Seeds Used (base, 2x, 3x, 4x)",
interactive=False,
)
gr.Markdown("### 💡 Example Prompts")
gr.Examples(
examples=examples,
inputs=[prompt],
cache_examples=False,
)
gr.Markdown(
"Models by [Krea](https://huggingface.co/krea). "
"This Space follows the Krea 2 Community License and uses the Turbo and Raw checkpoints for demo inference."
)
inputs = [mode, prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, seed, randomize_seed]
outputs = [output_images, used_seeds]
mode.change(
fn=apply_mode,
inputs=[mode],
outputs=[num_inference_steps, guidance_scale],
)
generate_btn.click(
fn=generate_image,
inputs=inputs,
outputs=outputs,
api_name="generate_image",
)
prompt.submit(
fn=generate_image,
inputs=inputs,
outputs=outputs,
api_name="generate_image_submit",
)
if __name__ == "__main__":
demo.launch(mcp_server=True, show_error=True)