"""HSIGene UNet - LocalControlUNetModel for hyperspectral generation.""" import torch from .diffusion import UNetModel from .utils import timestep_embedding class HSIGeneUNet(UNetModel): """UNet that accepts metadata and local_control from LocalAdapter.""" def forward( self, x, timesteps=None, metadata=None, context=None, local_control=None, meta=False, **kwargs, ): hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) + metadata h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) h += local_control.pop() for module in self.output_blocks: h = torch.cat([h, hs.pop() + local_control.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) return self.out(h)