deep1401 commited on
Commit
b10bf55
·
verified ·
1 Parent(s): ec1b13f

Upload gguf_loader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gguf_loader.py +52 -0
gguf_loader.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Load this Q4_K GGUF DiT into the Ideogram 4 pipeline.
2
+
3
+ Approach: build the standard FP8 pipeline (architecture + text encoder + VAE),
4
+ then replace each quantized DiT linear with the Q4_K weights from this GGUF.
5
+ Dequantization uses gguf-py's reference decoder (the same decoder our quantizer
6
+ was verified bit-exact against).
7
+
8
+ REFERENCE IMPLEMENTATION — the dequant math is validated, but this standalone
9
+ loader has not been end-to-end tested on a GPU. Verify before relying on it.
10
+ GGUF tensor names follow `{branch}.{module}.weight` (branch in {cond, uncond}),
11
+ see recipe-q4_k.json.
12
+ """
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ from gguf import GGUFReader, dequantize, GGMLQuantizationType
17
+
18
+ _PLAIN = {GGMLQuantizationType.F16, GGMLQuantizationType.F32}
19
+
20
+
21
+ def load_gguf_tensors(path):
22
+ """Return {tensor_name: torch.Tensor(bf16)} dequantized from the GGUF."""
23
+ reader = GGUFReader(path)
24
+ out = {}
25
+ for t in reader.tensors:
26
+ arr = np.array(t.data) if t.tensor_type in _PLAIN else dequantize(t.data, t.tensor_type)
27
+ shape = tuple(int(d) for d in reversed(t.shape)) # GGUF stores dims reversed
28
+ out[t.name] = torch.from_numpy(np.ascontiguousarray(arr).reshape(shape)).to(torch.bfloat16)
29
+ return out
30
+
31
+
32
+ def swap_branch(dit, tensors, branch):
33
+ """Replace every fp8 linear in `dit` with the GGUF Q4_K weight for `branch`."""
34
+ swapped = 0
35
+ for pname, parent in dit.named_modules():
36
+ for cname, child in list(parent.named_children()):
37
+ w = getattr(child, "weight", None)
38
+ if w is None or w.ndim != 2 or w.dtype != torch.float8_e4m3fn:
39
+ continue
40
+ full = f"{pname}.{cname}" if pname else cname
41
+ wkey, bkey = f"{branch}.{full}.weight", f"{branch}.{full}.bias"
42
+ if wkey not in tensors:
43
+ continue
44
+ W = tensors[wkey].to(w.device)
45
+ lin = nn.Linear(W.shape[1], W.shape[0], bias=(bkey in tensors),
46
+ dtype=torch.bfloat16, device=w.device)
47
+ lin.weight = nn.Parameter(W, requires_grad=False)
48
+ if bkey in tensors:
49
+ lin.bias = nn.Parameter(tensors[bkey].to(w.device), requires_grad=False)
50
+ setattr(parent, cname, lin)
51
+ swapped += 1
52
+ return swapped