parakeet-tdt-0.6b-v3-smoothquant-onnx / scripts /quantize-int8-smoothquant.py
thiswillbeyourgithub
smoothquant: fix istupakov int8 size note comparing output to itself
72a6e63
Raw
History Blame
30.9 kB
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "onnx",
# "onnxruntime",
# "onnx-neural-compressor",
# "numpy",
# "sympy",
# "prettytable",
# "psutil",
# "scipy",
# "loguru",
# ]
# ///
"""Export a *better* int8 Parakeet encoder using SmoothQuant static quantization.
Why this exists (see CLAUDE.md / ARCHITECTURE.md for the full story): the int8
encoder we currently ship (istupakov's) silently loses long-range information
past ~20 s within a single chunk, so the WASM backend is pinned to a 20 s chunk
window while fp16/fp32 happily run 60 s. Crucially the model architecture is NOT
the problem (fp16 holds flat at long windows); it is an int8 *numerics* problem:
a single per-tensor activation scale copes badly once a longer sequence widens
the activation distribution. That is exactly the regime SmoothQuant targets: it
migrates the per-channel activation outliers into the weights (a folded Mul),
then static-quantizes activations + per-channel weights. The bet is that a
SmoothQuant + per-channel int8 encoder degrades far less over a long chunk,
which would let WASM use the full 60 s window.
This produces ONLY the encoder int8 (`encoder-model.int8.smoothquant.onnx`). The
decoder is tiny and is not where the long-range loss lives, so we deliberately
reuse istupakov's existing `decoder_joint-model.int8.onnx`; that isolates the
comparison to the encoder change.
Calibration data: SmoothQuant needs representative *activations*, not labels, so
any speech works. Pass clips or folders with --audio (a folder is expanded to
every audio file inside it); by default the ./calibration_audio folder is used.
Each clip is sliced into deliberately LONG windows (default 30 s) so the smoothing
scales are computed over the very long-range distribution we are trying to fix.
The encoder takes mel features, not raw audio, so each window is first run through
the `nemo128.onnx` preprocessor (raw waveform -> 128-bin mel features) and those
features are fed to the encoder, exactly as the real pipeline does.
After export, compare against fp16 with the per-section WER harness (wer-quants.py
lives in the parakeet_web repo, not this model repo:
https://github.com/thiswillbeyourgithub/parakeet_web):
# the NEW SmoothQuant int8 (served from the symlinked candidate dir):
uv run scripts/wer-quants.py --model-dir candidates --quants int8
# the OLD istupakov int8 + the fp16 reference, for the baseline:
uv run scripts/wer-quants.py --model-dir . --quants int8,fp16
Both use the same fp32 oracle reference, so a per-section WER that rises less
steeply for the new int8 (closer to fp16) is the win we are after. This script
prints those two commands at the end and, unless --no-candidate is passed, builds
the `candidates` symlink farm they need.
By default only MatMul ops are quantized (the conv subsampling front-end stays
fp32: it is quant-fragile and collapsed the encoder when quantized) and
activations are calibrated with the Percentile method (MinMax let a single
long-tail outlier crush the scale). A post-export fidelity check compares the new
encoder's output to the fp32 encoder by cosine similarity and warns loudly on a
likely collapse, instead of only checking output shape.
Alpha defaults to "auto": the smoother searches a per-layer optimal alpha over a
[0,1] grid (minimising each layer's QDQ error vs fp32) instead of forcing one
global value onto FastConformer's uneven outlier profile. Same artifact size and
runtime as a fixed alpha, just better-placed smoothing; the only cost is a slower
export. Pass an explicit float to pin it (e.g. --alpha 0.5) and skip the search.
Usage (run from this model-repo root; the onnx files default to .):
uv run scripts/quantize-int8-smoothquant.py # auto-alpha, ./calibration_audio
uv run scripts/quantize-int8-smoothquant.py --alpha 0.6 # pin a fixed alpha instead
uv run scripts/quantize-int8-smoothquant.py --num-windows 32 --window-sec 30
uv run scripts/quantize-int8-smoothquant.py --audio clips/ --audio extra.wav # folder and/or files
uv run scripts/quantize-int8-smoothquant.py --op-types MatMul,Conv # also quantize convs
uv run scripts/quantize-int8-smoothquant.py --calibrate-method entropy
uv run scripts/quantize-int8-smoothquant.py --quant-format qdq
Built with Claude Code.
"""
import argparse
import os
import shutil
import subprocess
import sys
import time
from pathlib import Path
import numpy as np
import onnx
import onnxruntime as ort
from loguru import logger
from onnxruntime.quantization import CalibrationMethod, QuantFormat, QuantType
from onnx_neural_compressor import data_reader
from onnx_neural_compressor.quantization import config, quantize
from onnx_neural_compressor.algorithms.smoother import core as _sq_core
def configure_logging(log_file):
"""Send logs to stderr (human, colored) AND a local file (full, timestamped).
Called once at the start of main() so every message after arg-parsing lands
in both places. The file sink keeps a permanent record of an export run
(calibration windows, alpha, fidelity cosine, the wer-quants commands) so a
later A/B can be traced without re-running."""
logger.remove()
logger.add(
sys.stderr,
level="INFO",
format="<green>{time:HH:mm:ss}</green> | <level>{level: <7}</level> | <level>{message}</level>",
)
logger.add(
str(log_file),
level="DEBUG",
format="{time:YYYY-MM-DD HH:mm:ss} | {level: <7} | {message}",
mode="a", # append: one fixed file accumulates every run
enqueue=True,
)
logger.info(f"[sq] logging to {log_file}")
# --- FastConformer compatibility shim for onnx-neural-compressor's SmoothQuant -
# The library's smoother hard-assumes a 3D activation is (batch, seq, in_channel)
# with the in-channel LAST (there is a literal TODO admitting this in
# Calibrator._get_max_per_channel). That holds for BERT-style graphs but NOT for
# a few FastConformer MatMuls (the relative-position attention projections, where
# the weight is the first operand and the activation contracts over the sequence
# axis). For those, the per-channel activation max is taken over the wrong axis
# and no longer matches the weight's in-channel length, so _get_smooth_scale dies
# broadcasting e.g. (101,) against (2048,).
#
# These two wrappers make the smoother SKIP exactly those unresolvable nodes
# (return None -> stripped before any Mul is inserted) instead of crashing. All
# the well-behaved linears (FFN, standard projections, the bulk of the weights)
# are still smoothed; the skipped handful simply fall through to plain static
# int8. _insert_smooth_mul_op iterates scales.keys() and _adjust_weights guards
# with `if key not in scales`, so omitting a node is safe. NOTE: this monkeypatch
# reaches into library internals and may need revisiting on a neural-compressor
# upgrade; it is contained to this experimental export script.
_SKIPPED = {"count": 0}
_orig_get_smooth_scale = _sq_core.Smoother._get_smooth_scale
_orig_get_smooth_scales = _sq_core.Smoother._get_smooth_scales
def _safe_get_smooth_scale(self, weights, specific_alpha, tensor):
weights_max = np.amax(np.abs(weights.reshape(weights.shape[0], -1)), axis=-1)
if self.max_vals_per_channel[tensor].shape != weights_max.shape:
_SKIPPED["count"] += 1
return None # layout the per-channel logic can't resolve: don't smooth it
return _orig_get_smooth_scale(self, weights, specific_alpha, tensor)
def _safe_get_smooth_scales(self, alpha, target_list=[]):
scales = _orig_get_smooth_scales(self, alpha, target_list)
return {k: v for k, v in scales.items() if v is not None}
_sq_core.Smoother._get_smooth_scale = _safe_get_smooth_scale
_sq_core.Smoother._get_smooth_scales = _safe_get_smooth_scales
# --- alpha (and auto-alpha / folding) pass-through fix ------------------------
# onnx-neural-compressor 1.0 SILENTLY DROPS the SmoothQuant* knobs you set in
# StaticQuantConfig.extra_options. quantize() hands the StaticQuantConfig to
# smooth_quant_entry(), which smooths via `Smoother.transform(**config.to_dict())`
# -- but StaticQuantConfig.to_dict() buries SmoothQuantAlpha / AutoAlphaArgs /
# SmoothQuantFolding inside a nested "extra_options" dict and never emits the
# top-level `alpha` / `auto_alpha_args` / `folding` kwargs that transform() reads.
# So transform() ALWAYS falls back to its hard-coded defaults (alpha=0.5, the
# [0.3,0.7] auto grid) regardless of --alpha. The intended path is the dedicated
# SmoothQuantConfig (whose to_dict() does surface those names), but quantize()
# never builds one. We wrap transform() to inject the values from extra_options
# under the names transform() expects. main() only ARMS this (populates
# _SMOOTH_OVERRIDE) after a runtime guard confirms the bug is still present, so a
# future fixed library that forwards alpha itself transparently disables the shim.
# Like the layout shim above, this reaches into library internals and may need
# revisiting on a neural-compressor upgrade; contained to this export script.
_SMOOTH_OVERRIDE = {}
_orig_transform = _sq_core.Smoother.transform
def _forced_transform(self, *args, **kwargs):
if _SMOOTH_OVERRIDE:
kwargs.update(_SMOOTH_OVERRIDE)
return _orig_transform(self, *args, **kwargs)
_sq_core.Smoother.transform = _forced_transform
# Audio file extensions recognised when an --audio entry is a folder (or the
# default ./calibration_audio folder is scanned). SmoothQuant calibrates on
# activations (not labels), so any speech clip works.
AUDIO_EXTS = {".wav", ".mp3", ".flac", ".m4a", ".aac", ".ogg", ".opus", ".wma"}
# Default calibration source when --audio is omitted: a folder you drop clips in.
DEFAULT_CALIB_DIR = "calibration_audio"
SAMPLE_RATE = 16000
# Upstream istupakov int8 encoder size, for the post-export download-size note.
# Measured from HF on 2026-06-09:
# istupakov/parakeet-tdt-0.6b-v3-onnx / encoder-model.int8.onnx = 652,183,999 B.
# Hardcoded because this script's own output usually overwrites that filename in
# the model dir (when --out-name is the canonical encoder-model.int8.onnx), so the
# on-disk copy can't be stat'd as a baseline without reading our own output back.
# NOTE: istupakov also quantizes the convs (--op-types MatMul,Conv), which is why
# their encoder is smaller than this script's MatMul-only default.
ISTUPAKOV_INT8_ENCODER_BYTES = 652_183_999
def expand_audio(inputs):
"""Resolve --audio entries (files and/or folders) to a flat list of audio files.
A directory entry is expanded to every audio file (by extension) directly
inside it, so you can keep all calibration clips in one folder and just point
--audio at the folder. File entries are passed through untouched (their
existence is checked later in collect_windows)."""
out = []
for entry in inputs:
p = Path(entry)
if p.is_dir():
out.extend(sorted(c for c in p.iterdir()
if c.is_file() and c.suffix.lower() in AUDIO_EXTS))
else:
out.append(p)
return out
def human(n):
for unit in ("B", "KB", "MB", "GB"):
if n < 1024 or unit == "GB":
return f"{n:.1f} {unit}"
n /= 1024
def find_ffmpeg(explicit=None):
cand = explicit or os.environ.get("FFMPEG") or shutil.which("ffmpeg")
if not cand or not shutil.which(cand) and not os.path.exists(cand):
logger.error("ffmpeg not found (set $FFMPEG or pass --ffmpeg).")
sys.exit(1)
return cand
def decode_pcm(ffmpeg, path):
"""Decode any audio file to mono 16 kHz float32 PCM via ffmpeg."""
cmd = [ffmpeg, "-v", "error", "-i", str(path),
"-f", "f32le", "-ac", "1", "-ar", str(SAMPLE_RATE), "-"]
out = subprocess.run(cmd, capture_output=True)
if out.returncode != 0:
raise RuntimeError(f"ffmpeg failed on {path}: {out.stderr.decode()[-300:]}")
return np.frombuffer(out.stdout, dtype=np.float32)
def collect_windows(ffmpeg, audio_paths, window_sec, num_windows):
"""Slice every available clip into non-overlapping FULL-length windows, then
evenly subsample down to num_windows so calibration stays quick but diverse.
All windows are exactly `win` samples long on purpose: SmoothQuant's
calibrator np.stacks the per-op activations across calibration samples, so a
variable-length tail window (different T -> different activation shape) makes
it raise 'all input arrays must have the same shape'. We therefore drop any
partial tail rather than pad it."""
win = int(window_sec * SAMPLE_RATE)
windows = []
for p in audio_paths:
if not Path(p).exists():
continue
pcm = decode_pcm(ffmpeg, p)
n = len(pcm)
count = 0
start = 0
while start + win <= n:
windows.append(pcm[start:start + win])
start += win
count += 1
logger.info(f" [calib] {Path(p).name}: {n / SAMPLE_RATE:.0f}s -> {count} full window(s)")
if not windows:
logger.error(f"No calibration audio yielded a full {window_sec:g}s window. "
f"Drop clips in ./{DEFAULT_CALIB_DIR}/, pass --audio <file-or-folder>, "
"or lower --window-sec.")
sys.exit(1)
if len(windows) > num_windows:
# Even stride across the whole pool for speaker/content diversity.
idx = np.linspace(0, len(windows) - 1, num_windows).round().astype(int)
windows = [windows[i] for i in dict.fromkeys(idx)]
return windows
def build_features(pre_path, windows):
"""Run each raw-audio window through nemo128.onnx -> encoder mel features.
Precomputed once into memory so the calibration reader can rewind cheaply
(SmoothQuant + the static min/max + calibration passes each re-read it)."""
sess = ort.InferenceSession(str(pre_path), providers=["CPUExecutionProvider"])
feats = []
for w in windows:
wav = w.astype(np.float32)[None, :]
lens = np.array([wav.shape[1]], dtype=np.int64)
features, features_lens = sess.run(None, {"waveforms": wav, "waveforms_lens": lens})
feats.append({
"audio_signal": features.astype(np.float32),
"length": features_lens.astype(np.int64),
})
return feats
class FeatureReader(data_reader.CalibrationDataReader):
"""Feeds the encoder its real (audio_signal, length) inputs for calibration."""
def __init__(self, feats):
self.feats = feats
self.i = 0
def get_next(self):
if self.i >= len(self.feats):
return None
item = self.feats[self.i]
self.i += 1
return item
def rewind(self):
self.i = 0
def prune_unused_initializers(graph):
"""Drop top-level initializers that no node references, returning the count.
SmoothQuant folding orphans the per-channel `*_smooth_scale` tensors: with
SmoothQuantFolding the smoothing Mul is folded into the adjacent weights,
leaving the standalone scale initializer unreferenced. ORT prunes these at
session-load time (the noisy `CleanUnusedInitializersAndNodeArgs` warnings);
doing it here bakes the cleanup into the saved file (quieter logs, a few KB
smaller, identical numerics). Recurses into subgraph attributes (If/Loop/Scan
bodies) so an initializer used only via an outer-scope reference is kept."""
used = set()
def walk(g):
for node in g.node:
used.update(node.input)
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
walk(attr.g)
elif attr.type == onnx.AttributeProto.GRAPHS:
for sub in attr.graphs:
walk(sub)
walk(graph)
keep = [init for init in graph.initializer if init.name in used]
removed = len(graph.initializer) - len(keep)
if removed:
del graph.initializer[:]
graph.initializer.extend(keep)
return removed
def build_candidate_dir(model_dir, new_encoder, candidate_dir):
"""Symlink-farm a model dir where encoder-model.int8.onnx IS the new encoder,
so wer-quants.py (which loads int8 by that canonical name via onnx-asr) serves
the SmoothQuant encoder while reusing every other unchanged file."""
model_dir = Path(model_dir).resolve()
candidate_dir = Path(candidate_dir).resolve()
candidate_dir.mkdir(parents=True, exist_ok=True)
for f in model_dir.iterdir():
if f.is_dir():
continue
link = candidate_dir / f.name
if link.is_symlink() or link.exists():
link.unlink()
link.symlink_to(f.resolve())
# Override the int8 encoder to point at the freshly exported SmoothQuant file.
enc_link = candidate_dir / "encoder-model.int8.onnx"
if enc_link.is_symlink() or enc_link.exists():
enc_link.unlink()
enc_link.symlink_to(Path(new_encoder).resolve())
return candidate_dir
def main():
ap = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--model-dir", default=".",
help="dir holding encoder-model.onnx (+.data) and nemo128.onnx "
"(default: current directory)")
ap.add_argument("--out-name", default="encoder-model.int8.smoothquant.onnx",
help="output filename (written into --model-dir)")
ap.add_argument("--candidate-dir", default="candidates",
help="symlink-farm dir wer-quants.py points at for the new int8 "
"(default: ./candidates)")
ap.add_argument("--no-candidate", action="store_true",
help="skip building the wer-quants candidate symlink dir")
ap.add_argument("--alpha", default="auto",
help="SmoothQuant alpha. Either a float in 0..1 (higher migrates more "
"difficulty to the weights, better for big activation outliers) or "
"'auto' (the default): the library searches a per-layer optimal alpha "
"over [--auto-alpha-min, --auto-alpha-max] in --auto-alpha-step steps, "
"minimising each layer's output error vs fp32. Slower to export but the "
"shipped model is the same size/speed with better-placed smoothing, "
"which is exactly what FastConformer's uneven outlier profile wants.")
ap.add_argument("--auto-alpha-min", type=float, default=0.0,
help="auto-alpha search: lowest alpha tried (only used when --alpha=auto)")
ap.add_argument("--auto-alpha-max", type=float, default=1.0,
help="auto-alpha search: highest alpha tried (only used when --alpha=auto)")
ap.add_argument("--auto-alpha-step", type=float, default=0.1,
help="auto-alpha search: grid step (only used when --alpha=auto)")
ap.add_argument("--num-windows", type=int, default=24,
help="max calibration windows (evenly sampled across all audio)")
ap.add_argument("--window-sec", type=float, default=30.0,
help="calibration window length; long on purpose (the bug is long-range)")
ap.add_argument("--audio", action="append", default=None,
help="calibration audio file OR folder; repeatable. A folder is "
f"expanded to every audio file inside it. Default: ./{DEFAULT_CALIB_DIR}")
ap.add_argument("--quant-format", choices=["qoperator", "qdq"], default="qoperator",
help="QOperator (QLinear* ops, matches the shipped int8) or QDQ")
ap.add_argument("--op-types", default="MatMul",
help="comma-separated op types to quantize. Default MatMul ONLY: the "
"conv subsampling front-end is quant-fragile and is the prime suspect "
"for a collapsed encoder, so convs stay fp32. Pass 'MatMul,Conv' to "
"also quantize convs (matches istupakov's scope).")
ap.add_argument("--calibrate-method", choices=["minmax", "entropy", "percentile"],
default="percentile",
help="static activation calibration. MinMax (the library default) lets a "
"single long-tail outlier crush the scale and can collapse the encoder; "
"percentile/entropy clip the tail and are far more robust here.")
ap.add_argument("--fidelity-warn", type=float, default=0.90,
help="cosine-similarity floor (vs the fp32 encoder, one window) below which "
"the export is flagged as a likely collapse before any WER run. This is "
"a COLLAPSE detector, not a quality score: a healthy MatMul-only export "
"measured ~0.96 cosine yet tracked fp16 WER (10.9%% vs 10.2%%), so the "
"floor sits well below that. A true collapse lands far lower.")
ap.add_argument("--ffmpeg", default=None, help="ffmpeg binary (else $FFMPEG / PATH)")
ap.add_argument("--log-file", default="quantize-int8-smoothquant.log",
help="path for the run log file, appended to across runs (logs always "
"also go to stderr). Default: ./quantize-int8-smoothquant.log")
args = ap.parse_args()
configure_logging(args.log_file)
model_dir = Path(args.model_dir)
in_encoder = model_dir / "encoder-model.onnx"
pre_path = model_dir / "nemo128.onnx"
out_encoder = model_dir / args.out_name
for p in (in_encoder, pre_path):
if not p.exists():
logger.error(f"missing required file: {p}")
sys.exit(1)
# alpha is either the literal "auto" (per-layer search) or a fixed float.
if str(args.alpha).strip().lower() == "auto":
alpha = "auto"
else:
try:
alpha = float(args.alpha)
except ValueError:
logger.error(f"--alpha must be a float in 0..1 or 'auto', got {args.alpha!r}")
sys.exit(1)
ffmpeg = find_ffmpeg(args.ffmpeg)
audio = expand_audio(args.audio or [DEFAULT_CALIB_DIR])
logger.info(f"[sq] calibration: up to {args.num_windows} x {args.window_sec:g}s windows")
windows = collect_windows(ffmpeg, audio, args.window_sec, args.num_windows)
logger.info(f"[sq] using {len(windows)} calibration window(s); extracting mel features...")
feats = build_features(pre_path, windows)
fmt = QuantFormat.QOperator if args.quant_format == "qoperator" else QuantFormat.QDQ
calib = {"minmax": CalibrationMethod.MinMax,
"entropy": CalibrationMethod.Entropy,
"percentile": CalibrationMethod.Percentile}[args.calibrate_method]
op_types = [t.strip() for t in args.op_types.split(",") if t.strip()]
cfg = config.StaticQuantConfig(
calibration_data_reader=FeatureReader(feats),
quant_format=fmt,
calibrate_method=calib,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8,
# Which weight-bearing ops to quantize. Default is MatMul ONLY: the conv
# subsampling front-end (pre_encode.*) sees the raw mel features with a
# wide dynamic range and is notoriously quant-fragile; statically
# quantizing it can produce garbage that propagates and empties the
# transcript, so we leave all convs fp32 (the user is fine trading the
# extra size for safety). MatMul-only also dodges the static quantizer's
# Pad handler, which trips on FastConformer's optional/empty Pad inputs
# ("Quantization parameters are not specified for param .").
op_types_to_quantize=op_types,
per_channel=True, # the other half of the fix: per-channel weights
reduce_range=True, # recommended on non-VNNI CPUs (the WASM target)
use_external_data_format=False, # int8 encoder ~600 MB, fits a single file
calibration_sampling_size=len(feats),
execution_provider="CPUExecutionProvider",
extra_options={
"SmoothQuant": True,
# alpha="auto" makes the smoother search a per-layer optimal alpha
# (minimising each layer's QDQ output error vs fp32) instead of forcing
# one global value onto FastConformer's very uneven outlier profile.
"SmoothQuantAlpha": alpha,
"SmoothQuantFolding": True,
# Only consulted when alpha=="auto"; ignored for a fixed float.
"AutoAlphaArgs": {
"alpha_min": args.auto_alpha_min,
"alpha_max": args.auto_alpha_max,
"alpha_step": args.auto_alpha_step,
"attn_method": "max",
},
},
)
# Arm the alpha pass-through shim (see _forced_transform): translate the
# SmoothQuant* keys we set in extra_options into the top-level transform()
# kwargs the library forgets to forward. Guard: if a future neural-compressor
# fixes this, cfg.to_dict() will surface a top-level "alpha" on its own, so we
# leave the shim disarmed (and say so) rather than double-driving transform().
eo = cfg.extra_options
if "alpha" in cfg.to_dict():
logger.info("[sq] note: onnx-neural-compressor now forwards SmoothQuantAlpha "
"itself; the alpha pass-through shim is obsolete and stays disarmed.")
else:
_SMOOTH_OVERRIDE.update({
"alpha": eo["SmoothQuantAlpha"],
"folding": eo["SmoothQuantFolding"],
"auto_alpha_args": eo["AutoAlphaArgs"],
})
logger.info(f"[sq] SmoothQuant(alpha={alpha}) static int8, per-channel, "
f"calib={args.calibrate_method}, ops={op_types}, format={args.quant_format} ...")
logger.info(f"[sq] {human(os.path.getsize(in_encoder) + os.path.getsize(str(in_encoder) + '.data'))} fp32 encoder")
t0 = time.time()
# ORT_DISABLE_ALL skips neural-compressor's pre-optimization InferenceSession
# (which has a `provides=` kwarg typo that crashes on this version) and avoids
# re-serializing the 2.4 GB fp32 graph.
quantize(str(in_encoder), str(out_encoder), cfg,
optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL)
dt = time.time() - t0
if _SKIPPED["count"]:
logger.info(f"[sq] note: {_SKIPPED['count']} node(s) had a layout SmoothQuant could not "
f"resolve and were left as plain static int8 (everything else was smoothed)")
# neural-compressor always writes the quantized weights to an external
# `<name>_data` sidecar for a model this size, ignoring use_external_data_format.
# The int8 weights are ~620 MB, well under the 2 GB single-protobuf cap, so
# fold them back into ONE self-contained .onnx (matching the shipped
# single-file int8 and keeping the candidate symlink dir trivial). While the
# graph is loaded we also strip the orphaned folded-smooth-scale initializers
# (see prune_unused_initializers) so the saved file no longer logs ORT's
# CleanUnusedInitializersAndNodeArgs warnings on every load.
sidecar = str(out_encoder) + "_data"
merged = onnx.load(str(out_encoder), load_external_data=True)
pruned = prune_unused_initializers(merged.graph)
onnx.save(merged, str(out_encoder), save_as_external_data=False)
if os.path.exists(sidecar):
os.remove(sidecar)
if pruned:
logger.info(f"[sq] pruned {pruned} orphaned initializer(s) (folded smooth scales)")
out_size = os.path.getsize(out_encoder)
# Download-size comparison vs the upstream istupakov int8 encoder. Only stat an
# on-disk istupakov file when it is a DIFFERENT path than our output: when
# --out-name is the canonical encoder-model.int8.onnx, out_encoder overwrites
# that file, so stat'ing it would read our own output back and print the
# tautology "X (istupakov int8 is X)". Otherwise fall back to the HF size.
baseline = model_dir / "encoder-model.int8.onnx"
if baseline.exists() and baseline.resolve() != out_encoder.resolve():
base_bytes = os.path.getsize(baseline)
else:
base_bytes = ISTUPAKOV_INT8_ENCODER_BYTES
base_note = f" (istupakov int8 is {human(base_bytes)})"
logger.info(f"[sq] done in {dt:.0f}s -> {out_encoder.name} {human(out_size)}{base_note}")
# Fidelity smoke test (NOT just shape): run one calibration window through both
# the fp32 reference and the new int8 encoder and compare the encoder outputs by
# cosine similarity. A shape-only check let a fully collapsed encoder (empty
# transcript everywhere) pass silently once; this catches that in ~30 s instead
# of after a multi-minute WER run. A healthy int8 sits well above ~0.99.
try:
inp = {"audio_signal": feats[0]["audio_signal"], "length": feats[0]["length"]}
s_q = ort.InferenceSession(str(out_encoder), providers=["CPUExecutionProvider"])
out_q = s_q.run(None, inp)[0].astype(np.float64).ravel()
s_f = ort.InferenceSession(str(in_encoder), providers=["CPUExecutionProvider"])
out_f = s_f.run(None, inp)[0].astype(np.float64).ravel()
denom = (np.linalg.norm(out_q) * np.linalg.norm(out_f)) or 1.0
cos = float(np.dot(out_q, out_f) / denom)
if cos < args.fidelity_warn:
logger.warning(f"[sq] encoder-output cosine vs fp32 is {cos:.4f} "
f"(< {args.fidelity_warn}). This export likely COLLAPSED; expect a near-100% "
f"WER. Try a different --calibrate-method/--alpha or keep more ops fp32.")
else:
logger.info(f"[sq] fidelity: encoder-output cosine vs fp32 = {cos:.4f} (>= "
f"{args.fidelity_warn}). Looks healthy.")
except Exception as e:
logger.warning(f"[sq] exported encoder failed the fidelity smoke test: {e}")
if not args.no_candidate:
cand = build_candidate_dir(model_dir, out_encoder, args.candidate_dir)
logger.info(f"[sq] candidate model dir (for wer-quants): {cand}")
rel_cand = os.path.relpath(args.candidate_dir)
rel_model = os.path.relpath(model_dir)
logger.info("Compare per-section degradation vs fp16 (wer-quants.py lives in the "
"parakeet_web repo: https://github.com/thiswillbeyourgithub/parakeet_web):")
logger.info(f" uv run scripts/wer-quants.py --model-dir {rel_cand} --quants int8")
logger.info(f" uv run scripts/wer-quants.py --model-dir {rel_model} --quants int8,fp16")
logger.info("A new-int8 per-section WER that tracks fp16 (instead of climbing) is the win.")
if __name__ == "__main__":
main()