"""Load this Q4_K GGUF DiT into the Ideogram 4 pipeline. Approach: build the standard FP8 pipeline (architecture + text encoder + VAE), then replace each quantized DiT linear with the Q4_K weights from this GGUF. Dequantization uses gguf-py's reference decoder (the same decoder our quantizer was verified bit-exact against). REFERENCE IMPLEMENTATION — the dequant math is validated, but this standalone loader has not been end-to-end tested on a GPU. Verify before relying on it. GGUF tensor names follow `{branch}.{module}.weight` (branch in {cond, uncond}), see recipe-q4_k.json. """ import numpy as np import torch import torch.nn as nn from gguf import GGUFReader, dequantize, GGMLQuantizationType _PLAIN = {GGMLQuantizationType.F16, GGMLQuantizationType.F32} def load_gguf_tensors(path): """Return {tensor_name: torch.Tensor(bf16)} dequantized from the GGUF.""" reader = GGUFReader(path) out = {} for t in reader.tensors: arr = np.array(t.data) if t.tensor_type in _PLAIN else dequantize(t.data, t.tensor_type) shape = tuple(int(d) for d in reversed(t.shape)) # GGUF stores dims reversed out[t.name] = torch.from_numpy(np.ascontiguousarray(arr).reshape(shape)).to(torch.bfloat16) return out def swap_branch(dit, tensors, branch): """Replace every fp8 linear in `dit` with the GGUF Q4_K weight for `branch`.""" swapped = 0 for pname, parent in dit.named_modules(): for cname, child in list(parent.named_children()): w = getattr(child, "weight", None) if w is None or w.ndim != 2 or w.dtype != torch.float8_e4m3fn: continue full = f"{pname}.{cname}" if pname else cname wkey, bkey = f"{branch}.{full}.weight", f"{branch}.{full}.bias" if wkey not in tensors: continue W = tensors[wkey].to(w.device) lin = nn.Linear(W.shape[1], W.shape[0], bias=(bkey in tensors), dtype=torch.bfloat16, device=w.device) lin.weight = nn.Parameter(W, requires_grad=False) if bkey in tensors: lin.bias = nn.Parameter(tensors[bkey].to(w.device), requires_grad=False) setattr(parent, cname, lin) swapped += 1 return swapped