BiliSakura's picture
Upload folder using huggingface_hub
561f88a verified
raw
history blame
18.6 kB
"""Hub custom pipeline: FiTPipeline.
Load with native Hugging Face diffusers and trust_remote_code=True.
"""
import importlib
import inspect
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers.schedulers as diffusers_schedulers
import torch
from huggingface_hub import snapshot_download
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils.torch_utils import randn_tensor
# Local component classes are loaded dynamically in from_pretrained.
DEFAULT_NATIVE_RESOLUTION = 256
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from pathlib import Path
>>> import torch
>>> from diffusers import DiffusionPipeline, DDIMScheduler
>>> model_dir = Path("./FiTv1-XL-2-256").resolve()
>>> pipe = DiffusionPipeline.from_pretrained(
... str(model_dir),
... local_files_only=True,
... custom_pipeline=str(model_dir / "pipeline.py"),
... trust_remote_code=True,
... torch_dtype=torch.float32,
... )
>>> pipe.to("cuda")
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
>>> print(pipe.id2label[207])
>>> print(pipe.get_label_ids("golden retriever"))
>>> generator = torch.Generator(device="cuda").manual_seed(42)
>>> image = pipe(
... class_labels="golden retriever",
... height=256,
... width=256,
... num_inference_steps=250,
... guidance_scale=1.5,
... generator=generator,
... ).images[0]
>>> image.save("demo.png")
```
"""
class FiTPipeline(DiffusionPipeline):
r"""
Pipeline for class-conditional image generation with FiTv1 (DDPM sampling).
"""
model_cpu_offload_seq = "transformer->vae"
_optional_components = ["vae"]
def __init__(
self,
transformer: Any,
scheduler: KarrasDiffusionSchedulers,
vae: Any = None,
id2label: Optional[Dict[Union[int, str], str]] = None,
null_class_id: Optional[int] = None,
):
super().__init__()
self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
self.image_processor = VaeImageProcessor()
if null_class_id is None:
null_class_id = int(getattr(self.transformer.config, "num_classes", 1000))
self.register_to_config(null_class_id=int(null_class_id))
self._id2label = self._normalize_id2label(id2label)
self.labels = self._build_label2id(self._id2label)
self._labels_loaded_from_model_index = bool(self._id2label)
@property
def vae_scale_factor(self) -> int:
if self.vae is None:
return 8
block_out_channels = getattr(self.vae.config, "block_out_channels", None)
if block_out_channels:
return int(2 ** (len(block_out_channels) - 1))
return 8
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
"""Load a self-contained variant folder locally or from the Hub."""
repo_root = Path(__file__).resolve().parent
if pretrained_model_name_or_path in (None, "", "."):
variant = repo_root
elif (
isinstance(pretrained_model_name_or_path, str)
and "/" in pretrained_model_name_or_path
and not Path(pretrained_model_name_or_path).exists()
):
hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
if subfolder:
hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
else:
variant = Path(pretrained_model_name_or_path)
if not variant.is_absolute():
candidate = (Path.cwd() / variant).resolve()
variant = candidate if candidate.exists() else (repo_root / variant).resolve()
if subfolder:
variant = variant / subfolder
id2label_override = kwargs.pop("id2label", None)
null_class_id_override = kwargs.pop("null_class_id", None)
model_kwargs = dict(kwargs)
inserted: List[str] = []
def _load_component(folder: str, module_name: str, class_name: str):
comp_dir = variant / folder
module_path = comp_dir / f"{module_name}.py"
has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
if not module_path.exists() or not has_weights:
return None
comp_path = str(comp_dir)
if comp_path not in sys.path:
sys.path.insert(0, comp_path)
inserted.append(comp_path)
module = importlib.import_module(module_name)
component_cls = getattr(module, class_name)
return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
try:
transformer = _load_component("transformer", "fit_transformer_2d", "FiTTransformer2DModel")
if transformer is None:
raise ValueError(f"No loadable transformer found under {variant}")
scheduler = cls._load_scheduler_from_variant(variant, model_kwargs)
vae = None
vae_dir = variant / "vae"
if vae_dir.exists() and (vae_dir / "config.json").exists():
vae = AutoencoderKL.from_pretrained(str(vae_dir), **model_kwargs)
id2label = id2label_override or cls._read_id2label_from_model_index(str(variant))
null_class_id = null_class_id_override if null_class_id_override is not None else cls._read_null_class_id(
str(variant)
)
pipe = cls(
transformer=transformer,
scheduler=scheduler,
vae=vae,
id2label=id2label,
null_class_id=null_class_id,
)
if hasattr(pipe, "register_to_config"):
pipe.register_to_config(_name_or_path=str(variant))
return pipe
finally:
for comp_path in inserted:
if comp_path in sys.path:
sys.path.remove(comp_path)
@classmethod
def _load_scheduler_from_variant(cls, variant: Path, model_kwargs: Dict[str, object]) -> KarrasDiffusionSchedulers:
scheduler_dir = variant / "scheduler"
config_path = scheduler_dir / "scheduler_config.json"
if not config_path.exists():
raise ValueError(f"No scheduler config found under {scheduler_dir}")
scheduler_entry = None
model_index_path = variant / "model_index.json"
if model_index_path.exists():
scheduler_entry = json.loads(model_index_path.read_text(encoding="utf-8")).get("scheduler")
if scheduler_entry is None:
class_name = json.loads(config_path.read_text(encoding="utf-8")).get("_class_name")
if not class_name:
raise ValueError(f"Missing `_class_name` in {config_path}")
scheduler_entry = ["diffusers", class_name]
if not isinstance(scheduler_entry, list) or len(scheduler_entry) != 2:
raise ValueError(f"Invalid scheduler entry in model_index.json: {scheduler_entry}")
library_name, class_name = scheduler_entry
if library_name != "diffusers":
raise ValueError(f"Unsupported scheduler library: {library_name}")
scheduler_cls = getattr(diffusers_schedulers, class_name)
return scheduler_cls.from_pretrained(str(scheduler_dir), **model_kwargs)
@staticmethod
def _prepare_model_output_for_scheduler(
model_out: torch.Tensor,
latent_channels: int,
scheduler: KarrasDiffusionSchedulers,
) -> torch.Tensor:
if model_out.shape[1] != latent_channels * 2:
return model_out
variance_type = getattr(scheduler.config, "variance_type", None)
if scheduler.__class__.__name__ == "DDPMScheduler" and variance_type in ("learned", "learned_range"):
return model_out
model_output, _ = torch.split(model_out, latent_channels, dim=1)
return model_output
@staticmethod
def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
if not id2label:
return {}
return {int(key): value for key, value in id2label.items()}
@staticmethod
def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
if not variant_path:
return {}
model_index_path = Path(variant_path).resolve() / "model_index.json"
if not model_index_path.exists():
return {}
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
id2label = raw.get("id2label")
if not isinstance(id2label, dict):
return {}
return {int(key): value for key, value in id2label.items()}
@staticmethod
def _read_null_class_id(variant_path: Optional[str]) -> Optional[int]:
if not variant_path:
return None
model_index_path = Path(variant_path).resolve() / "model_index.json"
if not model_index_path.exists():
return None
raw = json.loads(model_index_path.read_text(encoding="utf-8"))
if "null_class_id" in raw:
return int(raw["null_class_id"])
return None
@staticmethod
def _build_label2id(id2label: Dict[int, str]) -> Dict[str, int]:
label2id: Dict[str, int] = {}
for class_id, value in id2label.items():
for synonym in value.split(","):
synonym = synonym.strip()
if synonym:
label2id[synonym] = int(class_id)
return dict(sorted(label2id.items()))
@property
def id2label(self) -> Dict[int, str]:
self._ensure_labels_loaded()
return self._id2label
def _ensure_labels_loaded(self) -> None:
if self._labels_loaded_from_model_index:
return
loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
if loaded:
self._id2label = loaded
self.labels = self._build_label2id(self._id2label)
self._labels_loaded_from_model_index = True
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
labels = [label] if isinstance(label, str) else label
self._ensure_labels_loaded()
if not self.labels:
raise ValueError("No id2label mapping is available in this checkpoint.")
missing = [item for item in labels if item not in self.labels]
if missing:
preview = ", ".join(list(self.labels.keys())[:8])
raise ValueError(f"Unknown labels: {missing}. Example valid labels: {preview}, ...")
return [self.labels[item] for item in labels]
def _normalize_class_labels(
self,
class_labels: Union[int, str, List[Union[int, str]], torch.Tensor],
) -> List[int]:
if isinstance(class_labels, torch.Tensor):
class_labels = class_labels.detach().cpu().tolist()
if isinstance(class_labels, int):
return [class_labels]
if isinstance(class_labels, str):
return self.get_label_ids(class_labels)
if not class_labels:
raise ValueError("`class_labels` cannot be empty.")
if isinstance(class_labels[0], str):
return self.get_label_ids(class_labels) # type: ignore[arg-type]
return [int(class_id) for class_id in class_labels] # type: ignore[union-attr]
@staticmethod
def prepare_extra_step_kwargs(
scheduler: KarrasDiffusionSchedulers,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {}
step_params = set(inspect.signature(scheduler.step).parameters.keys())
if "generator" in step_params:
kwargs["generator"] = generator
return kwargs
@staticmethod
def _expand_timestep(timestep, batch_size: int, device: torch.device) -> torch.Tensor:
if not torch.is_tensor(timestep):
timestep = torch.tensor([timestep], dtype=torch.long, device=device)
elif timestep.ndim == 0:
timestep = timestep[None].to(device=device)
return timestep.expand(batch_size)
@staticmethod
def _prepare_grid_mask_size(
batch_size: int,
n_patch_h: int,
n_patch_w: int,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grid_h = torch.arange(n_patch_h, dtype=torch.long, device=device)
grid_w = torch.arange(n_patch_w, dtype=torch.long, device=device)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.cat([grid[0].reshape(1, -1), grid[1].reshape(1, -1)], dim=0).repeat(batch_size, 1, 1)
mask = torch.ones(batch_size, n_patch_h * n_patch_w, device=device, dtype=dtype)
size = torch.tensor((n_patch_h, n_patch_w), device=device, dtype=torch.long).repeat(batch_size, 1)[:, None, :]
return grid, mask, size
@torch.inference_mode()
def __call__(
self,
class_labels: Union[int, str, List[Union[int, str]], torch.Tensor] = 207,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 250,
guidance_scale: float = 1.5,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
output_type: str = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
class_labels_list = self._normalize_class_labels(class_labels)
batch_size = len(class_labels_list)
height = DEFAULT_NATIVE_RESOLUTION if height is None else int(height)
width = DEFAULT_NATIVE_RESOLUTION if width is None else int(width)
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(
f"`height` and `width` must be divisible by {self.vae_scale_factor}, got ({height}, {width})."
)
if output_type not in {"pil", "np", "pt", "latent"}:
raise ValueError(f"Unsupported `output_type`: {output_type}")
device = self._execution_device
model_dtype = next(self.transformer.parameters()).dtype
latent_h = height // self.vae_scale_factor
latent_w = width // self.vae_scale_factor
patch_size = int(self.transformer.config.patch_size)
n_patch_h, n_patch_w = latent_h // patch_size, latent_w // patch_size
latent_channels = (patch_size**2) * int(self.transformer.in_channels)
extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
self.scheduler.set_timesteps(num_inference_steps, device=device)
if latents is None:
latents = randn_tensor(
(batch_size, latent_channels, n_patch_h * n_patch_w),
generator=generator,
device=device,
dtype=model_dtype,
)
else:
latents = latents.to(device=device, dtype=model_dtype)
expected = (batch_size, latent_channels, n_patch_h * n_patch_w)
if tuple(latents.shape) != expected:
raise ValueError(f"Invalid `latents` shape: {tuple(latents.shape)}. Expected {expected}.")
grid, mask, size = self._prepare_grid_mask_size(batch_size, n_patch_h, n_patch_w, device, model_dtype)
class_labels_tensor = torch.tensor(class_labels_list, device=device, dtype=torch.long)
using_cfg = guidance_scale > 1.0
if using_cfg:
y_null = torch.full((batch_size,), int(self.config.null_class_id), device=device, dtype=torch.long)
y = torch.cat([class_labels_tensor, y_null], dim=0)
grid = torch.cat([grid, grid], dim=0)
mask = torch.cat([mask, mask], dim=0)
size = torch.cat([size, size], dim=0)
for timestep in self.progress_bar(self.scheduler.timesteps):
latent_model_input = latents
if using_cfg:
latent_model_input = torch.cat([latents, latents], dim=0)
timestep_tensor = self._expand_timestep(timestep, latent_model_input.shape[0], device)
if using_cfg:
model_out = self.transformer.forward_with_cfg(
latent_model_input,
timestep_tensor,
y=y,
grid=grid,
mask=mask,
size=size,
cfg_scale=guidance_scale,
)
model_out = model_out.chunk(2, dim=0)[0]
else:
model_out = self.transformer(
latents,
timestep_tensor,
y=class_labels_tensor,
grid=grid,
mask=mask,
size=size,
)
model_output = self._prepare_model_output_for_scheduler(model_out, latent_channels, self.scheduler)
latents = self.scheduler.step(model_output, timestep, latents, **extra_step_kwargs).prev_sample
latents = latents[..., : n_patch_h * n_patch_w]
latents = self.transformer.unpatchify(latents, (latent_h, latent_w))
if self.vae is not None:
vae_dtype = next(self.vae.parameters()).dtype
latents = latents.to(dtype=vae_dtype)
latents = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image = self.image_processor.postprocess(latents, output_type=output_type)
elif output_type == "latent":
image = latents
else:
raise ValueError("Cannot decode latents without a VAE.")
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
__all__ = ["FiTPipeline"]