# Copyright 2026 Sam McLeod # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Dynamic INT8 (weights-only) quantiser for the Granite Speech 4.1 ONNX exports. Wraps `onnxruntime.quantization.quantize_dynamic` with the conventions used by the Granite Speech ONNX bundles: - Single external-data sidecar per graph (mirrors the FP32 export layout). - Pure `ai.onnx` opset 20 / IR 10. The default operator set is restricted to `MatMul` so the dynamic quantiser emits `MatMulInteger` (standard `ai.onnx`) rather than the `com.microsoft.Attention` / `com.microsoft.EmbedLayerNormalization` quantised variants. Override at your own risk - those domain ops are forbidden by the parakeet-rs consumer contract. - `per_channel=True` and `weight_type=QInt8` by default (better accuracy on the LLM weight tensors with no measurable speed cost on arm64 / x86 CPU EP). The script is self-contained (no project-internal imports) so it ships inside each Hugging Face bundle alongside the export script. Usage: python quantise.py --input PATH --output PATH \\ [--per-channel | --no-per-channel] \\ [--reduce-range] \\ [--weight-type qint8|quint8] \\ [--op-types MatMul,Gemm] \\ [--exclude-pattern REGEX] \\ [--exclude-nodes NODE1,NODE2] Examples: # Quantise the NAR editor with defaults. python quantise.py \\ --input exports/granite-speech-4.1-2b-nar/editor.onnx \\ --output exports/granite-speech-4.1-2b-nar/editor_int8.onnx # Skip the lm_head MatMul if it hurts parity. python quantise.py \\ --input exports/granite-speech-4.1-2b-nar/editor.onnx \\ --output exports/granite-speech-4.1-2b-nar/editor_int8.onnx \\ --exclude-nodes /lm_head/MatMul """ from __future__ import annotations import argparse import re import sys import tempfile import time from pathlib import Path import onnx from onnxruntime.quantization import QuantType, quantize_dynamic WEIGHT_TYPE_MAP = { "qint8": QuantType.QInt8, "quint8": QuantType.QUInt8, } def parse_args(argv: list[str] | None = None) -> argparse.Namespace: p = argparse.ArgumentParser( description="Dynamic INT8 (weights-only) ONNX quantiser for Granite Speech 4.1 graphs.", ) p.add_argument( "--input", required=True, type=Path, help="Path to the FP32 .onnx graph (external sidecar must sit alongside it).", ) p.add_argument( "--output", required=True, type=Path, help="Destination .onnx path. A single sidecar named _data is written next to it.", ) p.add_argument( "--per-channel", dest="per_channel", action="store_true", default=True, help="Quantise weights per output channel (default: on).", ) p.add_argument( "--no-per-channel", dest="per_channel", action="store_false", help="Disable per-channel quantisation.", ) p.add_argument( "--reduce-range", action="store_true", default=False, help="Quantise to 7 bits instead of 8. Improves accuracy on non-VNNI hardware " "but reduces the quantisation gain. Off by default.", ) p.add_argument( "--weight-type", choices=sorted(WEIGHT_TYPE_MAP.keys()), default="qint8", help="Weight quantisation dtype (default: qint8).", ) p.add_argument( "--op-types", default="MatMul", help=( "Comma-separated op types to quantise. Default: 'MatMul' (emits " "MatMulInteger only, all ai.onnx). Adding 'Conv' enables ConvInteger " "for the Conformer encoder's depthwise convolutions; this shrinks the " "encoder INT8 sidecar by ~40 percent but on this model family feeds " "enough weight-quant noise into the LLM head that it flips " "capitalisation and drops sentence-final punctuation on short clips - " "see task 17 in dev-plan.md. MatMul-only is the validated default. " "Adding 'Attention' or 'EmbedLayerNormalization' would introduce " "com.microsoft domain ops, which are forbidden by the parakeet-rs " "contract." ), ) p.add_argument( "--exclude-pattern", default=None, help="Regex applied to ONNX node names. Matching nodes are excluded from " "quantisation. Useful for skipping e.g. lm_head if its quantisation " "breaks parity.", ) p.add_argument( "--exclude-nodes", default="", help="Explicit comma-separated list of node names to exclude from quantisation.", ) p.add_argument( "--ir-version", type=int, default=10, help="ONNX IR version to write (default: 10, matches the FP32 exports).", ) return p.parse_args(argv) def collect_excluded_nodes( input_path: Path, exclude_pattern: str | None, exclude_nodes: list[str], ) -> list[str]: """Resolve --exclude-pattern against the FP32 graph's node names and merge with the explicit --exclude-nodes list. Loaded without external data so we only touch the small graph proto. """ excluded = set(n for n in exclude_nodes if n) if exclude_pattern: rx = re.compile(exclude_pattern) proto = onnx.load(str(input_path), load_external_data=False) for node in proto.graph.node: if node.name and rx.search(node.name): excluded.add(node.name) return sorted(excluded) def assert_pure_ai_onnx(model_path: Path) -> list[str]: """Reload the produced graph and verify only `ai.onnx` nodes are present. Returns the sorted list of domains for reporting. """ proto = onnx.load(str(model_path), load_external_data=False) domains = sorted({(n.domain or "ai.onnx") for n in proto.graph.node}) forbidden = [d for d in domains if d not in ("ai.onnx", "")] if forbidden: raise RuntimeError( f"Quantised graph contains forbidden op domains {forbidden}. " "Re-run with a narrower --op-types list." ) return domains def consolidate_single_sidecar( quantised_in: Path, final_out: Path, ir_version: int, ) -> None: """The dynamic quantiser may scatter weights across multiple external-data files. Reload + resave through a tempdir to land on the single-sidecar layout that matches the FP32 exports. """ print(" consolidating to single .onnx_data sidecar") proto = onnx.load(str(quantised_in), load_external_data=True) if proto.ir_version < ir_version: proto.ir_version = ir_version for tensor in proto.graph.initializer: tensor.ClearField("data_location") tensor.ClearField("external_data") sidecar_name = final_out.name + "_data" if (final_out.parent / sidecar_name).exists(): (final_out.parent / sidecar_name).unlink() if final_out.exists(): final_out.unlink() final_out.parent.mkdir(parents=True, exist_ok=True) onnx.save_model( proto, str(final_out), save_as_external_data=True, all_tensors_to_one_file=True, location=sidecar_name, size_threshold=1024, convert_attribute=False, ) onnx.checker.check_model(str(final_out), full_check=False) def quantise_graph(args: argparse.Namespace) -> None: input_path: Path = args.input.resolve() output_path: Path = args.output.resolve() if not input_path.exists(): raise SystemExit(f"input not found: {input_path}") op_types = [s.strip() for s in args.op_types.split(",") if s.strip()] explicit_excludes = [s.strip() for s in args.exclude_nodes.split(",") if s.strip()] excluded = collect_excluded_nodes(input_path, args.exclude_pattern, explicit_excludes) weight_type = WEIGHT_TYPE_MAP[args.weight_type] print(f"input: {input_path}") print(f"output: {output_path}") print(f"op_types: {op_types}") print(f"per_channel: {args.per_channel}") print(f"reduce_range: {args.reduce_range}") print(f"weight_type: {args.weight_type}") if excluded: print(f"excluded nodes ({len(excluded)}): {excluded}") else: print("excluded nodes: (none)") fp32_size = input_path.stat().st_size sidecar = input_path.with_name(input_path.name + "_data") fp32_data_size = sidecar.stat().st_size if sidecar.exists() else 0 print( f" fp32 graph={fp32_size / 1e6:.2f} MB " f"sidecar={fp32_data_size / 1e9:.2f} GB" ) with tempfile.TemporaryDirectory(prefix="quantise_int8_") as scratch_dir: scratch_path = Path(scratch_dir) / output_path.name t0 = time.time() quantize_dynamic( model_input=input_path, model_output=scratch_path, op_types_to_quantize=op_types, per_channel=args.per_channel, reduce_range=args.reduce_range, weight_type=weight_type, nodes_to_exclude=excluded or None, use_external_data_format=True, ) print(f" quantize_dynamic done in {time.time() - t0:.1f}s") # Stage 2: consolidate any scattered external-data files into a single # sidecar at the final destination. consolidate_single_sidecar(scratch_path, output_path, args.ir_version) # Verify pure ai.onnx after the move. domains = assert_pure_ai_onnx(output_path) int8_size = output_path.stat().st_size int8_data = output_path.with_name(output_path.name + "_data") int8_data_size = int8_data.stat().st_size if int8_data.exists() else 0 print( f" saved {output_path} (+ {int8_data.name}) " f"graph={int8_size / 1e6:.2f} MB sidecar={int8_data_size / 1e9:.2f} GB" ) print(f" node-domains={domains}") if fp32_data_size > 0: ratio = int8_data_size / fp32_data_size print(f" sidecar size ratio (int8 / fp32) = {ratio:.3f}") def main(argv: list[str] | None = None) -> None: args = parse_args(argv) try: quantise_graph(args) except RuntimeError as exc: print(f"FAIL: {exc}", file=sys.stderr) raise SystemExit(2) from exc if __name__ == "__main__": main()