BiliSakura's picture
Upload folder using huggingface_hub
c5fe00d verified
raw
history blame
17.6 kB
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
DEFAULT_NATIVE_RESOLUTION = 512
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from pathlib import Path
>>> from diffusers import DiffusionPipeline
>>> import torch
>>> model_dir = Path("./PixNerd-XL-16-512").resolve()
>>> pipe = DiffusionPipeline.from_pretrained(
... str(model_dir),
... local_files_only=True,
... custom_pipeline=str(model_dir / "pipeline.py"),
... trust_remote_code=True,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.to("cuda")
>>> print(pipe.id2label[207])
>>> print(pipe.get_label_ids("golden retriever"))
>>> generator = torch.Generator(device="cuda").manual_seed(42)
>>> # timeshift=3.0 and order=2 are defaults in scheduler/scheduler_config.json
>>> image = pipe(
... class_labels="golden retriever",
... height=512,
... width=512,
... num_inference_steps=25,
... guidance_scale=4.0,
... generator=generator,
... ).images[0]
>>> image.save("demo.png")
```
"""
ConditioningInput = Union[int, str, List[Union[int, str]], torch.LongTensor]
class PixNerdPipeline(DiffusionPipeline):
r"""
Pipeline for class-conditional PixNerd pixel-space image generation.
Parameters:
transformer ([`PixNerdTransformer2DModel`]):
Class-conditional PixNerd denoiser operating in pixel space.
scheduler ([`PixNerdFlowMatchScheduler`]):
Flow-matching scheduler with AdamLM multi-step coefficients.
vae ([`PixNerdPixelVAE`], *optional*):
Identity pixel autoencoder. May also be attached to `transformer.vae`.
conditioner ([`PixNerdLabelConditioner`], *optional*):
ImageNet class-label conditioner. May also be attached to `transformer.conditioner`.
id2label (`dict[int, str]`, *optional*):
ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
"""
model_cpu_offload_seq = "conditioner->transformer->vae"
_callback_tensor_inputs = ["latents"]
_optional_components = ["vae", "conditioner"]
def __init__(
self,
transformer,
scheduler,
vae=None,
conditioner=None,
id2label: Optional[Dict[Union[int, str], str]] = None,
):
super().__init__()
if vae is None:
vae = getattr(transformer, "vae", None)
if conditioner is None:
conditioner = getattr(transformer, "conditioner", None)
if vae is None or conditioner is None:
raise ValueError("Pipeline requires `vae` and `conditioner` either explicitly or from `transformer`.")
self.register_modules(
vae=vae,
conditioner=conditioner,
transformer=transformer,
scheduler=scheduler,
)
self.image_processor = VaeImageProcessor(vae_scale_factor=1, do_normalize=False)
if id2label is None:
id2label = self._read_id2label_from_model_index(
getattr(getattr(self, "config", None), "_name_or_path", None)
)
self._id2label = self._normalize_id2label(id2label)
self.labels = self._build_label2id(self._id2label)
self._labels_loaded_from_model_index = bool(self._id2label)
def _get_device(self) -> torch.device:
try:
return self._execution_device
except AttributeError:
pass
for name in ("transformer", "vae", "scheduler"):
module = getattr(self, name, None)
if isinstance(module, torch.nn.Module):
parameter = next(module.parameters(), None)
if parameter is not None:
return parameter.device
return torch.device("cpu")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path=None, *args, **kwargs):
id2label_override = kwargs.pop("id2label", None)
pipe = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
id2label = id2label_override or cls._read_id2label_from_model_index(pretrained_model_name_or_path)
if id2label:
pipe._id2label = cls._normalize_id2label(id2label)
pipe.labels = cls._build_label2id(pipe._id2label)
pipe._labels_loaded_from_model_index = True
return pipe
def _ensure_labels_loaded(self) -> None:
if self._labels_loaded_from_model_index:
return
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
if loaded:
self._id2label = loaded
self.labels = self._build_label2id(self._id2label)
self._labels_loaded_from_model_index = True
@staticmethod
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
if not id2label:
return {}
return {int(key): value for key, value in id2label.items()}
@staticmethod
def _read_id2label_from_model_index(variant_path: Optional[Union[str, Path]]) -> Dict[int, str]:
if not variant_path:
return {}
model_index_path = Path(variant_path).resolve() / "model_index.json"
if not model_index_path.exists():
return {}
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
id2label = raw.get("id2label")
if not isinstance(id2label, dict):
return {}
return {int(key): value for key, value in id2label.items()}
@staticmethod
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
label2id: Dict[str, int] = {}
for class_id, value in id2label.items():
for synonym in value.split(","):
synonym = synonym.strip()
if synonym:
label2id[synonym] = int(class_id)
return dict(sorted(label2id.items()))
@property
def id2label(self) -> Dict[int, str]:
r"""ImageNet class id to English label string (comma-separated synonyms)."""
self._ensure_labels_loaded()
return self._id2label
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
r"""
Map ImageNet label strings to class ids.
Args:
label (`str` or `list[str]`):
One or more English label strings. Each string must match a synonym in `id2label`.
"""
self._ensure_labels_loaded()
if isinstance(label, str):
label = [label]
if not self.labels:
raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
missing = [item for item in label if item not in self.labels]
if missing:
preview = ", ".join(list(self.labels.keys())[:8])
raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
return [self.labels[item] for item in label]
def _normalize_class_labels(
self,
class_labels: ConditioningInput,
num_images_per_prompt: int = 1,
) -> List[int]:
if torch.is_tensor(class_labels):
values = class_labels.to(dtype=torch.long).reshape(-1).tolist()
elif isinstance(class_labels, int):
values = [class_labels]
elif isinstance(class_labels, str):
values = self.get_label_ids(class_labels)
elif class_labels and isinstance(class_labels[0], str):
values = self.get_label_ids(list(class_labels))
else:
values = [int(entry) for entry in class_labels]
if num_images_per_prompt == 1:
return values
expanded: List[int] = []
for value in values:
expanded.extend([value] * num_images_per_prompt)
return expanded
def _get_patch_size(self) -> int:
patch_size = getattr(self.transformer, "patch_size", None)
if patch_size is not None:
return int(patch_size)
return int(getattr(self.transformer.config, "patch_size", 16))
def _get_in_channels(self) -> int:
in_channels = getattr(self.transformer, "in_channels", None)
if in_channels is not None:
return int(in_channels)
return int(getattr(self.transformer.config, "in_channels", 3))
def check_inputs(
self,
height: int,
width: int,
num_inference_steps: int,
output_type: str,
) -> None:
if num_inference_steps < 1:
raise ValueError("num_inference_steps must be >= 1.")
if output_type not in {"pil", "np", "pt", "latent"}:
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
order = int(getattr(self.scheduler.config, "order", getattr(self.scheduler, "order", 2)))
if order < 1:
raise ValueError("scheduler.config.order must be >= 1.")
patch_size = self._get_patch_size()
if height % patch_size != 0 or width % patch_size != 0:
raise ValueError(f"height and width must be divisible by patch_size={patch_size}.")
def encode_condition(
self,
class_label_ids: List[int],
negative_class_label_ids: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
metadata = {"device": self._get_device()}
with torch.no_grad():
cond, default_uncond = self.conditioner(class_label_ids, metadata)
if negative_class_label_ids is not None:
_, uncond = self.conditioner(negative_class_label_ids, metadata)
else:
uncond = default_uncond
return cond, uncond
def prepare_latents(
self,
batch_size: int,
num_channels: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
return randn_tensor(
(batch_size, num_channels, height, width),
generator=generator,
device=device,
dtype=dtype,
)
@staticmethod
def _fp_to_uint8(image: torch.Tensor) -> torch.Tensor:
return torch.clip_((image + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
if output_type == "latent":
return latents
image = self.vae.decode(latents)
if output_type == "pt":
return image
images_uint8 = self._fp_to_uint8(image).permute(0, 2, 3, 1).cpu().numpy()
if output_type == "np":
return images_uint8
if output_type == "pil":
from PIL import Image
return [Image.fromarray(img) for img in images_uint8]
raise ValueError(f"Unsupported output_type: {output_type}")
def _apply_decoder_patch_scaling(self, height: int, width: int) -> None:
denoiser = getattr(self.transformer, "denoiser", self.transformer)
if hasattr(denoiser, "decoder_patch_scaling_h"):
denoiser.decoder_patch_scaling_h = height / DEFAULT_NATIVE_RESOLUTION
denoiser.decoder_patch_scaling_w = width / DEFAULT_NATIVE_RESOLUTION
@torch.inference_mode()
def __call__(
self,
class_labels: Optional[ConditioningInput] = None,
negative_class_labels: Optional[ConditioningInput] = None,
num_images_per_prompt: int = 1,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 25,
guidance_scale: float = 4.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: str = "pil",
return_dict: bool = True,
prompt: Optional[ConditioningInput] = None,
negative_prompt: Optional[ConditioningInput] = None,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Generate class-conditional images with PixNerd.
Args:
class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`):
ImageNet class indices or human-readable English label strings.
negative_class_labels (`int`, `str`, `list[int]`, `list[str]`, or `torch.LongTensor`, *optional*):
Optional negative class labels for classifier-free guidance.
num_images_per_prompt (`int`, defaults to `1`):
Number of images to generate per class label.
height (`int`, *optional*):
Output image height in pixels. Defaults to `512`.
width (`int`, *optional*):
Output image width in pixels. Defaults to `512`.
num_inference_steps (`int`, defaults to `25`):
Number of denoising steps.
guidance_scale (`float`, defaults to `4.0`):
Classifier-free guidance scale applied by the scheduler.
generator (`torch.Generator`, *optional*):
RNG for reproducibility.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy pixel tensor.
output_type (`str`, defaults to `"pil"`):
`"pil"`, `"np"`, `"pt"`, or `"latent"`.
return_dict (`bool`, defaults to `True`):
Return [`ImagePipelineOutput`] if True.
prompt (`int`, `str`, `list`, *optional*):
Deprecated alias for `class_labels`.
negative_prompt (`int`, `str`, `list`, *optional*):
Deprecated alias for `negative_class_labels`.
"""
if class_labels is None:
class_labels = prompt
if negative_class_labels is None:
negative_class_labels = negative_prompt
if class_labels is None:
raise ValueError("`class_labels` (or deprecated `prompt`) must be provided.")
height = int(height or DEFAULT_NATIVE_RESOLUTION)
width = int(width or DEFAULT_NATIVE_RESOLUTION)
self.check_inputs(height, width, num_inference_steps, output_type)
patch_size = self._get_patch_size()
height = (height // patch_size) * patch_size
width = (width // patch_size) * patch_size
self._apply_decoder_patch_scaling(height, width)
class_label_ids = self._normalize_class_labels(class_labels, num_images_per_prompt)
negative_label_ids = None
if negative_class_labels is not None:
negative_label_ids = self._normalize_class_labels(negative_class_labels, num_images_per_prompt)
device = self._get_device()
model_dtype = next(self.transformer.parameters()).dtype
batch_size = len(class_label_ids)
cond, uncond = self.encode_condition(class_label_ids, negative_label_ids)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self._get_in_channels(),
height=height,
width=width,
dtype=model_dtype,
device=device,
generator=generator,
latents=latents,
)
self.scheduler.set_timesteps(
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
device=device,
)
for timestep in self.progress_bar(self.scheduler.timesteps):
cfg_latents = torch.cat([latents, latents], dim=0)
cfg_t = timestep.repeat(cfg_latents.shape[0]).to(device=device, dtype=latents.dtype)
cfg_condition = torch.cat([uncond, cond], dim=0)
model_output = self.transformer(
sample=cfg_latents.to(dtype=model_dtype),
timestep=cfg_t,
encoder_hidden_states=cfg_condition,
).sample
model_output = self.scheduler.classifier_free_guidance(model_output)
latents = self.scheduler.step(
model_output=model_output,
timestep=timestep,
sample=latents,
).prev_sample
image = self.decode_latents(latents, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
PixNerdPipelineOutput = ImagePipelineOutput