witcherderivia commited on
Commit
1b95aae
·
verified ·
1 Parent(s): 10e235d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -33
app.py CHANGED
@@ -5,16 +5,29 @@ import torch
5
  import spaces
6
 
7
  from PIL import Image
8
- from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
 
 
 
 
9
 
10
  import os
11
 
12
  from huggingface_hub import hf_hub_download
13
 
 
 
 
 
14
 
 
 
15
 
 
 
16
 
17
 
 
18
  pipe = QwenImagePipeline.from_pretrained(
19
  torch_dtype=torch.bfloat16,
20
  device="cuda",
@@ -31,16 +44,32 @@ pipe = QwenImagePipeline.from_pretrained(
31
  processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509",
32
  download_source='huggingface',origin_file_pattern="processor/"),
33
  )
 
34
 
35
 
36
 
37
 
38
  speedup = hf_hub_download(repo_id="witcherderivia/Qwen-Image-Style-Transfer", filename="diffsynth_Qwen-Image-Edit-2509-Lightning-4steps-V1.0-bf16.safetensors")
39
- qwenstyle= hf_hub_download(repo_id="witcherderivia/Qwen-Image-Style-Transfer", filename="diffsynth_Qwen-Image-Edit-2509-Style-Transfer-V1.safetensors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
- pipe.load_lora(pipe.dit, qwenstyle)
43
- pipe.load_lora(pipe.dit,speedup)
44
 
45
 
46
 
@@ -53,7 +82,9 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
53
  MAX_SEED = np.iinfo(np.int32).max
54
 
55
 
56
- @spaces.GPU
 
 
57
  def infer(
58
  content_ref,
59
  style_ref,
@@ -64,6 +95,7 @@ def infer(
64
  num_inference_steps=4,
65
  minedge=1024,
66
  progress=gr.Progress(track_tqdm=True),
 
67
 
68
  ):
69
 
@@ -71,51 +103,186 @@ def infer(
71
 
72
 
73
 
74
- content_ref=Image.fromarray(content_ref)
75
- style_ref=Image.fromarray(style_ref)
76
 
77
- if randomize_seed:
78
- seed = random.randint(0, MAX_SEED)
 
 
 
79
 
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
 
85
 
86
- w,h=content_ref.size
 
 
87
 
88
 
89
 
90
- #minedge=1024
91
- if w>h:
92
- r=w/h
93
- h=minedge
94
- w=int(h*r)-int(h*r)%16
95
-
96
- else:
97
- r=h/w
98
- w=minedge
99
- h=int(w*r)-int(w*r)%16
 
 
 
 
 
 
 
 
 
 
 
100
 
101
 
102
 
103
- print(f"Calling pipeline with prompt: '{prompt}'")
104
 
105
- print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}, Size: {w}x{h}")
106
 
107
- images = [
108
- content_ref.resize((w, h)),
109
- style_ref.resize((minedge, minedge)) ,
110
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
- # Generate the image
114
 
115
- image = pipe(prompt, edit_image=images, seed=seed, num_inference_steps=num_inference_steps, height=h, width=w,edit_image_auto_resize=False,cfg_scale=true_guidance_scale)#ligtning
116
 
117
 
118
- return image, seed
 
 
119
 
120
  # --- Examples and UI Layout ---
121
  examples = []
@@ -130,7 +297,7 @@ _HEADER_ = '''
130
 
131
 
132
  <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://openreview.net/forum?id=Cgb7JpOA5Q&referrer=%5Bthe%20profile%20of%20Shiwen%20Zhang%5D(%2Fprofile%3Fid%3D~Shiwen_Zhang1)' target='_blank'>QwenStyle: Content-Preserving Style Transfer with Qwen-Image-Edit</a> | Codes: <a href='https://github.com/witcherofresearch/Qwen-Image-Style-Transfer' target='_blank'>GitHub</a></p>
133
- <p style="font-size: 1rem; margin-bottom: 1.5rem;">If you encounter an Error with this demo, the most possible reason is ZeroGPU out-of-memory and the solution is to decrease the Min Edge of the generated image from 1024 to a lower value. This is because ZeroGPU has a memory limit of 70GB, while all the examples are tested with 80GB H100 GPUs. </p>
134
  '''
135
 
136
  with gr.Blocks() as demo:
@@ -227,8 +394,9 @@ with gr.Blocks() as demo:
227
  randomize_seed,
228
  true_guidance_scale,
229
  num_inference_steps,
230
- minedge,],
231
- outputs=[result, seed],
 
232
  fn=infer,
233
  cache_examples=False
234
  )
@@ -251,10 +419,14 @@ with gr.Blocks() as demo:
251
  true_guidance_scale,
252
  num_inference_steps,
253
  minedge,
 
254
 
255
  ],
256
- outputs=[result, seed],
257
  )
258
 
 
 
 
259
  if __name__ == "__main__":
260
  demo.launch(server_name='0.0.0.0')
 
5
  import spaces
6
 
7
  from PIL import Image
8
+ #from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
9
+ from pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
10
+ from qwen_vl_utils import process_vision_info
11
+
12
+
13
 
14
  import os
15
 
16
  from huggingface_hub import hf_hub_download
17
 
18
+ def update_textbox(selected_items):
19
+ # Join the selected list of strings into a comma-separated string
20
+ return ", ".join(selected_items)
21
+
22
 
23
+ pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16)
24
+ print("pipeline loaded")
25
 
26
+ pipe.to('cuda')
27
+ pipe.set_progress_bar_config(disable=None)
28
 
29
 
30
+ '''
31
  pipe = QwenImagePipeline.from_pretrained(
32
  torch_dtype=torch.bfloat16,
33
  device="cuda",
 
44
  processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit-2509",
45
  download_source='huggingface',origin_file_pattern="processor/"),
46
  )
47
+ '''
48
 
49
 
50
 
51
 
52
  speedup = hf_hub_download(repo_id="witcherderivia/Qwen-Image-Style-Transfer", filename="diffsynth_Qwen-Image-Edit-2509-Lightning-4steps-V1.0-bf16.safetensors")
53
+ qwenstyle= hf_hub_download(repo_id="witcherderivia/Qwen-Image-Style-Transfer", filename="diffusers_Qwen-Image-Edit-2509-Style-Transfer-V1.safetensors")
54
+
55
+
56
+
57
+
58
+ pipe.load_lora_weights(
59
+ qwenstyle,adapter_name='style'
60
+ )
61
+
62
+
63
+ pipe.load_lora_weights(
64
+ speedup,adapter_name='dmd'
65
+ )
66
+
67
+ pipe.set_adapters(["style", "dmd",], adapter_weights=[1.0, 1.0])
68
+ pipe.fuse_lora(adapter_names=["style", "dmd"], lora_scale=1.0)
69
+ pipe.unload_lora_weights()
70
+
71
 
72
 
 
 
73
 
74
 
75
 
 
82
  MAX_SEED = np.iinfo(np.int32).max
83
 
84
 
85
+ @spaces.GPU(size="xlarge")
86
+
87
+
88
  def infer(
89
  content_ref,
90
  style_ref,
 
95
  num_inference_steps=4,
96
  minedge=1024,
97
  progress=gr.Progress(track_tqdm=True),
98
+ checkbox=[],
99
 
100
  ):
101
 
 
103
 
104
 
105
 
 
 
106
 
107
+ content_text_input='describe main objects (fewer than 3) with separated words, each word is separated by comma, the total number of words is strictly fewer than 3'
108
+ style_text_input='describe only the artistic style, material and stroke in 5 words, not objects.'
109
+ #pipe.text_encoder.eval()
110
+ content_prompt=''
111
+ style_prompt=''
112
 
113
 
114
 
115
+
116
+
117
+
118
+ if content_ref is not None:
119
+ content_ref=Image.fromarray(content_ref)
120
+ content_messages = [
121
+ {
122
+ "role": "user",
123
+ "content": [
124
+ {
125
+ "type": "image",
126
+ "image": content_ref,
127
+ },
128
+ {"type": "text", "text": content_text_input},
129
+ ],
130
+ }
131
+ ]
132
+ content_text = pipe.processor.apply_chat_template(
133
+ content_messages, tokenize=False, add_generation_prompt=True
134
+ )
135
+ image_inputs, video_inputs = process_vision_info(content_messages)
136
+ inputs = pipe.processor(
137
+ text=[content_text],
138
+ images=image_inputs,
139
+ videos=video_inputs,
140
+ padding=True,
141
+ return_tensors="pt",
142
+ )
143
+ inputs = inputs.to(device)
144
+
145
+ # Inference: Generation of the output
146
+ generated_ids = pipe.text_encoder.generate(**inputs, max_new_tokens=1024)
147
+ generated_ids_trimmed = [
148
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
149
+ ]
150
+ content_prompt = pipe.processor.batch_decode(
151
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
152
+ )[0]
153
+ print(f"content_prompt={content_prompt}")
154
+ if style_ref is not None:
155
+ style_ref=Image.fromarray(style_ref)
156
+ style_messages = [
157
+ {
158
+ "role": "user",
159
+ "content": [
160
+ {
161
+ "type": "image",
162
+ "image": style_ref,
163
+ },
164
+ {"type": "text", "text": style_text_input},
165
+ ],
166
+ }
167
+ ]
168
+ style_text = pipe.processor.apply_chat_template(
169
+ style_messages, tokenize=False, add_generation_prompt=True
170
+ )
171
+ image_inputs, video_inputs = process_vision_info(style_messages)
172
+ inputs = pipe.processor(
173
+ text=[style_text],
174
+ images=image_inputs,
175
+ videos=video_inputs,
176
+ padding=True,
177
+ return_tensors="pt",
178
+ )
179
+ inputs = inputs.to(device)
180
+
181
+ # Inference: Generation of the output
182
+ generated_ids = pipe.text_encoder.generate(**inputs, max_new_tokens=1024)
183
+ generated_ids_trimmed = [
184
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
185
+ ]
186
+ style_prompt = pipe.processor.batch_decode(
187
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
188
+ )[0]
189
+ print(f"style_prompt={style_prompt}")
190
+
191
+ if randomize_seed:
192
+ seed = random.randint(0, MAX_SEED)
193
+
194
 
195
 
196
 
197
 
198
+ sw,sh,w,h=0,0,0,0
199
+ if content_ref:
200
+ w,h=content_ref.size
201
 
202
 
203
 
204
+ #minedge=1024
205
+ if w>h:
206
+ r=w/h
207
+ h=minedge
208
+ w=int(h*r)-int(h*r)%16
209
+
210
+ else:
211
+ r=h/w
212
+ w=minedge
213
+ h=int(w*r)-int(w*r)%16
214
+ if style_ref:
215
+ sw,sh=style_ref.size
216
+ if sw>sh:
217
+ r=sw/sh
218
+ sh=minedge
219
+ sw=int(sh*r)-int(sh*r)%16
220
+
221
+ else:
222
+ r=sh/sw
223
+ sw=minedge
224
+ sh=int(sw*r)-int(sw*r)%16
225
 
226
 
227
 
 
228
 
 
229
 
230
+ print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale},")
231
+
232
+ if content_ref and style_ref:
233
+ images = [
234
+ content_ref.resize((w, h)),
235
+ style_ref.resize((sw, sh)) ,
236
+ #style_ref.resize((minedge, minedge)) ,
237
+ ]
238
+ elif content_ref:
239
+ images = [
240
+ content_ref.resize((w, h)),
241
+ #style_ref.resize((sw, sh)) ,
242
+ #style_ref.resize((minedge, minedge)) ,
243
+ ]
244
+ elif style_ref:
245
+ images = [
246
+ #content_ref.resize((w, h)),
247
+ style_ref.resize((sw, sh)) ,
248
+ #style_ref.resize((minedge, minedge)) ,
249
+ ]
250
+
251
+ if "infer with content prompt" in checkbox and content_prompt not in prompt:
252
+ prompt=','.join([prompt,content_prompt])
253
+ if "infer with style prompt" in checkbox and style_prompt not in prompt:
254
+ prompt=','.join([prompt,style_prompt])
255
+ if "infer with content prompt" not in checkbox and content_prompt in prompt:
256
+ prompt=prompt.replace(content_prompt.strip(','),'')
257
+ if "infer with style prompt" not in checkbox and style_prompt in prompt:
258
+ prompt=prompt.replace(style_prompt.strip(),'')
259
+ prompt=prompt.strip(',')
260
+ print(f"Calling pipeline with prompt: '{prompt}'")
261
+ inputs = {
262
+ "image": images,
263
+ "prompt": prompt,
264
+ "generator": torch.manual_seed(seed),
265
+ "true_cfg_scale": true_guidance_scale,
266
+ "negative_prompt": " ",
267
+ "num_inference_steps": num_inference_steps,
268
+ "guidance_scale": true_g,
269
+ "num_images_per_prompt": 1,
270
+ "width": w or sw,
271
+ "height": h or sh,
272
+ }
273
+ with torch.inference_mode():
274
+ image = pipe(**inputs)
275
+ image = image.images[0]
276
+
277
 
278
 
 
279
 
280
+
281
 
282
 
283
+
284
+
285
+ return image, seed, content_prompt, style_prompt, prompt
286
 
287
  # --- Examples and UI Layout ---
288
  examples = []
 
297
 
298
 
299
  <p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://openreview.net/forum?id=Cgb7JpOA5Q&referrer=%5Bthe%20profile%20of%20Shiwen%20Zhang%5D(%2Fprofile%3Fid%3D~Shiwen_Zhang1)' target='_blank'>QwenStyle: Content-Preserving Style Transfer with Qwen-Image-Edit</a> | Codes: <a href='https://github.com/witcherofresearch/Qwen-Image-Style-Transfer' target='_blank'>GitHub</a></p>
300
+ <p style="font-size: 1rem; margin-bottom: 1.5rem;">If you encounter an Error with this demo, the most possible reason is ZeroGPU out-of-memory and the solution is to decrease the Min Edge of the generated image from 1024 to a lower value. </p>
301
  '''
302
 
303
  with gr.Blocks() as demo:
 
394
  randomize_seed,
395
  true_guidance_scale,
396
  num_inference_steps,
397
+ minedge,
398
+ ],
399
+ outputs=[content_prompt, style_prompt,prompt]],
400
  fn=infer,
401
  cache_examples=False
402
  )
 
419
  true_guidance_scale,
420
  num_inference_steps,
421
  minedge,
422
+ checkbox,
423
 
424
  ],
425
+ outputs=[result, seed, content_prompt, style_prompt,prompt],
426
  )
427
 
428
+
429
+
430
+
431
  if __name__ == "__main__":
432
  demo.launch(server_name='0.0.0.0')