Krea_2_demo / app.py
yingzhac's picture
Create Krea 2 Turbo demo
74862c5 verified
Raw
History Blame
6.96 kB
import gradio as gr
import spaces
import torch
from diffusers import Krea2Pipeline
MODEL_ID = "krea/Krea-2-Turbo"
DTYPE = torch.bfloat16
MAX_SEED = 2**31 - 1
pipe = None
def get_pipe():
global pipe
if pipe is None:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available. Set this Space hardware to ZeroGPU before generating.")
print(f"Loading {MODEL_ID} pipeline...")
pipe = Krea2Pipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE).to("cuda")
print("Pipeline loaded!")
return pipe
if torch.cuda.is_available():
get_pipe()
else:
print("CUDA is not available at startup. The UI will load, but generation requires ZeroGPU hardware.")
def gpu_duration(
prompt,
negative_prompt,
height,
width,
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
progress=None,
):
megapixels = max(1.0, (int(width) * int(height)) / (1024 * 1024))
return min(300, int(int(num_inference_steps) * 8 * megapixels * 4 + 180))
@spaces.GPU(duration=gpu_duration, size="xlarge")
def generate_image(
prompt,
negative_prompt,
height,
width,
num_inference_steps,
guidance_scale,
seed,
randomize_seed,
progress=gr.Progress(track_tqdm=True),
):
"""Generate 4 Krea 2 Turbo images with seeds: seed, 2x, 3x, 4x."""
if not prompt or not str(prompt).strip():
raise gr.Error("Enter a prompt to generate images.")
if not torch.cuda.is_available():
raise gr.Error("CUDA is not available. Set this Space hardware to ZeroGPU before generating.")
pipe = get_pipe()
if randomize_seed:
seed = torch.randint(0, MAX_SEED, (1,)).item()
base_seed = int(seed) % MAX_SEED
seeds = [(base_seed * i) % MAX_SEED for i in range(1, 5)]
guidance = float(guidance_scale)
neg_prompt = None
if guidance > 0 and isinstance(negative_prompt, str) and negative_prompt.strip():
neg_prompt = negative_prompt
images = []
try:
for current_seed in seeds:
generator = torch.Generator("cuda").manual_seed(int(current_seed))
image = pipe(
prompt=str(prompt).strip(),
negative_prompt=neg_prompt,
height=int(height),
width=int(width),
num_inference_steps=int(num_inference_steps),
guidance_scale=guidance,
generator=generator,
).images[0]
images.append(image)
except RuntimeError as exc:
torch.cuda.empty_cache()
raise gr.Error(
f"Generation failed at {int(width)}x{int(height)}. Try 1024x1024, fewer images, or fewer steps."
) from exc
return images, ", ".join(str(s) for s in seeds)
examples = [
[
"A russet harvest mouse clinging to a branch, macro photograph, shallow depth of field, "
"creamy green bokeh, soft natural light"
],
[
'A quiet city bookstore at night, rain on the windows, a neon sign that reads "open late", '
"cinematic warm lighting, detailed reflections"
],
[
"A fashion editorial portrait of a model wearing sculptural silver fabric, clean studio backdrop, "
"softbox lighting, high-end magazine photography"
],
[
"A whimsical hand-drawn village built inside a giant teacup, watercolor texture, cozy evening light"
],
]
with gr.Blocks(title="Krea 2 Turbo Demo") as demo:
gr.Markdown(
"""
# 🎨 Krea 2 Turbo Demo
Generate images with [krea/Krea-2-Turbo](https://huggingface.co/krea/Krea-2-Turbo).
Turbo is the fast distilled Krea 2 checkpoint; the recommended default is 8 steps with CFG 0.0.
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder='Describe the image in natural language. Wrap rendered text in quotes, e.g. a sign that reads "open late".',
lines=4,
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Only used when CFG Guidance Scale is above 0.",
lines=3,
)
with gr.Row():
height = gr.Slider(
minimum=512,
maximum=1536,
value=1024,
step=16,
label="Height",
)
width = gr.Slider(
minimum=512,
maximum=1536,
value=1024,
step=16,
label="Width",
)
with gr.Row():
num_inference_steps = gr.Slider(
minimum=1,
maximum=28,
value=8,
step=1,
label="Inference Steps",
info="Krea 2 Turbo is designed for 8 steps.",
)
guidance_scale = gr.Slider(
minimum=0.0,
maximum=5.0,
value=0.0,
step=0.1,
label="CFG Guidance Scale",
info="Turbo default is 0.0. Negative prompt is ignored when CFG is 0.",
)
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
precision=0,
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=False,
)
generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
with gr.Column(scale=1):
output_images = gr.Gallery(
label="Generated Images",
columns=2,
rows=2,
preview=True,
)
used_seeds = gr.Textbox(
label="Seeds Used (base, 2x, 3x, 4x)",
interactive=False,
)
gr.Markdown("### 💡 Example Prompts")
gr.Examples(
examples=examples,
inputs=[prompt],
cache_examples=False,
)
gr.Markdown(
"Model by [Krea](https://huggingface.co/krea). "
"This Space follows the Krea 2 Community License and uses the Turbo checkpoint for demo inference."
)
inputs = [prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, seed, randomize_seed]
outputs = [output_images, used_seeds]
generate_btn.click(
fn=generate_image,
inputs=inputs,
outputs=outputs,
api_name="generate_image",
)
prompt.submit(
fn=generate_image,
inputs=inputs,
outputs=outputs,
api_name="generate_image_submit",
)
if __name__ == "__main__":
demo.launch(mcp_server=True, show_error=True)