zeta-chroma-test-bench / zeta_model.py
chipfoxx's picture
Fix: return raw x0 from model, use custom ComfyUI-matching Euler sampler instead of diffusers scheduler"
f7df763 verified
Raw
History Blame Contribute Delete
15.1 kB
"""
NextDiTPixelSpace — Zeta-Chroma model architecture
Re-implemented from ComfyUI comfy/ldm/lumina/model.py
Key names match the checkpoint exactly (split Q/K/V, ModuleDict final_layer).
This version returns the raw model output matching ComfyUI's forward() convention:
output = (x - neg_x0) / sigma
The sampling loop in app.py does the Euler step directly.
"""
import math
from functools import lru_cache
import torch
import torch.nn as nn
import torch.nn.functional as F
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half)
args = t.float().unsqueeze(-1) * freqs.unsqueeze(0)
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, output_size or hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t, dtype=None):
emb = timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
emb = emb.to(weight_dtype)
return self.mlp(emb)
# ── RoPE ──
def rope(pos, dim, theta):
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta ** scale)
out = torch.einsum("...n,d->...nd", pos.float(), omega)
cos_out, sin_out = torch.cos(out), torch.sin(out)
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
return out.view(*out.shape[:-1], 2, 2).float()
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class EmbedND(nn.Module):
def __init__(self, dim, theta, axes_dim):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids):
n_axes = ids.shape[-1]
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class JointAttention(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads=None, qk_norm=True):
super().__init__()
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.to_q = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.to_k = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.to_v = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.to_out = nn.ModuleList([nn.Linear(n_heads * self.head_dim, dim, bias=False)])
self.norm_q = nn.RMSNorm(self.head_dim, elementwise_affine=True) if qk_norm else nn.Identity()
self.norm_k = nn.RMSNorm(self.head_dim, elementwise_affine=True) if qk_norm else nn.Identity()
def forward(self, x, x_mask, freqs_cis):
B, S, _ = x.shape
xq = self.to_q(x).view(B, S, self.n_heads, self.head_dim)
xk = self.to_k(x).view(B, S, self.n_kv_heads, self.head_dim)
xv = self.to_v(x).view(B, S, self.n_kv_heads, self.head_dim)
xq, xk = self.norm_q(xq), self.norm_k(xk)
xq, xk = apply_rope(xq, xk, freqs_cis)
n_rep = self.n_heads // self.n_kv_heads
if n_rep > 1:
xk = xk.unsqueeze(3).expand(-1, -1, -1, n_rep, -1).flatten(2, 3)
xv = xv.unsqueeze(3).expand(-1, -1, -1, n_rep, -1).flatten(2, 3)
out = F.scaled_dot_product_attention(xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2))
return self.to_out[0](out.transpose(1, 2).contiguous().view(B, S, -1))
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None):
super().__init__()
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class JointTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads, multiple_of=256, ffn_dim_multiplier=4.0,
norm_eps=1e-5, qk_norm=True, modulation=True, z_image_modulation=False):
super().__init__()
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
self.feed_forward = FeedForward(dim, dim, multiple_of, ffn_dim_multiplier)
self.attention_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
self.ffn_norm1 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
self.attention_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
self.ffn_norm2 = nn.RMSNorm(dim, eps=norm_eps, elementwise_affine=True)
self.modulation = modulation
if modulation:
mod_in = min(dim, 256) if z_image_modulation else min(dim, 1024)
if z_image_modulation:
self.adaLN_modulation = nn.Sequential(nn.Linear(mod_in, 4 * dim, bias=True))
else:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(mod_in, 4 * dim, bias=True))
def forward(self, x, x_mask, freqs_cis, adaln_input=None):
if self.modulation:
s_msa, g_msa, s_mlp, g_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + g_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention(self.attention_norm1(x) * (1 + s_msa.unsqueeze(1)), x_mask, freqs_cis))
x = x + g_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
self.feed_forward(self.ffn_norm1(x) * (1 + s_mlp.unsqueeze(1))))
else:
x = x + self.attention_norm2(self.attention(self.attention_norm1(x), x_mask, freqs_cis))
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
return x
class NerfEmbedder(nn.Module):
def __init__(self, in_channels, hidden_size_input, max_freqs=8):
super().__init__()
self.max_freqs = max_freqs
self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs ** 2, hidden_size_input))
@lru_cache(maxsize=4)
def _pos(self, ps, device, dtype):
px = torch.linspace(0, 1, ps, device=device, dtype=dtype)
py = torch.linspace(0, 1, ps, device=device, dtype=dtype)
py, px = torch.meshgrid(py, px, indexing="ij")
fx = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
c = (1 + fx[None, :, None] * fx[None, None, :]) ** -1
dx = torch.cos(px.reshape(-1, 1, 1) * fx[None, :, None] * torch.pi)
dy = torch.cos(py.reshape(-1, 1, 1) * fx[None, None, :] * torch.pi)
return (dx * dy * c).view(1, -1, self.max_freqs ** 2)
def forward(self, x):
B, P2, C = x.shape
weight_dtype = self.embedder[0].weight.dtype
x = x.to(weight_dtype)
dct = self._pos(int(P2 ** 0.5), x.device, weight_dtype).expand(B, -1, -1)
return self.embedder(torch.cat((x, dct), dim=-1))
class PixelResBlock(nn.Module):
def __init__(self, ch):
super().__init__()
self.in_ln = nn.LayerNorm(ch, eps=1e-6)
self.mlp = nn.Sequential(nn.Linear(ch, ch), nn.SiLU(), nn.Linear(ch, ch))
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(ch, 3 * ch))
def forward(self, x, y):
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
h = self.in_ln(x) * (1 + scale) + shift
return x + gate * self.mlp(h)
class SimpleMLPAdaLN(nn.Module):
def __init__(self, in_channels, model_channels, out_channels, z_channels, num_res_blocks=4, max_freqs=8):
super().__init__()
self.cond_embed = nn.Linear(z_channels, model_channels)
self.input_embedder = NerfEmbedder(in_channels, model_channels, max_freqs)
self.res_blocks = nn.ModuleList([PixelResBlock(model_channels) for _ in range(num_res_blocks)])
self.final_layer = nn.ModuleDict({
"norm_final": nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6),
"linear": nn.Linear(model_channels, out_channels),
})
def forward(self, x, c):
weight_dtype = self.cond_embed.weight.dtype
x = self.input_embedder(x)
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1)
x = x.to(weight_dtype)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer["linear"](self.final_layer["norm_final"](x))
class NextDiTPixelSpace(nn.Module):
def __init__(self, patch_size=32, in_channels=3, dim=3840, n_layers=30, n_refiner_layers=2,
n_heads=30, n_kv_heads=30, multiple_of=256, ffn_dim_multiplier=4.0,
norm_eps=1e-5, qk_norm=True, cap_feat_dim=2560,
axes_dims=(32, 48, 48), axes_lens=(1536, 512, 512),
rope_theta=256.0, time_scale=1000.0, pad_tokens_multiple=32,
decoder_hidden_size=3840, decoder_num_res_blocks=4,
decoder_max_freqs=8, decoder_in_channels=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.dim = dim
self.x_embedder = nn.Linear(patch_size ** 2 * in_channels, dim, bias=True)
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256)
self.cap_embedder = nn.Sequential(
nn.RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True),
nn.Linear(cap_feat_dim, dim, bias=True),
)
self.x_pad_token = nn.Parameter(torch.empty(1, dim))
self.cap_pad_token = nn.Parameter(torch.empty(1, dim))
bk = dict(multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier,
norm_eps=norm_eps, qk_norm=qk_norm)
self.noise_refiner = nn.ModuleList([
JointTransformerBlock(dim, n_heads, n_kv_heads, modulation=True, z_image_modulation=True, **bk)
for _ in range(n_refiner_layers)])
self.context_refiner = nn.ModuleList([
JointTransformerBlock(dim, n_heads, n_kv_heads, modulation=False, **bk)
for _ in range(n_refiner_layers)])
self.layers = nn.ModuleList([
JointTransformerBlock(dim, n_heads, n_kv_heads, z_image_modulation=True, **bk)
for _ in range(n_layers)])
dec_in = decoder_in_channels or (patch_size ** 2 * in_channels)
self.dec_net = SimpleMLPAdaLN(dec_in, decoder_hidden_size, dec_in, dim,
decoder_num_res_blocks, decoder_max_freqs)
self.register_buffer("__x0__", torch.tensor([]))
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
def _pad(self, feats, pad_token):
extra = (-feats.shape[1]) % self.pad_tokens_multiple
if extra > 0:
pad = pad_token.to(device=feats.device, dtype=feats.dtype).unsqueeze(0).expand(feats.shape[0], extra, -1).clone()
feats = torch.cat((feats, pad), dim=1)
return feats, extra
def _unpatchify(self, patches, H, W):
pH = pW = self.patch_size
C = self.out_channels
return (patches.view(-1, H // pH, W // pW, pH, pW, C)
.permute(0, 5, 1, 3, 2, 4).reshape(-1, C, H, W))
@torch.no_grad()
def forward(self, x, timesteps, context):
"""
Returns the SAME output as ComfyUI's NextDiTPixelSpace.forward():
output = (x - neg_x0) / sigma
where neg_x0 = -decoder_output (the _forward returns negated).
The caller (sampling loop) should do:
x_next = x + (sigma_next - sigma) * output
"""
B, C, H, W = x.shape
pH = pW = self.patch_size
device = x.device
weight_dtype = self.x_embedder.weight.dtype
x = x.to(weight_dtype)
context = context.to(weight_dtype)
pad_h = (pH - H % pH) % pH
pad_w = (pW - W % pW) % pW
if pad_h or pad_w:
x = F.pad(x, (0, pad_w, 0, pad_h))
_, _, Hp, Wp = x.shape
Ht, Wt = Hp // pH, Wp // pW
N = Ht * Wt
t = 1.0 - timesteps.float()
adaln = self.t_embedder(t * self.time_scale, dtype=weight_dtype)
patches_raw = (x.view(B, C, Ht, pH, Wt, pW)
.permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
pixel_values = patches_raw.reshape(B * N, 1, pH * pW * C)
cap = self.cap_embedder(context)
cap, _ = self._pad(cap, self.cap_pad_token)
cap_len = cap.shape[1]
cap_ids = torch.zeros(B, cap_len, 3, dtype=torch.float32, device=device)
cap_ids[:, :, 0] = torch.arange(cap_len, dtype=torch.float32, device=device) + 1.0
cap_freq = self.rope_embedder(cap_ids).movedim(1, 2)
for layer in self.context_refiner:
cap = layer(cap, None, cap_freq)
img = self.x_embedder(patches_raw)
img, _ = self._pad(img, self.x_pad_token)
x_ids = torch.zeros(B, img.shape[1], 3, dtype=torch.float32, device=device)
x_ids[:, :N, 0] = cap_len + 1
x_ids[:, :N, 1] = torch.arange(Ht, dtype=torch.float32, device=device).view(-1, 1).expand(-1, Wt).flatten()
x_ids[:, :N, 2] = torch.arange(Wt, dtype=torch.float32, device=device).view(1, -1).expand(Ht, -1).flatten()
img_freq = self.rope_embedder(x_ids).movedim(1, 2)
for layer in self.noise_refiner:
img = layer(img, None, img_freq, adaln)
full = torch.cat([cap, img], dim=1)
full_freq = torch.cat([cap_freq, img_freq], dim=1)
for layer in self.layers:
full = layer(full, None, full_freq, adaln)
hidden = full[:, cap_len:cap_len + N, :]
cond = hidden.reshape(B * N, self.dim)
out = self.dec_net(pixel_values, cond).reshape(B, N, -1)
# _forward returns -img_out (negated decoder output)
neg_x0 = self._unpatchify(out, Hp, Wp)[:, :, :H, :W]
# ComfyUI forward() does: return (x - neg_x0) / timesteps
# This is the model_output that ComfyUI's sampler uses directly:
# x_next = x + (sigma_next - sigma) * model_output
x_crop = x[:, :, :H, :W]
sigma = timesteps.float().view(-1, 1, 1, 1).clamp(min=1e-5)
return (x_crop - (-neg_x0)) / sigma