"""HSIGenePipeline - diffusers DiffusionPipeline for HSIGene hyperspectral generation. AeroGen-style loading: use DiffusionPipeline.from_pretrained(path) - no sys.path.insert needed. Self-contained: loading logic inlined (no separate modular_pipeline import). """ import importlib import json import sys from pathlib import Path from typing import List, Optional, Union import numpy as np import torch import torch.nn.functional as F from dataclasses import dataclass from diffusers import DDIMScheduler, DiffusionPipeline from diffusers.utils import BaseOutput # Re-export for diffusers component loading (load_method lookup) DiffusionPipeline = DiffusionPipeline # Inline path/loading (AeroGen-style) - self-contained for diffusers cache loading _pipeline_dir = Path(__file__).resolve().parent if str(_pipeline_dir) not in sys.path: sys.path.insert(0, str(_pipeline_dir)) # Register as "pipeline_hsigene" so diffusers' get_class_obj_and_candidates finds us when it does # importlib.import_module("pipeline_hsigene") during component loading. (We may be loaded as # "diffusers_modules.local.xxx.pipeline_hsigene" from cache, so this alias is required.) sys.modules["pipeline_hsigene"] = sys.modules[__name__] _COMPONENT_NAMES = ( "unet", "vae", "text_encoder", "local_adapter", "global_content_adapter", "global_text_adapter", "metadata_encoder", ) _TARGET_MAP = { "hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet", "hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet", "hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", "hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", "ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder", "hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder", "models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter", "hsigene.LocalAdapter": "local_adapter.model.LocalAdapter", "models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", "hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", "models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", "hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", "models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", "hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", } def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path: """Add model repo to path so hsigene can be imported. Returns resolved path.""" path = Path(pretrained_model_name_or_path) if not path.exists(): from huggingface_hub import snapshot_download path = Path(snapshot_download(pretrained_model_name_or_path)) path = path.resolve() s = str(path) if s not in sys.path: sys.path.insert(0, s) return path def _get_class(target: str): module_path, cls_name = target.rsplit(".", 1) mod = importlib.import_module(module_path) return getattr(mod, cls_name) def load_component(model_path: Path, name: str): """Load a single component (unet, vae, text_encoder, etc.).""" path = Path(model_path) root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path ensure_ldm_path(root) comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name with open(comp_path / "config.json") as f: cfg = json.load(f) target = cfg.pop("_target", None) if not target: raise ValueError(f"No _target in {comp_path / 'config.json'}") target = _TARGET_MAP.get(target, target) cls_ref = _get_class(target) params = {k: v for k, v in cfg.items() if not k.startswith("_")} comp = cls_ref(**params) for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"): wp = comp_path / wfile if wp.exists(): if wfile.endswith(".safetensors"): from safetensors.torch import load_file state = load_file(str(wp)) else: try: state = torch.load(wp, map_location="cpu", weights_only=True) except TypeError: state = torch.load(wp, map_location="cpu") comp.load_state_dict(state, strict=True) break comp.eval() return comp def load_components(model_path: Union[str, Path]) -> dict: """Load all pipeline components.""" path = Path(ensure_ldm_path(model_path)) if path.name in _COMPONENT_NAMES and (path / "config.json").exists(): path = path.parent scheduler = DDIMScheduler.from_pretrained(path / "scheduler") components = {} for name in _COMPONENT_NAMES: components[name] = load_component(path, name) scale_factor = 0.18215 if (path / "model_index.json").exists(): with open(path / "model_index.json") as f: scale_factor = json.load(f).get("scale_factor", scale_factor) components["scheduler"] = scheduler components["scale_factor"] = scale_factor return components class _CRSModelWrapper(torch.nn.Module): """Wrapper that mimics CRSControlNet interface.""" def __init__( self, unet, vae, text_encoder, local_adapter, global_content_adapter, global_text_adapter, metadata_emb, scale_factor=0.18215, local_control_scales=None, ): super().__init__() # Keep diffusion_model as a properly registered submodule so # wrapper/device transfers (e.g., `.to("cuda")`) move UNet weights. self.model = torch.nn.Module() self.model.add_module("diffusion_model", unet) self.first_stage_model = vae self.cond_stage_model = text_encoder self.local_adapter = local_adapter self.global_content_adapter = global_content_adapter self.global_text_adapter = global_text_adapter self.metadata_emb = metadata_emb self.scale_factor = scale_factor self.local_control_scales = local_control_scales or [1.0] * 13 @torch.no_grad() def get_learned_conditioning(self, prompts): return self.cond_stage_model(prompts) def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, text_strength=1.0, **kwargs): if metadata is None: metadata = cond["metadata"] metadata_emb = self.metadata_emb(metadata) content_t = cond["global_control"][0] global_control = self.global_content_adapter(content_t) cond_txt = torch.cat(cond["c_crossattn"], 1) cond_txt = self.global_text_adapter(cond_txt) cond_txt = F.normalize(cond_txt, p=2, dim=-1) * text_strength global_control = F.normalize(global_control, p=2, dim=-1) * global_strength cond_txt = torch.cat([cond_txt, global_control], dim=1) local_control = torch.cat(cond["local_control"], 1) local_control = self.local_adapter( x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control ) local_control = [c * s for c, s in zip(local_control, self.local_control_scales)] return self.model.diffusion_model( x=x_noisy, timesteps=t, metadata=metadata_emb, context=cond_txt, local_control=local_control, meta=True, ) def decode_first_stage(self, z): z = (1.0 / self.scale_factor) * z return self.first_stage_model.decode(z) def low_vram_shift(self, is_diffusing): if is_diffusing: self.model.diffusion_model = self.model.diffusion_model.cuda() self.local_adapter = self.local_adapter.cuda() self.global_text_adapter = self.global_text_adapter.cuda() self.global_content_adapter = self.global_content_adapter.cuda() self.first_stage_model = self.first_stage_model.cpu() self.cond_stage_model = self.cond_stage_model.cpu() else: self.model.diffusion_model = self.model.diffusion_model.cpu() self.local_adapter = self.local_adapter.cpu() self.global_text_adapter = self.global_text_adapter.cpu() self.global_content_adapter = self.global_content_adapter.cpu() self.first_stage_model = self.first_stage_model.cuda() self.cond_stage_model = self.cond_stage_model.cuda() @dataclass class HSIGeneOutput(BaseOutput): """Output class for HSIGene pipeline.""" images: Optional[np.ndarray] = None latents: Optional[torch.Tensor] = None def _is_component_list(v): """Check if value is raw config format [library, class_name].""" return isinstance(v, (list, tuple)) and len(v) == 2 and isinstance(v[0], str) and isinstance(v[1], str) def _resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]: """Resolve candidate path/repo to model root containing model_index.json.""" if not candidate: return None try: path = Path(candidate) if not path.exists(): from huggingface_hub import snapshot_download path = Path(snapshot_download(str(candidate))) path = path.resolve() if (path / "model_index.json").exists(): return path cur = path for _ in range(5): parent = cur.parent if parent == cur: break if (parent / "model_index.json").exists(): return parent cur = parent except Exception: return None return None class HSIGenePipeline(DiffusionPipeline): """Pipeline for HSIGene hyperspectral image generation. AeroGen-style: load with DiffusionPipeline.from_pretrained(path) - no sys.path.insert. """ def register_modules(self, **kwargs): """Override to handle list-format component specs from diffusers config.""" for name, module in kwargs.items(): if module is None or (isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None): self.register_to_config(**{name: (None, None)}) setattr(self, name, module) elif _is_component_list(module): self.register_to_config(**{name: (module[0], module[1])}) setattr(self, name, module) else: from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple library, class_name = _fetch_class_library_tuple(module) self.register_to_config(**{name: (library, class_name)}) setattr(self, name, module) def __init__( self, unet=None, vae=None, text_encoder=None, local_adapter=None, global_content_adapter=None, global_text_adapter=None, metadata_encoder=None, scheduler=None, crs_model=None, scale_factor=0.18215, model_path: Optional[Union[str, Path]] = None, _name_or_path: Optional[Union[str, Path]] = None, ): super().__init__() if crs_model is not None: self.register_modules(crs_model=crs_model, scheduler=scheduler) else: components_are_lists = any( _is_component_list(x) for x in ( unet, vae, text_encoder, local_adapter, global_content_adapter, global_text_adapter, metadata_encoder, ) if x is not None ) if components_are_lists: # Diffusers custom_pipeline may pass raw [library, class] placeholders to __init__. # Resolve model root and materialize real components here. model_root = ( _resolve_model_root(model_path) or _resolve_model_root(_name_or_path) or _resolve_model_root(getattr(getattr(self, "config", None), "_name_or_path", None)) ) if model_root is None: raise ValueError( "HSIGene received raw config placeholders but could not resolve model path. " "Pass `model_path` to HSIGenePipeline or load via " "`DiffusionPipeline.from_pretrained(, custom_pipeline=)` " "with a valid local model directory." ) loaded = load_components(model_root) unet = loaded["unet"] vae = loaded["vae"] text_encoder = loaded["text_encoder"] local_adapter = loaded["local_adapter"] global_content_adapter = loaded["global_content_adapter"] global_text_adapter = loaded["global_text_adapter"] metadata_encoder = loaded["metadata_encoder"] scheduler = loaded["scheduler"] if scheduler is None else scheduler scale_factor = loaded["scale_factor"] crs_model = _CRSModelWrapper( unet=unet, vae=vae, text_encoder=text_encoder, local_adapter=local_adapter, global_content_adapter=global_content_adapter, global_text_adapter=global_text_adapter, metadata_emb=metadata_encoder, scale_factor=scale_factor, ) self.register_modules( unet=unet, vae=vae, text_encoder=text_encoder, local_adapter=local_adapter, global_content_adapter=global_content_adapter, global_text_adapter=global_text_adapter, metadata_encoder=metadata_encoder, scheduler=scheduler, crs_model=crs_model, ) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, Path], device: Optional[Union[str, torch.device]] = None, subfolder: Optional[str] = None, **kwargs, ): """Load from diffusers-format directory. Supports subfolder for single-component loading.""" path = Path(ensure_ldm_path(pretrained_model_name_or_path)) subfolder = kwargs.pop("subfolder", subfolder) if subfolder in ("unet", "vae", "text_encoder", "local_adapter", "global_content_adapter", "global_text_adapter", "metadata_encoder"): return load_component(path, subfolder) if path.name in ("unet", "vae", "text_encoder", "local_adapter", "global_content_adapter", "global_text_adapter", "metadata_encoder"): if (path / "config.json").exists(): ensure_ldm_path(path.parent) return load_component(path.parent, path.name) if not (path / "model_index.json").exists(): for _ in range(5): parent = path.parent if (parent / "model_index.json").exists(): path = parent break if parent == path: break path = parent components = load_components(path) pipe = cls( unet=components["unet"], vae=components["vae"], text_encoder=components["text_encoder"], local_adapter=components["local_adapter"], global_content_adapter=components["global_content_adapter"], global_text_adapter=components["global_text_adapter"], metadata_encoder=components["metadata_encoder"], scheduler=components["scheduler"], scale_factor=components["scale_factor"], ) if device is not None: pipe = pipe.to(device) return pipe @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = "", num_samples: int = 1, height: int = 256, width: int = 256, num_inference_steps: int = 50, eta: float = 0.0, global_strength: float = 1.0, text_strength: Optional[float] = None, local_conditions: Optional[torch.Tensor] = None, global_conditions: Optional[torch.Tensor] = None, metadata: Optional[torch.Tensor] = None, condition_resolution: int = 512, guidance_scale: float = 1.0, negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, output_type: str = "numpy", return_dict: bool = True, save_memory: bool = False, ): target_device = next(self.crs_model.parameters()).device if hasattr(self, "unet") and isinstance(self.unet, torch.nn.Module): target_device = next(self.unet.parameters()).device if latents is not None: target_device = latents.device elif generator is not None and hasattr(generator, "device"): target_device = torch.device(generator.device) # Keep wrapper submodules on the same device used for sampling. if next(self.crs_model.parameters()).device != target_device: self.crs_model = self.crs_model.to(target_device) device = target_device if text_strength is None: text_strength = global_strength if isinstance(prompt, str): prompts = [prompt] * num_samples else: prompts = list(prompt) num_samples = len(prompts) if save_memory: self.crs_model.low_vram_shift(is_diffusing=False) text_embedding = self.crs_model.get_learned_conditioning(prompts) if local_conditions is None: local_conditions = torch.zeros( num_samples, 18, condition_resolution, condition_resolution, device=device, dtype=torch.float32, ) else: local_conditions = local_conditions.to(device=device, dtype=torch.float32) if global_conditions is None: global_conditions = torch.zeros( num_samples, 768, device=device, dtype=torch.float32, ) else: global_conditions = global_conditions.to(device=device, dtype=torch.float32) if metadata is None: metadata = torch.zeros(7, device=device, dtype=torch.float32) else: metadata = metadata.to(device=device, dtype=torch.float32) cond = { "local_control": [local_conditions], "c_crossattn": [text_embedding], "global_control": [global_conditions], "metadata": [metadata], } do_cfg = guidance_scale > 1.0 if do_cfg: if negative_prompt is None: neg_prompts = [""] * num_samples elif isinstance(negative_prompt, str): neg_prompts = [negative_prompt] * num_samples else: neg_prompts = list(negative_prompt) uc_text = self.crs_model.get_learned_conditioning(neg_prompts) uncond = { "local_control": [local_conditions], "c_crossattn": [uc_text], "global_control": [torch.zeros_like(global_conditions)], "metadata": [metadata], } latent_shape = (num_samples, 4, height // 4, width // 4) if latents is None: if generator is not None and hasattr(generator, "device"): gen_device = torch.device(generator.device) if gen_device.type != device.type: # Recreate generator on target device while preserving seed # so CPU/CUDA mismatch does not crash torch.randn. if hasattr(generator, "initial_seed"): generator = torch.Generator(device=device).manual_seed(generator.initial_seed()) else: generator = torch.Generator(device=device) latents = torch.randn( latent_shape, device=device, generator=generator, dtype=torch.float32, ) else: latents = latents.to(device) self.scheduler.set_timesteps(num_inference_steps, device=device) if save_memory: self.crs_model.low_vram_shift(is_diffusing=True) for t in self.progress_bar(self.scheduler.timesteps): t_batch = t.expand(num_samples) if do_cfg: noise_pred_cond = self.crs_model.apply_model( latents, t_batch, cond, metadata=metadata, global_strength=global_strength, text_strength=text_strength, ) noise_pred_uncond = self.crs_model.apply_model( latents, t_batch, uncond, metadata=metadata, global_strength=global_strength, text_strength=text_strength, ) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_cond - noise_pred_uncond ) else: noise_pred = self.crs_model.apply_model( latents, t_batch, cond, metadata=metadata, global_strength=global_strength, text_strength=text_strength, ) latents = self.scheduler.step( noise_pred, t, latents, eta=eta, generator=generator, ).prev_sample if output_type == "latent": if not return_dict: return (latents,) return HSIGeneOutput(latents=latents) if save_memory: self.crs_model.low_vram_shift(is_diffusing=False) images = self.crs_model.decode_first_stage(latents) images = images.permute(0, 2, 3, 1).cpu().numpy() images = images * 0.5 + 0.5 images = np.clip(images, 0, 1) if not return_dict: return (images,) return HSIGeneOutput(images=images)