| """Load this Q4_K GGUF DiT into the Ideogram 4 pipeline. |
| |
| Approach: build the standard FP8 pipeline (architecture + text encoder + VAE), |
| then replace each quantized DiT linear with the Q4_K weights from this GGUF. |
| Dequantization uses gguf-py's reference decoder (the same decoder our quantizer |
| was verified bit-exact against). |
| |
| REFERENCE IMPLEMENTATION — the dequant math is validated, but this standalone |
| loader has not been end-to-end tested on a GPU. Verify before relying on it. |
| GGUF tensor names follow `{branch}.{module}.weight` (branch in {cond, uncond}), |
| see recipe-q4_k.json. |
| """ |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from gguf import GGUFReader, dequantize, GGMLQuantizationType |
|
|
| _PLAIN = {GGMLQuantizationType.F16, GGMLQuantizationType.F32} |
|
|
|
|
| def load_gguf_tensors(path): |
| """Return {tensor_name: torch.Tensor(bf16)} dequantized from the GGUF.""" |
| reader = GGUFReader(path) |
| out = {} |
| for t in reader.tensors: |
| arr = np.array(t.data) if t.tensor_type in _PLAIN else dequantize(t.data, t.tensor_type) |
| shape = tuple(int(d) for d in reversed(t.shape)) |
| out[t.name] = torch.from_numpy(np.ascontiguousarray(arr).reshape(shape)).to(torch.bfloat16) |
| return out |
|
|
|
|
| def swap_branch(dit, tensors, branch): |
| """Replace every fp8 linear in `dit` with the GGUF Q4_K weight for `branch`.""" |
| swapped = 0 |
| for pname, parent in dit.named_modules(): |
| for cname, child in list(parent.named_children()): |
| w = getattr(child, "weight", None) |
| if w is None or w.ndim != 2 or w.dtype != torch.float8_e4m3fn: |
| continue |
| full = f"{pname}.{cname}" if pname else cname |
| wkey, bkey = f"{branch}.{full}.weight", f"{branch}.{full}.bias" |
| if wkey not in tensors: |
| continue |
| W = tensors[wkey].to(w.device) |
| lin = nn.Linear(W.shape[1], W.shape[0], bias=(bkey in tensors), |
| dtype=torch.bfloat16, device=w.device) |
| lin.weight = nn.Parameter(W, requires_grad=False) |
| if bkey in tensors: |
| lin.bias = nn.Parameter(tensors[bkey].to(w.device), requires_grad=False) |
| setattr(parent, cname, lin) |
| swapped += 1 |
| return swapped |
|
|