""" NextDiTPixelSpace — Zeta-Chroma model architecture Re-implemented from ComfyUI comfy/ldm/lumina/model.py Key names match the checkpoint exactly (split Q/K/V, ModuleDict final_layer). This version returns the raw model output matching ComfyUI's forward() convention: output = (x - neg_x0) / sigma The sampling loop in app.py does the Euler step directly. """ import math from functools import lru_cache import torch import torch.nn as nn import torch.nn.functional as F def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half) args = t.float().unsqueeze(-1) * freqs.unsqueeze(0) return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size), nn.SiLU(), nn.Linear(hidden_size, output_size or hidden_size), ) self.frequency_embedding_size = frequency_embedding_size def forward(self, t, dtype=None): emb = timestep_embedding(t, self.frequency_embedding_size) weight_dtype = self.mlp[0].weight.dtype emb = emb.to(weight_dtype) return self.mlp(emb) # ── RoPE ── def rope(pos, dim, theta): scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta ** scale) out = torch.einsum("...n,d->...nd", pos.float(), omega) cos_out, sin_out = torch.cos(out), torch.sin(out) out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) return out.view(*out.shape[:-1], 2, 2).float() def apply_rope(xq, xk, freqs_cis): xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) class EmbedND(nn.Module): def __init__(self, dim, theta, axes_dim): super().__init__() self.theta = theta self.axes_dim = axes_dim def forward(self, ids): n_axes = ids.shape[-1] emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) return emb.unsqueeze(1) class JointAttention(nn.Module): def __init__(self, dim, n_heads, n_kv_heads=None, qk_norm=True): super().__init__() self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads self.n_heads = n_heads self.head_dim = dim // n_heads self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False) self.to_k = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) self.to_v = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) self.to_out = nn.ModuleList([nn.Linear(n_heads * self.head_dim, dim, bias=False)]) self.norm_q = nn.RMSNorm(self.head_dim, elementwise_affine=True) if qk_norm else nn.Identity() self.norm_k = nn.RMSNorm(self.head_dim, elementwise_affine=True) if qk_norm else nn.Identity() def forward(self, x, x_mask, freqs_cis): B, S, _ = x.shape xq = self.to_q(x).view(B, S, self.n_heads, self.head_dim) xk = self.to_k(x).view(B, S, self.n_kv_heads, self.head_dim) xv = self.to_v(x).view(B, S, self.n_kv_heads, self.head_dim) xq, xk = self.norm_q(xq), self.norm_k(xk) xq, xk = apply_rope(xq, xk, freqs_cis) n_rep = self.n_heads // self.n_kv_heads if n_rep > 1: xk = xk.unsqueeze(3).expand(-1, -1, -1, n_rep, -1).flatten(2, 3) xv = xv.unsqueeze(3).expand(-1, -1, -1, n_rep, -1).flatten(2, 3) out = F.scaled_dot_product_attention(xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)) return self.to_out[0](out.transpose(1, 2).contiguous().view(B, S, -1)) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): super().__init__() if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class JointTransformerBlock(nn.Module): def __init__(self, dim, n_heads, n_kv_heads, multiple_of=256, ffn_dim_multiplier=4.0, norm_eps=1e-5, qk_norm=True, modulation=True, z_image_modulation=False): super().__init__() self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm) self.feed_forward = FeedForward(dim, dim, multiple_of, ffn_dim_multiplier) self.attention_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True) self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True) self.attention_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True) self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True) self.modulation = modulation if modulation: mod_in = min(dim, 256) if z_image_modulation else min(dim, 1024) if z_image_modulation: self.adaLN_modulation = nn.Sequential(nn.Linear(mod_in, 4 * dim, bias=True)) else: self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(mod_in, 4 * dim, bias=True)) def forward(self, x, x_mask, freqs_cis, adaln_input=None): if self.modulation: s_msa, g_msa, s_mlp, g_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) x = x + g_msa.unsqueeze(1).tanh() * self.attention_norm2( self.attention(self.attention_norm1(x) * (1 + s_msa.unsqueeze(1)), x_mask, freqs_cis)) x = x + g_mlp.unsqueeze(1).tanh() * self.ffn_norm2( self.feed_forward(self.ffn_norm1(x) * (1 + s_mlp.unsqueeze(1)))) else: x = x + self.attention_norm2(self.attention(self.attention_norm1(x), x_mask, freqs_cis)) x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) return x class NerfEmbedder(nn.Module): def __init__(self, in_channels, hidden_size_input, max_freqs=8): super().__init__() self.max_freqs = max_freqs self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs ** 2, hidden_size_input)) @lru_cache(maxsize=4) def _pos(self, ps, device, dtype): px = torch.linspace(0, 1, ps, device=device, dtype=dtype) py = torch.linspace(0, 1, ps, device=device, dtype=dtype) py, px = torch.meshgrid(py, px, indexing="ij") fx = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) c = (1 + fx[None, :, None] * fx[None, None, :]) ** -1 dx = torch.cos(px.reshape(-1, 1, 1) * fx[None, :, None] * torch.pi) dy = torch.cos(py.reshape(-1, 1, 1) * fx[None, None, :] * torch.pi) return (dx * dy * c).view(1, -1, self.max_freqs ** 2) def forward(self, x): B, P2, C = x.shape weight_dtype = self.embedder[0].weight.dtype x = x.to(weight_dtype) dct = self._pos(int(P2 ** 0.5), x.device, weight_dtype).expand(B, -1, -1) return self.embedder(torch.cat((x, dct), dim=-1)) class PixelResBlock(nn.Module): def __init__(self, ch): super().__init__() self.in_ln = nn.LayerNorm(ch, eps=1e-6) self.mlp = nn.Sequential(nn.Linear(ch, ch), nn.SiLU(), nn.Linear(ch, ch)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(ch, 3 * ch)) def forward(self, x, y): shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) h = self.in_ln(x) * (1 + scale) + shift return x + gate * self.mlp(h) class SimpleMLPAdaLN(nn.Module): def __init__(self, in_channels, model_channels, out_channels, z_channels, num_res_blocks=4, max_freqs=8): super().__init__() self.cond_embed = nn.Linear(z_channels, model_channels) self.input_embedder = NerfEmbedder(in_channels, model_channels, max_freqs) self.res_blocks = nn.ModuleList([PixelResBlock(model_channels) for _ in range(num_res_blocks)]) self.final_layer = nn.ModuleDict({ "norm_final": nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6), "linear": nn.Linear(model_channels, out_channels), }) def forward(self, x, c): weight_dtype = self.cond_embed.weight.dtype x = self.input_embedder(x) y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) x = x.to(weight_dtype) for block in self.res_blocks: x = block(x, y) return self.final_layer["linear"](self.final_layer["norm_final"](x)) class NextDiTPixelSpace(nn.Module): def __init__(self, patch_size=32, in_channels=3, dim=3840, n_layers=30, n_refiner_layers=2, n_heads=30, n_kv_heads=30, multiple_of=256, ffn_dim_multiplier=4.0, norm_eps=1e-5, qk_norm=True, cap_feat_dim=2560, axes_dims=(32, 48, 48), axes_lens=(1536, 512, 512), rope_theta=256.0, time_scale=1000.0, pad_tokens_multiple=32, decoder_hidden_size=3840, decoder_num_res_blocks=4, decoder_max_freqs=8, decoder_in_channels=None): super().__init__() self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size self.time_scale = time_scale self.pad_tokens_multiple = pad_tokens_multiple self.dim = dim self.x_embedder = nn.Linear(patch_size ** 2 * in_channels, dim, bias=True) self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256) self.cap_embedder = nn.Sequential( nn.RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True), nn.Linear(cap_feat_dim, dim, bias=True), ) self.x_pad_token = nn.Parameter(torch.empty(1, dim)) self.cap_pad_token = nn.Parameter(torch.empty(1, dim)) bk = dict(multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, norm_eps=norm_eps, qk_norm=qk_norm) self.noise_refiner = nn.ModuleList([ JointTransformerBlock(dim, n_heads, n_kv_heads, modulation=True, z_image_modulation=True, **bk) for _ in range(n_refiner_layers)]) self.context_refiner = nn.ModuleList([ JointTransformerBlock(dim, n_heads, n_kv_heads, modulation=False, **bk) for _ in range(n_refiner_layers)]) self.layers = nn.ModuleList([ JointTransformerBlock(dim, n_heads, n_kv_heads, z_image_modulation=True, **bk) for _ in range(n_layers)]) dec_in = decoder_in_channels or (patch_size ** 2 * in_channels) self.dec_net = SimpleMLPAdaLN(dec_in, decoder_hidden_size, dec_in, dim, decoder_num_res_blocks, decoder_max_freqs) self.register_buffer("__x0__", torch.tensor([])) self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims) def _pad(self, feats, pad_token): extra = (-feats.shape[1]) % self.pad_tokens_multiple if extra > 0: pad = pad_token.to(device=feats.device, dtype=feats.dtype).unsqueeze(0).expand(feats.shape[0], extra, -1).clone() feats = torch.cat((feats, pad), dim=1) return feats, extra def _unpatchify(self, patches, H, W): pH = pW = self.patch_size C = self.out_channels return (patches.view(-1, H // pH, W // pW, pH, pW, C) .permute(0, 5, 1, 3, 2, 4).reshape(-1, C, H, W)) @torch.no_grad() def forward(self, x, timesteps, context): """ Returns the SAME output as ComfyUI's NextDiTPixelSpace.forward(): output = (x - neg_x0) / sigma where neg_x0 = -decoder_output (the _forward returns negated). The caller (sampling loop) should do: x_next = x + (sigma_next - sigma) * output """ B, C, H, W = x.shape pH = pW = self.patch_size device = x.device weight_dtype = self.x_embedder.weight.dtype x = x.to(weight_dtype) context = context.to(weight_dtype) pad_h = (pH - H % pH) % pH pad_w = (pW - W % pW) % pW if pad_h or pad_w: x = F.pad(x, (0, pad_w, 0, pad_h)) _, _, Hp, Wp = x.shape Ht, Wt = Hp // pH, Wp // pW N = Ht * Wt t = 1.0 - timesteps.float() adaln = self.t_embedder(t * self.time_scale, dtype=weight_dtype) patches_raw = (x.view(B, C, Ht, pH, Wt, pW) .permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) pixel_values = patches_raw.reshape(B * N, 1, pH * pW * C) cap = self.cap_embedder(context) cap, _ = self._pad(cap, self.cap_pad_token) cap_len = cap.shape[1] cap_ids = torch.zeros(B, cap_len, 3, dtype=torch.float32, device=device) cap_ids[:, :, 0] = torch.arange(cap_len, dtype=torch.float32, device=device) + 1.0 cap_freq = self.rope_embedder(cap_ids).movedim(1, 2) for layer in self.context_refiner: cap = layer(cap, None, cap_freq) img = self.x_embedder(patches_raw) img, _ = self._pad(img, self.x_pad_token) x_ids = torch.zeros(B, img.shape[1], 3, dtype=torch.float32, device=device) x_ids[:, :N, 0] = cap_len + 1 x_ids[:, :N, 1] = torch.arange(Ht, dtype=torch.float32, device=device).view(-1, 1).expand(-1, Wt).flatten() x_ids[:, :N, 2] = torch.arange(Wt, dtype=torch.float32, device=device).view(1, -1).expand(Ht, -1).flatten() img_freq = self.rope_embedder(x_ids).movedim(1, 2) for layer in self.noise_refiner: img = layer(img, None, img_freq, adaln) full = torch.cat([cap, img], dim=1) full_freq = torch.cat([cap_freq, img_freq], dim=1) for layer in self.layers: full = layer(full, None, full_freq, adaln) hidden = full[:, cap_len:cap_len + N, :] cond = hidden.reshape(B * N, self.dim) out = self.dec_net(pixel_values, cond).reshape(B, N, -1) # _forward returns -img_out (negated decoder output) neg_x0 = self._unpatchify(out, Hp, Wp)[:, :, :H, :W] # ComfyUI forward() does: return (x - neg_x0) / timesteps # This is the model_output that ComfyUI's sampler uses directly: # x_next = x + (sigma_next - sigma) * model_output x_crop = x[:, :, :H, :W] sigma = timesteps.float().view(-1, 1, 1, 1).clamp(min=1e-5) return (x_crop - (-neg_x0)) / sigma