Upload gguf_loader.py with huggingface_hub
Browse files- gguf_loader.py +52 -0
gguf_loader.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load this Q4_K GGUF DiT into the Ideogram 4 pipeline.
|
| 2 |
+
|
| 3 |
+
Approach: build the standard FP8 pipeline (architecture + text encoder + VAE),
|
| 4 |
+
then replace each quantized DiT linear with the Q4_K weights from this GGUF.
|
| 5 |
+
Dequantization uses gguf-py's reference decoder (the same decoder our quantizer
|
| 6 |
+
was verified bit-exact against).
|
| 7 |
+
|
| 8 |
+
REFERENCE IMPLEMENTATION — the dequant math is validated, but this standalone
|
| 9 |
+
loader has not been end-to-end tested on a GPU. Verify before relying on it.
|
| 10 |
+
GGUF tensor names follow `{branch}.{module}.weight` (branch in {cond, uncond}),
|
| 11 |
+
see recipe-q4_k.json.
|
| 12 |
+
"""
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from gguf import GGUFReader, dequantize, GGMLQuantizationType
|
| 17 |
+
|
| 18 |
+
_PLAIN = {GGMLQuantizationType.F16, GGMLQuantizationType.F32}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_gguf_tensors(path):
|
| 22 |
+
"""Return {tensor_name: torch.Tensor(bf16)} dequantized from the GGUF."""
|
| 23 |
+
reader = GGUFReader(path)
|
| 24 |
+
out = {}
|
| 25 |
+
for t in reader.tensors:
|
| 26 |
+
arr = np.array(t.data) if t.tensor_type in _PLAIN else dequantize(t.data, t.tensor_type)
|
| 27 |
+
shape = tuple(int(d) for d in reversed(t.shape)) # GGUF stores dims reversed
|
| 28 |
+
out[t.name] = torch.from_numpy(np.ascontiguousarray(arr).reshape(shape)).to(torch.bfloat16)
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def swap_branch(dit, tensors, branch):
|
| 33 |
+
"""Replace every fp8 linear in `dit` with the GGUF Q4_K weight for `branch`."""
|
| 34 |
+
swapped = 0
|
| 35 |
+
for pname, parent in dit.named_modules():
|
| 36 |
+
for cname, child in list(parent.named_children()):
|
| 37 |
+
w = getattr(child, "weight", None)
|
| 38 |
+
if w is None or w.ndim != 2 or w.dtype != torch.float8_e4m3fn:
|
| 39 |
+
continue
|
| 40 |
+
full = f"{pname}.{cname}" if pname else cname
|
| 41 |
+
wkey, bkey = f"{branch}.{full}.weight", f"{branch}.{full}.bias"
|
| 42 |
+
if wkey not in tensors:
|
| 43 |
+
continue
|
| 44 |
+
W = tensors[wkey].to(w.device)
|
| 45 |
+
lin = nn.Linear(W.shape[1], W.shape[0], bias=(bkey in tensors),
|
| 46 |
+
dtype=torch.bfloat16, device=w.device)
|
| 47 |
+
lin.weight = nn.Parameter(W, requires_grad=False)
|
| 48 |
+
if bkey in tensors:
|
| 49 |
+
lin.bias = nn.Parameter(tensors[bkey].to(w.device), requires_grad=False)
|
| 50 |
+
setattr(parent, cname, lin)
|
| 51 |
+
swapped += 1
|
| 52 |
+
return swapped
|