mamungtai-sat pormungtai commited on
Commit
84f9e8e
·
1 Parent(s): b7fc923

AsianRealistic: load in fp32 to fix NaN rainbow-noise output (#22)

Browse files

- AsianRealistic: load in fp32 to fix NaN rainbow-noise output (6e5f94eee0d23e5fc79c3417b1fb394a78d2b965)


Co-authored-by: pormungtailaw <pormungtai@users.noreply.huggingface.co>

Files changed (2) hide show
  1. models.json +1 -1
  2. pipeline_manager.py +9 -7
models.json CHANGED
@@ -9,7 +9,7 @@
9
  "type": "checkpoint",
10
  "repo_id": "stablediffusionapi/asianrealisticsdlifechias",
11
  "single_file_url": null,
12
- "vae": "stabilityai/sd-vae-ft-mse",
13
  "preview": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/9c67da8b-5e57-413c-a5ca-135f5ed9af18/width=450/1743066.jpeg",
14
  "trigger": "",
15
  "clip_skip": 2,
 
9
  "type": "checkpoint",
10
  "repo_id": "stablediffusionapi/asianrealisticsdlifechias",
11
  "single_file_url": null,
12
+ "dtype": "fp32",
13
  "preview": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/9c67da8b-5e57-413c-a5ca-135f5ed9af18/width=450/1743066.jpeg",
14
  "trigger": "",
15
  "clip_skip": 2,
pipeline_manager.py CHANGED
@@ -201,27 +201,30 @@ def _build_base_pipeline(cfg):
201
  """Construct the txt2img pipeline for a model config (on CPU)."""
202
  base = cfg["base"]
203
  common = dict(token=HF_TOKEN)
 
 
 
204
 
205
  if base == "sd15":
206
  from diffusers import StableDiffusionPipeline
207
  if cfg.get("single_file_url"):
208
  local = _download_url(cfg["single_file_url"])
209
  pipe = StableDiffusionPipeline.from_single_file(
210
- local, torch_dtype=DTYPE_SD, safety_checker=None
211
  )
212
  else:
213
  pipe = StableDiffusionPipeline.from_pretrained(
214
- cfg["repo_id"], torch_dtype=DTYPE_SD, safety_checker=None, **common
215
  )
216
 
217
  elif base == "sdxl":
218
  from diffusers import StableDiffusionXLPipeline
219
  if cfg.get("single_file_url"):
220
  local = _download_url(cfg["single_file_url"])
221
- pipe = StableDiffusionXLPipeline.from_single_file(local, torch_dtype=DTYPE_SD)
222
  else:
223
  pipe = StableDiffusionXLPipeline.from_pretrained(
224
- cfg["repo_id"], torch_dtype=DTYPE_SD, **common
225
  )
226
 
227
  elif base == "flux":
@@ -262,11 +265,10 @@ def _build_base_pipeline(cfg):
262
  except Exception as e: # noqa
263
  print(f"[lora] fuse skipped: {e}")
264
 
265
- # Optional VAE override some merged checkpoints ship a broken fp16 VAE that
266
- # decodes to rainbow noise; a known-good VAE fixes it.
267
  if cfg.get("vae"):
268
  from diffusers import AutoencoderKL
269
- pipe.vae = AutoencoderKL.from_pretrained(cfg["vae"], torch_dtype=DTYPE_SD)
270
 
271
  # SD1.5 / SDXL community checkpoints are tuned for the Euler Ancestral sampler;
272
  # it matches the look people get in A1111 / ComfyUI far better than the default.
 
201
  """Construct the txt2img pipeline for a model config (on CPU)."""
202
  base = cfg["base"]
203
  common = dict(token=HF_TOKEN)
204
+ # Some checkpoint merges overflow to NaN in fp16 (rainbow-noise output);
205
+ # such models set "dtype": "fp32" in the registry.
206
+ dt = torch.float32 if cfg.get("dtype") == "fp32" else DTYPE_SD
207
 
208
  if base == "sd15":
209
  from diffusers import StableDiffusionPipeline
210
  if cfg.get("single_file_url"):
211
  local = _download_url(cfg["single_file_url"])
212
  pipe = StableDiffusionPipeline.from_single_file(
213
+ local, torch_dtype=dt, safety_checker=None
214
  )
215
  else:
216
  pipe = StableDiffusionPipeline.from_pretrained(
217
+ cfg["repo_id"], torch_dtype=dt, safety_checker=None, **common
218
  )
219
 
220
  elif base == "sdxl":
221
  from diffusers import StableDiffusionXLPipeline
222
  if cfg.get("single_file_url"):
223
  local = _download_url(cfg["single_file_url"])
224
+ pipe = StableDiffusionXLPipeline.from_single_file(local, torch_dtype=dt)
225
  else:
226
  pipe = StableDiffusionXLPipeline.from_pretrained(
227
+ cfg["repo_id"], torch_dtype=dt, **common
228
  )
229
 
230
  elif base == "flux":
 
265
  except Exception as e: # noqa
266
  print(f"[lora] fuse skipped: {e}")
267
 
268
+ # Optional VAE override (known-good VAE for models with a broken one).
 
269
  if cfg.get("vae"):
270
  from diffusers import AutoencoderKL
271
+ pipe.vae = AutoencoderKL.from_pretrained(cfg["vae"], torch_dtype=dt)
272
 
273
  # SD1.5 / SDXL community checkpoints are tuned for the Euler Ancestral sampler;
274
  # it matches the look people get in A1111 / ComfyUI far better than the default.