File size: 2,290 Bytes
b10bf55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Load this Q4_K GGUF DiT into the Ideogram 4 pipeline.

Approach: build the standard FP8 pipeline (architecture + text encoder + VAE),
then replace each quantized DiT linear with the Q4_K weights from this GGUF.
Dequantization uses gguf-py's reference decoder (the same decoder our quantizer
was verified bit-exact against).

REFERENCE IMPLEMENTATION — the dequant math is validated, but this standalone
loader has not been end-to-end tested on a GPU. Verify before relying on it.
GGUF tensor names follow `{branch}.{module}.weight` (branch in {cond, uncond}),
see recipe-q4_k.json.
"""
import numpy as np
import torch
import torch.nn as nn
from gguf import GGUFReader, dequantize, GGMLQuantizationType

_PLAIN = {GGMLQuantizationType.F16, GGMLQuantizationType.F32}


def load_gguf_tensors(path):
    """Return {tensor_name: torch.Tensor(bf16)} dequantized from the GGUF."""
    reader = GGUFReader(path)
    out = {}
    for t in reader.tensors:
        arr = np.array(t.data) if t.tensor_type in _PLAIN else dequantize(t.data, t.tensor_type)
        shape = tuple(int(d) for d in reversed(t.shape))  # GGUF stores dims reversed
        out[t.name] = torch.from_numpy(np.ascontiguousarray(arr).reshape(shape)).to(torch.bfloat16)
    return out


def swap_branch(dit, tensors, branch):
    """Replace every fp8 linear in `dit` with the GGUF Q4_K weight for `branch`."""
    swapped = 0
    for pname, parent in dit.named_modules():
        for cname, child in list(parent.named_children()):
            w = getattr(child, "weight", None)
            if w is None or w.ndim != 2 or w.dtype != torch.float8_e4m3fn:
                continue
            full = f"{pname}.{cname}" if pname else cname
            wkey, bkey = f"{branch}.{full}.weight", f"{branch}.{full}.bias"
            if wkey not in tensors:
                continue
            W = tensors[wkey].to(w.device)
            lin = nn.Linear(W.shape[1], W.shape[0], bias=(bkey in tensors),
                            dtype=torch.bfloat16, device=w.device)
            lin.weight = nn.Parameter(W, requires_grad=False)
            if bkey in tensors:
                lin.bias = nn.Parameter(tensors[bkey].to(w.device), requires_grad=False)
            setattr(parent, cname, lin)
            swapped += 1
    return swapped