#!/usr/bin/env python3 """Load mixed precision Bernini checkpoints (INT4 transformer + INT8 T5/embeddings).""" import torch from safetensors.torch import load_file def _restore_shape(t_2d, orig_shape): s = tuple(orig_shape) if len(s) == 1: return t_2d.squeeze(0) elif len(s) == 2: return t_2d else: return t_2d.reshape(s) def dequantize_state_dict(state_dict, torch_dtype=None): if torch_dtype is None: torch_dtype = torch.float32 result = {} count_int4p = 0 count_int8 = 0 for key, tensor in state_dict.items(): # Skip metadata tensors skip_suffixes = [".int4_scale", ".int4_zero_point", ".int4_orig_shape", ".int8_scale", ".int8_zero_point", ".int8_orig_shape"] if any(key.endswith(s) for s in skip_suffixes): continue # Check for INT4 packed (uint8 with int4 metadata) scale_key_4 = f"{key}.int4_scale" if scale_key_4 in state_dict and tensor.dtype == torch.uint8: scale = state_dict[scale_key_4] zp = state_dict.get(f"{key}.int4_zero_point", torch.zeros_like(scale)) orig_shape = tuple(state_dict[f"{key}.int4_orig_shape"].tolist()) # Unpack 2D and dequantize lower = tensor & 0x0F upper = (tensor >> 4) & 0x0F unpacked_2d = torch.stack([lower, upper], dim=-1).reshape(tensor.shape[0], -1).to(torch.float32) deq_2d = (unpacked_2d - zp.unsqueeze(1)) * scale.unsqueeze(1) result[key] = _restore_shape(deq_2d, orig_shape).to(torch_dtype) count_int4p += 1 # Check for INT8 (int8 with int8 metadata) elif f"{key}.int8_scale" in state_dict and tensor.dtype == torch.int8: scale = state_dict[f"{key}.int8_scale"] zp = state_dict.get(f"{key}.int8_zero_point", torch.zeros_like(scale)) orig_shape = tuple(state_dict[f"{key}.int8_orig_shape"].tolist()) # Reshape to 2D for per-channel dequantization if len(orig_shape) == 1: t_2d = tensor.unsqueeze(0) elif len(orig_shape) == 2: t_2d = tensor.reshape(orig_shape) else: t_2d = tensor.reshape(orig_shape[0], -1) deq_2d = (t_2d.to(torch.float32) + 128.0 - zp.unsqueeze(1)) * scale.unsqueeze(1) result[key] = _restore_shape(deq_2d, orig_shape).to(torch_dtype) count_int8 += 1 else: # FP32 tensor — pass through result[key] = tensor if torch_dtype is None else tensor.to(torch_dtype) return result, count_int4p, count_int8 def load_mixed(path, torch_dtype=None): from pathlib import Path path = Path(path) if path.is_dir(): files = list(path.glob("*.mixed-int8-int4p.safetensors")) assert files, f"No .mixed-int8-int4p.safetensors in {path}" path = files[0] sd = load_file(str(path), device="cpu") result, c_int4p, c_int8 = dequantize_state_dict(sd, torch_dtype) print(f"Dequantized: {c_int4p} INT4 packed + {c_int8} INT8 -> {torch_dtype}") return result if __name__ == "__main__": import sys if len(sys.argv) > 1: sd = load_mixed(sys.argv[1]) print(f"Loaded {len(sd)} tensors, {sum(t.numel() for t in sd.values())/1e9:.2f}B params")