yingzhac commited on
Commit
0c9ebb1
·
verified ·
1 Parent(s): 74862c5

Add Turbo/Raw model switch (default Turbo)

Browse files
Files changed (2) hide show
  1. README.md +4 -3
  2. app.py +79 -27
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Krea 2 Turbo Demo
3
  emoji: 🎨
4
  colorFrom: indigo
5
  colorTo: purple
@@ -9,9 +9,10 @@ app_file: app.py
9
  python_version: "3.12"
10
  startup_duration_timeout: 1h
11
  pinned: true
12
- short_description: Krea 2 Turbo text-to-image demo with the Z-Image UI layout
13
  models:
14
  - krea/Krea-2-Turbo
 
15
  ---
16
 
17
- Krea 2 Turbo text-to-image demo using the same compact control layout as the Z-Image demo.
 
1
  ---
2
+ title: Krea 2 Demo
3
  emoji: 🎨
4
  colorFrom: indigo
5
  colorTo: purple
 
9
  python_version: "3.12"
10
  startup_duration_timeout: 1h
11
  pinned: true
12
+ short_description: Krea 2 demo with a Turbo/Raw model switch
13
  models:
14
  - krea/Krea-2-Turbo
15
+ - krea/Krea-2-Raw
16
  ---
17
 
18
+ Krea 2 text-to-image demo with a Turbo/Raw model switch, using the same compact control layout as the Z-Image demo.
app.py CHANGED
@@ -4,31 +4,61 @@ import torch
4
  from diffusers import Krea2Pipeline
5
 
6
 
7
- MODEL_ID = "krea/Krea-2-Turbo"
 
 
 
 
 
 
 
 
 
 
 
 
8
  DTYPE = torch.bfloat16
9
  MAX_SEED = 2**31 - 1
10
 
11
- pipe = None
12
 
13
 
14
- def get_pipe():
15
- global pipe
16
- if pipe is None:
17
- if not torch.cuda.is_available():
18
- raise RuntimeError("CUDA is not available. Set this Space hardware to ZeroGPU before generating.")
19
- print(f"Loading {MODEL_ID} pipeline...")
20
- pipe = Krea2Pipeline.from_pretrained(MODEL_ID, torch_dtype=DTYPE).to("cuda")
21
- print("Pipeline loaded!")
22
- return pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  if torch.cuda.is_available():
26
- get_pipe()
27
  else:
28
  print("CUDA is not available at startup. The UI will load, but generation requires ZeroGPU hardware.")
29
 
30
 
31
  def gpu_duration(
 
32
  prompt,
33
  negative_prompt,
34
  height,
@@ -45,6 +75,7 @@ def gpu_duration(
45
 
46
  @spaces.GPU(duration=gpu_duration, size="xlarge")
47
  def generate_image(
 
48
  prompt,
49
  negative_prompt,
50
  height,
@@ -55,13 +86,13 @@ def generate_image(
55
  randomize_seed,
56
  progress=gr.Progress(track_tqdm=True),
57
  ):
58
- """Generate 4 Krea 2 Turbo images with seeds: seed, 2x, 3x, 4x."""
59
  if not prompt or not str(prompt).strip():
60
  raise gr.Error("Enter a prompt to generate images.")
61
  if not torch.cuda.is_available():
62
  raise gr.Error("CUDA is not available. Set this Space hardware to ZeroGPU before generating.")
63
 
64
- pipe = get_pipe()
65
  if randomize_seed:
66
  seed = torch.randint(0, MAX_SEED, (1,)).item()
67
 
@@ -96,6 +127,12 @@ def generate_image(
96
  return images, ", ".join(str(s) for s in seeds)
97
 
98
 
 
 
 
 
 
 
99
  examples = [
100
  [
101
  "A russet harvest mouse clinging to a branch, macro photograph, shallow depth of field, "
@@ -115,18 +152,27 @@ examples = [
115
  ]
116
 
117
 
118
- with gr.Blocks(title="Krea 2 Turbo Demo") as demo:
119
  gr.Markdown(
120
  """
121
- # 🎨 Krea 2 Turbo Demo
122
 
123
- Generate images with [krea/Krea-2-Turbo](https://huggingface.co/krea/Krea-2-Turbo).
124
- Turbo is the fast distilled Krea 2 checkpoint; the recommended default is 8 steps with CFG 0.0.
 
 
125
  """
126
  )
127
 
128
  with gr.Row():
129
  with gr.Column(scale=1):
 
 
 
 
 
 
 
130
  prompt = gr.Textbox(
131
  label="Prompt",
132
  placeholder='Describe the image in natural language. Wrap rendered text in quotes, e.g. a sign that reads "open late".',
@@ -158,20 +204,20 @@ with gr.Blocks(title="Krea 2 Turbo Demo") as demo:
158
  with gr.Row():
159
  num_inference_steps = gr.Slider(
160
  minimum=1,
161
- maximum=28,
162
- value=8,
163
  step=1,
164
  label="Inference Steps",
165
- info="Krea 2 Turbo is designed for 8 steps.",
166
  )
167
 
168
  guidance_scale = gr.Slider(
169
  minimum=0.0,
170
- maximum=5.0,
171
- value=0.0,
172
  step=0.1,
173
  label="CFG Guidance Scale",
174
- info="Turbo default is 0.0. Negative prompt is ignored when CFG is 0.",
175
  )
176
 
177
  with gr.Row():
@@ -207,13 +253,19 @@ with gr.Blocks(title="Krea 2 Turbo Demo") as demo:
207
  )
208
 
209
  gr.Markdown(
210
- "Model by [Krea](https://huggingface.co/krea). "
211
- "This Space follows the Krea 2 Community License and uses the Turbo checkpoint for demo inference."
212
  )
213
 
214
- inputs = [prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, seed, randomize_seed]
215
  outputs = [output_images, used_seeds]
216
 
 
 
 
 
 
 
217
  generate_btn.click(
218
  fn=generate_image,
219
  inputs=inputs,
 
4
  from diffusers import Krea2Pipeline
5
 
6
 
7
+ MODELS = {
8
+ "Turbo": {
9
+ "id": "krea/Krea-2-Turbo",
10
+ "steps": 8,
11
+ "cfg": 0.0,
12
+ },
13
+ "Raw": {
14
+ "id": "krea/Krea-2-Raw",
15
+ "steps": 28,
16
+ "cfg": 4.5,
17
+ },
18
+ }
19
+ DEFAULT_MODE = "Turbo"
20
  DTYPE = torch.bfloat16
21
  MAX_SEED = 2**31 - 1
22
 
23
+ pipes = {}
24
 
25
 
26
+ def get_pipe(mode):
27
+ if mode not in MODELS:
28
+ raise gr.Error(f"Unknown model mode: {mode}")
29
+ if mode in pipes:
30
+ return pipes[mode]
31
+ if not torch.cuda.is_available():
32
+ raise RuntimeError("CUDA is not available. Set this Space hardware to ZeroGPU before generating.")
33
+
34
+ model_id = MODELS[mode]["id"]
35
+
36
+ def _load():
37
+ print(f"Loading {model_id} pipeline...")
38
+ loaded = Krea2Pipeline.from_pretrained(model_id, torch_dtype=DTYPE).to("cuda")
39
+ print(f"{mode} pipeline loaded!")
40
+ return loaded
41
+
42
+ try:
43
+ pipes[mode] = _load()
44
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as exc:
45
+ # Not enough VRAM to hold both checkpoints at once: drop the others and retry.
46
+ print(f"Load failed ({exc}); freeing other cached pipelines and retrying.")
47
+ for other in list(pipes):
48
+ del pipes[other]
49
+ torch.cuda.empty_cache()
50
+ pipes[mode] = _load()
51
+ return pipes[mode]
52
 
53
 
54
  if torch.cuda.is_available():
55
+ get_pipe(DEFAULT_MODE)
56
  else:
57
  print("CUDA is not available at startup. The UI will load, but generation requires ZeroGPU hardware.")
58
 
59
 
60
  def gpu_duration(
61
+ mode,
62
  prompt,
63
  negative_prompt,
64
  height,
 
75
 
76
  @spaces.GPU(duration=gpu_duration, size="xlarge")
77
  def generate_image(
78
+ mode,
79
  prompt,
80
  negative_prompt,
81
  height,
 
86
  randomize_seed,
87
  progress=gr.Progress(track_tqdm=True),
88
  ):
89
+ """Generate 4 Krea 2 images (Turbo or Raw) with seeds: seed, 2x, 3x, 4x."""
90
  if not prompt or not str(prompt).strip():
91
  raise gr.Error("Enter a prompt to generate images.")
92
  if not torch.cuda.is_available():
93
  raise gr.Error("CUDA is not available. Set this Space hardware to ZeroGPU before generating.")
94
 
95
+ pipe = get_pipe(mode)
96
  if randomize_seed:
97
  seed = torch.randint(0, MAX_SEED, (1,)).item()
98
 
 
127
  return images, ", ".join(str(s) for s in seeds)
128
 
129
 
130
+ def apply_mode(mode):
131
+ """Reset Steps and CFG to the selected checkpoint's recommended defaults."""
132
+ cfg = MODELS.get(mode, MODELS[DEFAULT_MODE])
133
+ return gr.update(value=cfg["steps"]), gr.update(value=cfg["cfg"])
134
+
135
+
136
  examples = [
137
  [
138
  "A russet harvest mouse clinging to a branch, macro photograph, shallow depth of field, "
 
152
  ]
153
 
154
 
155
+ with gr.Blocks(title="Krea 2 Demo") as demo:
156
  gr.Markdown(
157
  """
158
+ # 🎨 Krea 2 Demo
159
 
160
+ Generate images with Krea 2. Use the **Model** switch to choose
161
+ [Turbo](https://huggingface.co/krea/Krea-2-Turbo) (fast distilled 8 steps, CFG 0.0) or
162
+ [Raw](https://huggingface.co/krea/Krea-2-Raw) (full quality — ~28-52 steps, CFG 4.5).
163
+ Switching the model resets Steps and CFG to that checkpoint's recommended defaults.
164
  """
165
  )
166
 
167
  with gr.Row():
168
  with gr.Column(scale=1):
169
+ mode = gr.Radio(
170
+ choices=["Turbo", "Raw"],
171
+ value=DEFAULT_MODE,
172
+ label="Model",
173
+ info="Turbo = fast preview. Raw = full quality, slower.",
174
+ )
175
+
176
  prompt = gr.Textbox(
177
  label="Prompt",
178
  placeholder='Describe the image in natural language. Wrap rendered text in quotes, e.g. a sign that reads "open late".',
 
204
  with gr.Row():
205
  num_inference_steps = gr.Slider(
206
  minimum=1,
207
+ maximum=60,
208
+ value=MODELS[DEFAULT_MODE]["steps"],
209
  step=1,
210
  label="Inference Steps",
211
+ info="Turbo ~8 steps; Raw ~28-52 steps. Resets with the Model switch.",
212
  )
213
 
214
  guidance_scale = gr.Slider(
215
  minimum=0.0,
216
+ maximum=10.0,
217
+ value=MODELS[DEFAULT_MODE]["cfg"],
218
  step=0.1,
219
  label="CFG Guidance Scale",
220
+ info="Turbo uses 0.0; Raw uses ~4.5. Negative prompt is ignored when CFG is 0.",
221
  )
222
 
223
  with gr.Row():
 
253
  )
254
 
255
  gr.Markdown(
256
+ "Models by [Krea](https://huggingface.co/krea). "
257
+ "This Space follows the Krea 2 Community License and uses the Turbo and Raw checkpoints for demo inference."
258
  )
259
 
260
+ inputs = [mode, prompt, negative_prompt, height, width, num_inference_steps, guidance_scale, seed, randomize_seed]
261
  outputs = [output_images, used_seeds]
262
 
263
+ mode.change(
264
+ fn=apply_mode,
265
+ inputs=[mode],
266
+ outputs=[num_inference_steps, guidance_scale],
267
+ )
268
+
269
  generate_btn.click(
270
  fn=generate_image,
271
  inputs=inputs,