ideogram-4-gguf-q4_k / gguf_loader.py
deep1401's picture
Upload gguf_loader.py with huggingface_hub
b10bf55 verified
"""Load this Q4_K GGUF DiT into the Ideogram 4 pipeline.
Approach: build the standard FP8 pipeline (architecture + text encoder + VAE),
then replace each quantized DiT linear with the Q4_K weights from this GGUF.
Dequantization uses gguf-py's reference decoder (the same decoder our quantizer
was verified bit-exact against).
REFERENCE IMPLEMENTATION — the dequant math is validated, but this standalone
loader has not been end-to-end tested on a GPU. Verify before relying on it.
GGUF tensor names follow `{branch}.{module}.weight` (branch in {cond, uncond}),
see recipe-q4_k.json.
"""
import numpy as np
import torch
import torch.nn as nn
from gguf import GGUFReader, dequantize, GGMLQuantizationType
_PLAIN = {GGMLQuantizationType.F16, GGMLQuantizationType.F32}
def load_gguf_tensors(path):
"""Return {tensor_name: torch.Tensor(bf16)} dequantized from the GGUF."""
reader = GGUFReader(path)
out = {}
for t in reader.tensors:
arr = np.array(t.data) if t.tensor_type in _PLAIN else dequantize(t.data, t.tensor_type)
shape = tuple(int(d) for d in reversed(t.shape)) # GGUF stores dims reversed
out[t.name] = torch.from_numpy(np.ascontiguousarray(arr).reshape(shape)).to(torch.bfloat16)
return out
def swap_branch(dit, tensors, branch):
"""Replace every fp8 linear in `dit` with the GGUF Q4_K weight for `branch`."""
swapped = 0
for pname, parent in dit.named_modules():
for cname, child in list(parent.named_children()):
w = getattr(child, "weight", None)
if w is None or w.ndim != 2 or w.dtype != torch.float8_e4m3fn:
continue
full = f"{pname}.{cname}" if pname else cname
wkey, bkey = f"{branch}.{full}.weight", f"{branch}.{full}.bias"
if wkey not in tensors:
continue
W = tensors[wkey].to(w.device)
lin = nn.Linear(W.shape[1], W.shape[0], bias=(bkey in tensors),
dtype=torch.bfloat16, device=w.device)
lin.weight = nn.Parameter(W, requires_grad=False)
if bkey in tensors:
lin.bias = nn.Parameter(tensors[bkey].to(w.device), requires_grad=False)
setattr(parent, cname, lin)
swapped += 1
return swapped