# Copyright 2026 Sam McLeod # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Export Granite Speech 4.1 2b (autoregressive variants) to three ONNX graphs. Covers both `granite-speech-4.1-2b` (base) and `granite-speech-4.1-2b-plus`. The two share architecture - Conformer encoder + Blip2 Q-Former projector + Granite-4.0 1B causal LM with `logits_scaling=8` - and only differ in weights and chat template. Pass `--model-dir` and `--baseline` to select the variant. The NAR variant has a different topology and is exported by the `export_nar_*.py` scripts instead. Produces, under the configured `--out-dir`: - encoder.onnx : Conformer CTC encoder + Blip2 Q-Former projector. Input: input_features float32 [B, T, 160] Output: audio_embeds float32 [B, T_audio, 2048] audio_embed_sizes int64 [B] (per-sample valid lengths) - prompt_encode.onnx : LLM prefill over a fully spliced inputs_embeds. Inputs : inputs_embeds float32 [B, N, 2048] position_ids int64 [B, N] attention_mask float32 [B, 1, N, N] (additive) Outputs: logits float32 [B, N, V] (divided by 8) present..{key,value} for L in 0..39 - decode_step.onnx : Single-token decode with KV cache. Inputs : inputs_embeds float32 [B, 1, 2048] position_ids int64 [B, 1] attention_mask float32 [B, 1, 1, T_total] (additive) past_key_values..{key,value} for L in 0..39 Outputs: logits float32 [B, 1, V] (divided by 8) present..{key,value} The base/plus projector is `Blip2QFormerModel`, not the NAR custom projector. Q-Former self-attention is plain matmul-softmax already (Bert-style); only the Conformer encoder's SDPA + `if remainder > 0` guard need rewriting for clean tracing. Both LLM graphs apply `logits / config.text_config.logits_scaling` (=8). This matches `GraniteForCausalLM.forward`, which the reference autoregressive path goes through. Without it, ONNX logits are 8x the PyTorch reference even though argmax is preserved, which trips strict numeric parity bars. Usage: # Base 2b (defaults): HF_HOME=$TMPDIR/hf_home HF_MODULES_CACHE=$TMPDIR/hf_modules \\ uv run python src/export_speech_2b_ar.py # Plus 2b: HF_HOME=$TMPDIR/hf_home HF_MODULES_CACHE=$TMPDIR/hf_modules \\ uv run python src/export_speech_2b_ar.py \\ --model-dir models/granite-speech-4.1-2b-plus \\ --baseline test_data/baselines/plus.json \\ --out-dir exports/granite-speech-4.1-2b-plus # Just one stage: uv run python src/export_speech_2b_ar.py --stages encoder uv run python src/export_speech_2b_ar.py --stages prompt,decode --skip-export """ from __future__ import annotations import argparse import json import os import tempfile import time from pathlib import Path from typing import Any import numpy as np import soundfile as sf import torch import torch.nn as nn import torch.nn.functional as F # Resolve roots so the script works whether it lives at /src/.py # (project layout) or /.py (HF bundle layout). Defaults exist for # the project layout; bundle users should pass explicit --audio / --baseline / # --model-dir / --out-dir. SCRIPT_DIR = Path(__file__).resolve().parent REPO_ROOT = SCRIPT_DIR.parent if SCRIPT_DIR.name == "src" else SCRIPT_DIR DEFAULT_AUDIO = REPO_ROOT / "test_data" / "10226_10111_000000.wav" DEFAULT_BASELINE = REPO_ROOT / "test_data" / "baselines" / "base.json" DEFAULT_MODEL_DIR = REPO_ROOT / "models" / "granite-speech-4.1-2b" DEFAULT_OUT_DIR = REPO_ROOT / "exports" / "granite-speech-4.1-2b" USER_PROMPT_TRANSCRIBE = ( "<|audio|>transcribe the speech with proper punctuation and capitalization." ) # --------------------------------------------------------------------------- # Utilities. # --------------------------------------------------------------------------- def load_audio(path: Path) -> np.ndarray: waveform, sr = sf.read(str(path), dtype="float32") if waveform.ndim > 1: waveform = waveform.mean(axis=1) assert sr == 16000, f"expected 16 kHz, got {sr}" return waveform def tensor_stats(t: torch.Tensor | np.ndarray | None) -> dict[str, Any] | None: if t is None: return None if isinstance(t, torch.Tensor): x = t.detach().float().cpu().numpy() dtype_str = str(t.dtype).replace("torch.", "") else: x = np.asarray(t).astype(np.float32, copy=False) dtype_str = str(t.dtype) flat = x.flatten() return { "shape": list(x.shape), "dtype": dtype_str, "mean": float(flat.mean()) if flat.size else None, "std": float(flat.std()) if flat.size else None, "min": float(flat.min()) if flat.size else None, "max": float(flat.max()) if flat.size else None, "first10": [float(v) for v in flat[:10]], } def _resave_single_sidecar(scratch_path: Path, out_path: Path, ir_version: int) -> None: """Stage 2 of every export: re-save with one external-data sidecar in the final location so we end up with exactly two artefacts on disk.""" import onnx print(" stage-2: re-saving with single .onnx_data sidecar + ir bump") model_proto = onnx.load(str(scratch_path), load_external_data=True) if model_proto.ir_version < ir_version: model_proto.ir_version = ir_version for tensor in model_proto.graph.initializer: tensor.ClearField("data_location") tensor.ClearField("external_data") sidecar_name = out_path.name + "_data" if (out_path.parent / sidecar_name).exists(): (out_path.parent / sidecar_name).unlink() if out_path.exists(): out_path.unlink() onnx.save_model( model_proto, str(out_path), save_as_external_data=True, all_tensors_to_one_file=True, location=sidecar_name, size_threshold=1024, convert_attribute=False, ) onnx.checker.check_model(str(out_path), full_check=False) domains = sorted({n.domain for n in model_proto.graph.node}) print(f" saved {out_path} (+ {sidecar_name}) node-domains={domains}") # --------------------------------------------------------------------------- # Model loading (mirrors capture_baselines.py::capture_base_or_plus). # --------------------------------------------------------------------------- def load_base_model(model_dir: Path) -> tuple[nn.Module, Any]: from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor print(f" loading processor from {model_dir}") processor = AutoProcessor.from_pretrained(str(model_dir)) print(f" loading model from {model_dir} (eager, fp32)") t0 = time.time() # Blip2QFormerModel does not support SDPA in transformers 5.8; eager is mandatory. model = AutoModelForSpeechSeq2Seq.from_pretrained( str(model_dir), torch_dtype=torch.float32, attn_implementation="eager", ) model.eval() # The nested text_config / encoder_config / projector_config can carry # `dtype: bfloat16`; force fp32 across the whole module tree. model = model.to(torch.float32) print(f" loaded in {time.time() - t0:.1f}s") return model, processor # --------------------------------------------------------------------------- # Trace-friendly monkey-patches for the Conformer encoder. # --------------------------------------------------------------------------- def patch_conformer_for_tracing(model: nn.Module) -> None: """Rewrite the in-tree GraniteSpeechConformerAttention.forward so it traces: - SDPA -> plain matmul/softmax. - `if remainder > 0` guard -> always-pad by `(-num_features) % context_size`. Blip2QFormerMultiHeadAttention is already plain matmul-softmax (Bert-style), so no rewrite is needed for the projector's self-attention path. The projector's outer reshape/pad math is handled separately by patch_projector_for_tracing because it bakes T_audio into the graph if left as upstream's `math.ceil(seq_len / window_size)` pattern. """ encoder = model.encoder attn0 = encoder.layers[0].attn attn_cls = type(attn0) def attn_forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> torch.Tensor: hidden_states = self.pre_norm(hidden_states) bsz, num_features, _ = hidden_states.shape # Always-pad: pad amount may be zero. Use modulo so the graph is valid # for any T at runtime. pad_amount = (-num_features) % self.context_size num_blocks = (num_features + self.context_size - 1) // self.context_size hidden_states = F.pad(hidden_states, (0, 0, 0, pad_amount)) query_states = self.to_q(hidden_states) key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) query_states = query_states.reshape( bsz, num_blocks, self.context_size, self.num_heads, -1 ).transpose(2, 3) key_states = key_states.reshape( bsz, num_blocks, self.context_size, self.num_heads, -1 ).transpose(2, 3) value_states = value_states.reshape( bsz, num_blocks, self.context_size, self.num_heads, -1 ).transpose(2, 3) # Shaw's relative positional embedding. rel_pos_emb = self.rel_pos_emb(attention_dists) # query_states: [B, M, H, C, D]; rel_pos_emb: [C, R, D] # Output: [B, M, H, C, R] pos_attn = torch.einsum( "b m h c d, c r d -> b m h c r", query_states, rel_pos_emb ) * self.scale # Plain matmul attention with the additive `pos_attn` bias inside the # softmax (matches the MATH SDPA backend numerically). attn_logits = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scale attn_logits = attn_logits + pos_attn attn_weights = torch.softmax(attn_logits, dim=-1) out = torch.matmul(attn_weights, value_states) # [B, M, H, C, D] out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) out = self.to_out(out[:, :num_features, :]) return self.dropout(out) attn_cls.forward = attn_forward def patch_projector_for_tracing(model: nn.Module) -> None: """Rewrite GraniteSpeechEncoderProjector.forward so the output time dimension (T_audio = nblocks * num_queries) stays dynamic in the exported graph. The upstream forward bakes T_audio because: 1. `seq_len = hidden_states.size(1)` is a Python int under TorchScript trace 2. `math.ceil(seq_len / self.window_size)` is Python int math, baked 3. The intermediate `.view(batch * nblocks, window_size, dim)` and final `.view(batch, nblocks * window_size // downsample_rate, -1)` both emit Reshape ops with a constant shape vector The rewrite uses `torch._shape_as_tensor` for dynamic shape access, an over-pad-then-tensor-slice idiom for the F.pad step, and `-1` for the intermediate batch*nblocks dim and the final T_audio dim. Batch is still baked at trace value (1) because reshape's target shape is a constant vector and we don't support multi-batch inference; T_audio is the audio-length-dependent dim that needs to be dynamic. """ projector = model.projector projector_cls = type(projector) def projector_forward_traceable(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size = hidden_states.shape[0] # static (B=1 at trace) dim = hidden_states.shape[2] # static encoder hidden_dim window_size = self.window_size # Dynamic seq_len via Shape op (emitted by torch._shape_as_tensor): shape_t = torch._shape_as_tensor(hidden_states) seq_len_t = shape_t[1] # 0-d int64 Tensor # nblocks * window_size = the padded length we want. nblocks_t = (seq_len_t + window_size - 1) // window_size final_len_t = nblocks_t * window_size # 0-d Tensor # Statically pad by (window_size - 1), the maximum pad ever needed, # then dynamically slice down to final_len_t. Avoids needing F.pad # with a tensor pad amount (which doesn't trace cleanly). hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, window_size - 1), "constant", 0.0 ) hidden_states = hidden_states[:, :final_len_t, :] # [B, nblocks*window_size, dim] -> [B*nblocks, window_size, dim]. # `-1` lets ONNX infer batch*nblocks from numel at runtime. hidden_states = hidden_states.reshape(-1, window_size, dim) # Build an explicit all-ones encoder_attention_mask. Without this, the # QFormer auto-creates one via `torch.ones(encoder_hidden_states.size())` # which under tracing bakes batch*nblocks at the trace input's value. # `torch.ones_like` on a slice that drops the hidden dim keeps the # mask shape dynamic ([batch*nblocks, window_size]). encoder_attention_mask = torch.ones_like(hidden_states[..., 0]) query_output = self.qformer( query_embeds=self.query, encoder_hidden_states=hidden_states, encoder_attention_mask=encoder_attention_mask, return_dict=True, ) # qf_out: [B*nblocks, num_queries, qf_hidden] qf_out = query_output.last_hidden_state qf_hidden = qf_out.shape[-1] # static qformer hidden # [B*nblocks, num_queries, qf_hidden] -> [B, T_audio, qf_hidden]. # B is baked at trace (1); T_audio (= nblocks*num_queries) is inferred. qf_out = qf_out.reshape(batch_size, -1, qf_hidden) return self.linear(qf_out) projector_cls.forward = projector_forward_traceable # --------------------------------------------------------------------------- # Encoder + projector wrapper. # --------------------------------------------------------------------------- class EncoderProjectorWrapper(nn.Module): """Wrap encoder + projector into one ONNX graph. Inputs: input_features: float32 [B, T, 160] Outputs: audio_embeds: float32 [B, T_audio, 2048] = projector(encoder(input_features)) audio_embed_sizes: int64 [B] - count of valid audio tokens per sample, replicating the feature_extractor's projection-length math on the static input shape. Notes: - The Conformer encoder itself does not consume an attention mask; the feature extractor supplies a Python-int per-sample length, which is what `audio_embed_sizes` reproduces here from the static input shape T. Downstream Rust glue should compute the same size from the raw audio length and slice `audio_embeds[:, :size, :]` for the splice. - The projector output size is `nblocks * (window_size / downsample_rate)`. With T=844 (the reference clip), this gives `ceil(844/15) * 3 = 171`, which matches the captured PyTorch reference. """ def __init__(self, encoder: nn.Module, projector: nn.Module, window_size: int, downsample_rate: int): super().__init__() self.encoder = encoder self.projector = projector self.window_size = int(window_size) self.downsample_rate = int(downsample_rate) def forward(self, input_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: enc_out = self.encoder(input_features, return_dict=True) audio_embeds = self.projector(enc_out.last_hidden_state) # Compute audio_embed_sizes dynamically so the value tracks T at runtime. # `torch._shape_as_tensor` emits an ONNX Shape op so seq_len_t is a 0-d # int64 Tensor rather than a baked Python int. The result is an int64 # tensor of shape [B] (B baked at 1, the only mode we trace; T_audio # tracks runtime input length). shape_t = torch._shape_as_tensor(input_features) seq_len_t = shape_t[1] num_queries = self.window_size // self.downsample_rate nblocks_t = (seq_len_t + self.window_size - 1) // self.window_size size_per_t = nblocks_t * num_queries # 0-d int64 Tensor audio_embed_sizes = size_per_t.unsqueeze(0) # [1] tensor return audio_embeds, audio_embed_sizes # --------------------------------------------------------------------------- # LLM wrappers (prompt_encode + decode_step). Adapted from # src/export_granite_llm_kv.py to take inputs_embeds instead of input_ids. # --------------------------------------------------------------------------- def _build_causal_mask_4d( attention_mask_2d: torch.Tensor, T_past: int, dtype: torch.dtype, ) -> torch.Tensor: """Build a 4-D additive attention mask `[B, 1, T_q, T_k]` from a 2-D padding mask `[B, T_k]`. Padding columns are -inf, and the trailing T_q query rows have an upper-triangular causal mask added. The Granite eager-mask path early-exits when handed a 4-D mask, so this short-circuits the v5 mask-helper crash under TorchScript trace. """ B, T_k = attention_mask_2d.shape T_q = T_k - T_past neg_inf = torch.finfo(dtype).min pad = (attention_mask_2d == 0).to(dtype) * neg_inf # [B, T_k] pad = pad.view(B, 1, 1, T_k).expand(B, 1, T_q, T_k) q_idx = torch.arange(T_q, device=attention_mask_2d.device).view(1, 1, T_q, 1) k_idx = torch.arange(T_k, device=attention_mask_2d.device).view(1, 1, 1, T_k) allowed = k_idx <= (q_idx + T_past) causal = torch.where( allowed, torch.zeros((), dtype=dtype, device=attention_mask_2d.device), torch.full((), neg_inf, dtype=dtype, device=attention_mask_2d.device), ) return pad + causal class PromptEncodeWrapper(nn.Module): """Prefill graph; consumes pre-spliced inputs_embeds. Forward signature (positional): inputs_embeds: float32 [B, N, H] position_ids: int64 [B, N] attention_mask: float32 [B, 1, N, N] additive 4-D causal+padding mask Outputs: logits: float32 [B, N, V] (divided by logits_scaling) present..key, present..value for L in 0..n_layers-1 """ def __init__( self, llm_model: nn.Module, lm_head: nn.Module, num_layers: int, logits_scaling: float ) -> None: super().__init__() self.llm_model = llm_model self.lm_head = lm_head self.num_layers = num_layers self.logits_scaling = logits_scaling def forward( self, inputs_embeds: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> tuple[torch.Tensor, ...]: from transformers import DynamicCache cache = DynamicCache() out = self.llm_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, use_cache=True, past_key_values=cache, ) logits = self.lm_head(out.last_hidden_state) / self.logits_scaling present = out.past_key_values flat: list[torch.Tensor] = [logits] for layer in present.layers: flat.append(layer.keys) flat.append(layer.values) return tuple(flat) class DecodeStepWrapper(nn.Module): """Single-token decode graph. Forward signature (positional): inputs_embeds: float32 [B, 1, H] position_ids: int64 [B, 1] attention_mask: float32 [B, 1, 1, T_total] additive 4-D mask past_kv_flat: 2*n_layers tensors, each float32 [B, num_kv_heads, T_past, head_dim], in the order (past.0.key, past.0.value, past.1.key, ..., past..value) """ def __init__( self, llm_model: nn.Module, lm_head: nn.Module, num_layers: int, logits_scaling: float ) -> None: super().__init__() self.llm_model = llm_model self.lm_head = lm_head self.num_layers = num_layers self.logits_scaling = logits_scaling def forward( self, inputs_embeds: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, *past_kv_flat: torch.Tensor, ) -> tuple[torch.Tensor, ...]: from transformers import DynamicCache if len(past_kv_flat) != 2 * self.num_layers: raise ValueError( f"expected {2 * self.num_layers} past_kv tensors, got {len(past_kv_flat)}" ) layer_pairs = [ (past_kv_flat[2 * i], past_kv_flat[2 * i + 1]) for i in range(self.num_layers) ] cache = DynamicCache(ddp_cache_data=layer_pairs) out = self.llm_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, use_cache=True, past_key_values=cache, ) logits = self.lm_head(out.last_hidden_state) / self.logits_scaling present = out.past_key_values flat: list[torch.Tensor] = [logits] for layer in present.layers: flat.append(layer.keys) flat.append(layer.values) return tuple(flat) # --------------------------------------------------------------------------- # Export functions. # --------------------------------------------------------------------------- def export_encoder( wrapper: EncoderProjectorWrapper, sample_input_features: torch.Tensor, out_path: Path, opset: int = 20, ir_version: int = 10, ) -> None: out_path.parent.mkdir(parents=True, exist_ok=True) print(f" exporting encoder to {out_path} (opset={opset}, ir_version={ir_version})") dynamic_axes = { "input_features": {0: "B", 1: "T"}, "audio_embeds": {0: "B", 1: "T_audio"}, "audio_embed_sizes": {0: "B"}, } with tempfile.TemporaryDirectory(prefix="speech2b_ar_encoder_onnx_") as scratch_dir: scratch_path = Path(scratch_dir) / "encoder.onnx" t0 = time.time() torch.onnx.export( wrapper, (sample_input_features,), str(scratch_path), input_names=["input_features"], output_names=["audio_embeds", "audio_embed_sizes"], dynamic_axes=dynamic_axes, opset_version=opset, do_constant_folding=True, export_params=True, dynamo=False, ) print(f" stage-1 torch.onnx.export done in {time.time() - t0:.1f}s") _resave_single_sidecar(scratch_path, out_path, ir_version) def export_prompt_encode( wrapper: PromptEncodeWrapper, sample_inputs_embeds: torch.Tensor, sample_position_ids: torch.Tensor, sample_attention_mask: torch.Tensor, out_path: Path, num_layers: int, opset: int = 20, ir_version: int = 10, ) -> None: out_path.parent.mkdir(parents=True, exist_ok=True) print(f" exporting prompt_encode to {out_path} (opset={opset}, ir_version={ir_version})") output_names: list[str] = ["logits"] for i in range(num_layers): output_names.append(f"present.{i}.key") output_names.append(f"present.{i}.value") dynamic_axes: dict[str, dict[int, str]] = { "inputs_embeds": {0: "B", 1: "N"}, "position_ids": {0: "B", 1: "N"}, "attention_mask": {0: "B", 2: "N", 3: "N"}, "logits": {0: "B", 1: "N"}, } for i in range(num_layers): dynamic_axes[f"present.{i}.key"] = {0: "B", 2: "N"} dynamic_axes[f"present.{i}.value"] = {0: "B", 2: "N"} with tempfile.TemporaryDirectory(prefix="speech2b_ar_prompt_onnx_") as scratch_dir: scratch_path = Path(scratch_dir) / "prompt_encode.onnx" t0 = time.time() torch.onnx.export( wrapper, (sample_inputs_embeds, sample_position_ids, sample_attention_mask), str(scratch_path), input_names=["inputs_embeds", "position_ids", "attention_mask"], output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset, do_constant_folding=True, export_params=True, dynamo=False, ) print(f" stage-1 torch.onnx.export done in {time.time() - t0:.1f}s") _resave_single_sidecar(scratch_path, out_path, ir_version) def export_decode_step( wrapper: DecodeStepWrapper, sample_inputs_embeds: torch.Tensor, sample_position_ids: torch.Tensor, sample_attention_mask: torch.Tensor, sample_past_kv_flat: tuple[torch.Tensor, ...], out_path: Path, num_layers: int, opset: int = 20, ir_version: int = 10, ) -> None: out_path.parent.mkdir(parents=True, exist_ok=True) print(f" exporting decode_step to {out_path} (opset={opset}, ir_version={ir_version})") input_names: list[str] = ["inputs_embeds", "position_ids", "attention_mask"] for i in range(num_layers): input_names.append(f"past_key_values.{i}.key") input_names.append(f"past_key_values.{i}.value") output_names: list[str] = ["logits"] for i in range(num_layers): output_names.append(f"present.{i}.key") output_names.append(f"present.{i}.value") dynamic_axes: dict[str, dict[int, str]] = { "inputs_embeds": {0: "B"}, "position_ids": {0: "B"}, "attention_mask": {0: "B", 3: "T_total"}, "logits": {0: "B"}, } for i in range(num_layers): dynamic_axes[f"past_key_values.{i}.key"] = {0: "B", 2: "T_past"} dynamic_axes[f"past_key_values.{i}.value"] = {0: "B", 2: "T_past"} dynamic_axes[f"present.{i}.key"] = {0: "B", 2: "T_total"} dynamic_axes[f"present.{i}.value"] = {0: "B", 2: "T_total"} with tempfile.TemporaryDirectory(prefix="speech2b_ar_decode_onnx_") as scratch_dir: scratch_path = Path(scratch_dir) / "decode_step.onnx" args = (sample_inputs_embeds, sample_position_ids, sample_attention_mask, *sample_past_kv_flat) t0 = time.time() torch.onnx.export( wrapper, args, str(scratch_path), input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset, do_constant_folding=True, export_params=True, dynamo=False, ) print(f" stage-1 torch.onnx.export done in {time.time() - t0:.1f}s") _resave_single_sidecar(scratch_path, out_path, ir_version) # --------------------------------------------------------------------------- # Parity helpers. # --------------------------------------------------------------------------- def encoder_parity( wrapper: EncoderProjectorWrapper, processor: Any, waveform: np.ndarray, onnx_path: Path, abs_tol: float, argmax_only: bool = False, ) -> dict[str, Any]: import onnxruntime as ort print("\n=== encoder parity ===") inputs = processor(USER_PROMPT_TRANSCRIBE, [waveform], sampling_rate=16000, return_tensors="pt") input_features = inputs["input_features"].to(torch.float32) print(f" input_features: {tuple(input_features.shape)}") print(" PyTorch wrapper forward") t0 = time.time() with torch.inference_mode(): audio_pt, sizes_pt = wrapper(input_features) print(f" pt: {time.time() - t0:.2f}s audio_embeds={tuple(audio_pt.shape)}") print(f" ONNX inference: {onnx_path}") sess = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) t0 = time.time() audio_ort, sizes_ort = sess.run( ["audio_embeds", "audio_embed_sizes"], {"input_features": input_features.numpy().astype(np.float32)}, ) print(f" ort: {time.time() - t0:.2f}s audio_embeds={tuple(audio_ort.shape)}") pt_np = audio_pt.detach().float().cpu().numpy() abs_err = np.abs(pt_np - audio_ort) max_err = float(abs_err.max()) mean_err = float(abs_err.mean()) p99 = float(np.percentile(abs_err, 99)) sizes_pt_np = sizes_pt.detach().cpu().numpy().astype(np.int64) sizes_ok = bool(np.array_equal(sizes_pt_np, sizes_ort.astype(np.int64))) if argmax_only: # The encoder's audio_embeds feed into the LLM, where the actual ship # gate (transcript byte-exact, argmax stable) lives. The continuous # audio_embeds delta is informational only in INT8 mode. ok = sizes_ok else: ok = max_err <= abs_tol and sizes_ok print(f" max_abs_err={max_err:.3e} mean={mean_err:.3e} p99={p99:.3e}") print(f" audio_embed_sizes pt={sizes_pt_np.tolist()} ort={sizes_ort.tolist()} match={sizes_ok}") print(f" encoder parity: {'PASS' if ok else 'FAIL'}{' (argmax-only)' if argmax_only else ''}") return { "ok": ok, "abs_tol": abs_tol, "argmax_only": argmax_only, "max_abs_err": max_err, "mean_abs_err": mean_err, "p99_abs_err": p99, "audio_embeds_shape_pt": list(pt_np.shape), "audio_embeds_shape_ort": list(audio_ort.shape), "audio_embed_sizes_pt": sizes_pt_np.tolist(), "audio_embed_sizes_ort": sizes_ort.tolist(), "audio_embed_sizes_match": sizes_ok, "audio_embeds_stats_pt": tensor_stats(audio_pt), "audio_embeds_stats_ort": tensor_stats(audio_ort), } def build_inputs_embeds(model: nn.Module, processor: Any, waveform: np.ndarray) -> tuple[torch.Tensor, torch.Tensor, dict]: """Build the post-splice `inputs_embeds [1, N, 2048]` and `position_ids` for parity, exactly mirroring the PyTorch path: 1. Render the chat prompt with `<|audio|>` -> repeated audio token. 2. Run encoder + projector to get audio embeds. 3. masked_scatter audio embeds into the text embeddings at audio-token positions. """ chat = [{"role": "user", "content": USER_PROMPT_TRANSCRIBE}] rendered = processor.tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) inputs = processor(rendered, [waveform], sampling_rate=16000, return_tensors="pt") input_ids = inputs["input_ids"].to(torch.long) input_features = inputs["input_features"].to(torch.float32) input_features_mask = inputs["input_features_mask"] print(f" prompt token ids shape={tuple(input_ids.shape)}") print(f" input_features shape={tuple(input_features.shape)} input_features_mask shape={tuple(input_features_mask.shape)}") with torch.inference_mode(): audio_outputs = model.get_audio_features(input_features, return_dict=True) audio_embeds = audio_outputs.pooler_output # The reference uses model.dtype; we forced fp32 at load. inputs_embeds = model.get_merged_audio_embeddings( input_ids=input_ids, audio_features=audio_embeds, input_features_mask=input_features_mask, ) inputs_embeds = inputs_embeds.to(torch.float32) N = inputs_embeds.shape[1] position_ids = torch.arange(N, dtype=torch.long).unsqueeze(0).expand(1, N).contiguous() info = { "input_ids_shape": list(input_ids.shape), "audio_embeds_shape": list(audio_embeds.shape), "input_features_mask_shape": list(input_features_mask.shape), "inputs_embeds_shape": list(inputs_embeds.shape), "n_audio_tokens": int((input_ids == model.config.audio_token_id).sum().item()), "input_features_mask_sum": int(input_features_mask.sum().item()), } print(f" inputs_embeds shape={tuple(inputs_embeds.shape)} audio_tokens={info['n_audio_tokens']}") return inputs_embeds, position_ids, info def llm_parity_e2e( model: nn.Module, processor: Any, waveform: np.ndarray, prompt_onnx: Path, decode_onnx: Path, baseline_json: Path, max_new_tokens: int, abs_tol: float, argmax_only: bool = False, ) -> dict[str, Any]: """Greedy-decode end-to-end through the ONNX graphs and compare against the captured PyTorch baseline transcript token-for-token. """ import onnxruntime as ort print("\n=== prompt_encode + decode_step end-to-end parity ===") inputs_embeds, position_ids, embed_info = build_inputs_embeds(model, processor, waveform) N = inputs_embeds.shape[1] # Build the 4-D additive causal+pad mask Python-side. attn_2d = torch.ones((1, N), dtype=torch.long) attn_4d_prompt = _build_causal_mask_4d(attn_2d, T_past=0, dtype=torch.float32) # Reference: PyTorch logits at the prompt's last position; expected first # generated token == baseline new_token_ids[0]. print(" loading PyTorch reference path (lm_head / logits_scaling)") with torch.inference_mode(): out = model.language_model( inputs_embeds=inputs_embeds, attention_mask=attn_4d_prompt, position_ids=position_ids, use_cache=True, past_key_values=None, ) # The language_model is a GraniteForCausalLM; out.logits is already # divided by logits_scaling. Use it as the strict-parity reference. pt_logits = out.logits.detach().float().cpu().numpy() pt_past = out.past_key_values print(f" pt prompt logits shape={pt_logits.shape} argmax_last={int(pt_logits[0, -1].argmax())}") # ---- ONNX: prompt_encode ---- print(f" loading ONNX sessions") so = ort.SessionOptions() sess_prompt = ort.InferenceSession(str(prompt_onnx), so, providers=["CPUExecutionProvider"]) sess_decode = ort.InferenceSession(str(decode_onnx), so, providers=["CPUExecutionProvider"]) num_layers = len(model.language_model.model.layers) feeds_prompt = { "inputs_embeds": inputs_embeds.numpy().astype(np.float32), "position_ids": position_ids.numpy().astype(np.int64), "attention_mask": attn_4d_prompt.numpy().astype(np.float32), } print(" running prompt_encode.onnx") t0 = time.time() prompt_outs = sess_prompt.run(None, feeds_prompt) print(f" forward: {time.time() - t0:.2f}s") prompt_logits = prompt_outs[0] past_kv_flat = list(prompt_outs[1:]) assert len(past_kv_flat) == 2 * num_layers # Compare prompt-stage logits. prompt_diff = np.abs(prompt_logits - pt_logits) prompt_max_err = float(prompt_diff.max()) prompt_mean_err = float(prompt_diff.mean()) pt_argmax = pt_logits.argmax(-1) ort_argmax = prompt_logits.argmax(-1) prompt_argmax_mismatches = int((pt_argmax != ort_argmax).sum()) print(f" prompt logits max_abs_err={prompt_max_err:.3e} mean={prompt_mean_err:.3e} " f"argmax_mismatches={prompt_argmax_mismatches}/{pt_argmax.size}") # First generated token: argmax at the last prompt position (this is what # GenerationMixin's greedy path does). embed_tokens = model.language_model.model.embed_tokens eos_id = int(model.config.text_config.eos_token_id) onnx_new_tokens: list[int] = [int(prompt_logits[0, -1].argmax())] onnx_step_logits: list[np.ndarray] = [prompt_logits[0, -1].astype(np.float32)] print(f" greedy-decoding up to {max_new_tokens} new tokens through decode_step.onnx") t0 = time.time() for step in range(1, max_new_tokens): prev_tok = onnx_new_tokens[-1] if prev_tok == eos_id: break T_past = N + step - 1 T_total = T_past + 1 # Build the next inputs_embeds via the model's embed_tokens. prev_id_tensor = torch.tensor([[prev_tok]], dtype=torch.long) with torch.inference_mode(): next_embed = embed_tokens(prev_id_tensor).to(torch.float32) # 4-D additive mask of zeros for unmasked positions; padding is irrelevant # because attention_mask_2d is all-ones throughout the decode loop. attn_2d_step = torch.ones((1, T_total), dtype=torch.long) attn_4d_step = _build_causal_mask_4d(attn_2d_step, T_past=T_past, dtype=torch.float32) feeds: dict[str, np.ndarray] = { "inputs_embeds": next_embed.numpy().astype(np.float32), "position_ids": np.array([[T_past]], dtype=np.int64), "attention_mask": attn_4d_step.numpy().astype(np.float32), } for i in range(num_layers): feeds[f"past_key_values.{i}.key"] = past_kv_flat[2 * i] feeds[f"past_key_values.{i}.value"] = past_kv_flat[2 * i + 1] outs = sess_decode.run(None, feeds) step_logits = outs[0] new_past = list(outs[1:]) assert len(new_past) == 2 * num_layers past_kv_flat = new_past nt = int(step_logits[0, 0].argmax()) onnx_step_logits.append(step_logits[0, 0].astype(np.float32)) onnx_new_tokens.append(nt) print(f" {len(onnx_new_tokens) - 1} decode_step forwards: {time.time() - t0:.2f}s") onnx_transcript = processor.tokenizer.decode( [t for t in onnx_new_tokens if t != eos_id], skip_special_tokens=True ) print(f" onnx new tokens: {onnx_new_tokens}") print(f" onnx transcript: {onnx_transcript!r}") baseline = json.loads(baseline_json.read_text()) baseline_tokens = baseline["new_token_ids"] baseline_transcript = baseline["transcript"] tokens_match = onnx_new_tokens == baseline_tokens transcript_match = onnx_transcript == baseline_transcript print(f" baseline transcript: {baseline_transcript!r}") print(f" tokens match: {tokens_match} transcript match: {transcript_match}") # Per-step parity vs PyTorch reference for the first 5 steps. per_step_compare: list[dict[str, Any]] = [] pt_step_logits = None if max_new_tokens >= 1: # Recompute PyTorch reference logits per step via model.generate, to # avoid having to maintain an alternate decode loop here. with torch.inference_mode(): chat = [{"role": "user", "content": USER_PROMPT_TRANSCRIBE}] rendered = processor.tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) ref_inputs = processor(rendered, [waveform], sampling_rate=16000, return_tensors="pt") gen = model.generate( **ref_inputs, max_new_tokens=max_new_tokens, do_sample=False, num_beams=1, return_dict_in_generate=True, output_scores=True, ) pt_step_logits = [s[0].detach().float().cpu().numpy() for s in gen.scores] n_compare = min(len(pt_step_logits), len(onnx_step_logits)) for i in range(n_compare): ref = pt_step_logits[i].astype(np.float32) ours = onnx_step_logits[i].astype(np.float32) d = np.abs(ref - ours) per_step_compare.append({ "step": i, "ref_token": int(ref.argmax()), "onnx_token": int(ours.argmax()), "argmax_match": int(ref.argmax()) == int(ours.argmax()), "max_abs_err": float(d.max()), "mean_abs_err": float(d.mean()), }) overall_max = max((s["max_abs_err"] for s in per_step_compare), default=0.0) overall_argmax_mm = sum(0 if s["argmax_match"] else 1 for s in per_step_compare) if argmax_only: # INT8 ship gate: end-to-end transcript + decoded token IDs match the # baseline exactly. Prompt-stage and per-step max-abs deltas vs FP32 # are recorded for reporting but not blocking. ok = tokens_match and transcript_match else: ok = ( tokens_match and transcript_match and prompt_argmax_mismatches == 0 and overall_argmax_mm == 0 ) return { "ok": ok, "abs_tol": abs_tol, "argmax_only": argmax_only, "embed_info": embed_info, "N_prompt": N, "prompt_logits_max_abs_err": prompt_max_err, "prompt_logits_mean_abs_err": prompt_mean_err, "prompt_argmax_mismatches": prompt_argmax_mismatches, "prompt_argmax_total": int(pt_argmax.size), "onnx_new_tokens": onnx_new_tokens, "baseline_new_tokens": baseline_tokens, "tokens_match": tokens_match, "onnx_transcript": onnx_transcript, "baseline_transcript": baseline_transcript, "transcript_match": transcript_match, "per_step_compare": per_step_compare, "overall_max_abs_err_step": overall_max, "overall_argmax_mismatches_step": overall_argmax_mm, } # --------------------------------------------------------------------------- # Main. # --------------------------------------------------------------------------- def main() -> None: p = argparse.ArgumentParser() p.add_argument("--audio", default=str(DEFAULT_AUDIO)) p.add_argument("--baseline", default=str(DEFAULT_BASELINE)) p.add_argument("--model-dir", default=str(DEFAULT_MODEL_DIR)) p.add_argument("--out-dir", default=str(DEFAULT_OUT_DIR)) p.add_argument("--abs-tol", type=float, default=1e-3) p.add_argument( "--stages", default="encoder,prompt,decode", help="comma-separated subset of {encoder, prompt, decode}", ) p.add_argument( "--skip-export", action="store_true", help="skip export, run parity on existing files" ) p.add_argument("--max-new-tokens", type=int, default=80) p.add_argument( "--graph-suffix", default="", help="suffix appended to graph stems (e.g. '_int8') so parity runs against " "encoder.onnx etc. Parity output goes to parity.json. " "When set, --skip-export is implied.", ) args = p.parse_args() stages = {s.strip() for s in args.stages.split(",") if s.strip()} valid = {"encoder", "prompt", "decode"} bad = stages - valid if bad: raise SystemExit(f"unknown stage(s): {bad}; valid: {sorted(valid)}") out_dir = Path(args.out_dir) suffix = args.graph_suffix if suffix and not args.skip_export: print(f" --graph-suffix={suffix!r} set; implying --skip-export") args.skip_export = True encoder_path = out_dir / f"encoder{suffix}.onnx" prompt_path = out_dir / f"prompt_encode{suffix}.onnx" decode_path = out_dir / f"decode_step{suffix}.onnx" parity_json = out_dir / f"parity{suffix}.json" print(f"audio: {args.audio}") print(f"model_dir: {args.model_dir}") print(f"out_dir: {out_dir}") print(f"stages: {sorted(stages)}") waveform = load_audio(Path(args.audio)) print(f" duration={waveform.shape[0] / 16000:.2f}s") print("loading model...") model, processor = load_base_model(Path(args.model_dir)) print("patching conformer attention for tracing...") patch_conformer_for_tracing(model) print("patching projector for dynamic T_audio tracing...") patch_projector_for_tracing(model) # Useful constants from the loaded config. text_cfg = model.config.text_config num_layers = int(text_cfg.num_hidden_layers) logits_scaling = float(text_cfg.logits_scaling) print(f" num_layers={num_layers} logits_scaling={logits_scaling}") print(f" audio_token_id={model.config.audio_token_id} hidden_size={text_cfg.hidden_size}") # Sample inputs for tracing. sample_inputs = processor( USER_PROMPT_TRANSCRIBE, [waveform], sampling_rate=16000, return_tensors="pt" ) sample_features = sample_inputs["input_features"].to(torch.float32) parity_payload: dict[str, Any] = { "abs_tol": args.abs_tol, "stages_run": sorted(stages), "input_features_shape": list(sample_features.shape), } # ----- Encoder export + parity ----- if "encoder" in stages: wrapper = EncoderProjectorWrapper( encoder=model.encoder, projector=model.projector, window_size=int(model.config.window_size), downsample_rate=int(model.config.downsample_rate), ).eval() if not args.skip_export: with torch.inference_mode(): export_encoder( wrapper=wrapper, sample_input_features=sample_features, out_path=encoder_path, opset=20, ir_version=10, ) parity_payload["encoder"] = encoder_parity( wrapper=wrapper, processor=processor, waveform=waveform, onnx_path=encoder_path, abs_tol=args.abs_tol, argmax_only=bool(suffix), ) # ----- LLM (prompt + decode) export ----- if {"prompt", "decode"} & stages and not args.skip_export: # Build a sample inputs_embeds + position_ids by running encoder + splice. print("\nbuilding sample inputs_embeds for LLM export trace...") sample_embeds, sample_pos_ids, _info = build_inputs_embeds(model, processor, waveform) N = sample_embeds.shape[1] sample_attn_4d = _build_causal_mask_4d( torch.ones((1, N), dtype=torch.long), T_past=0, dtype=torch.float32 ) if "prompt" in stages: prompt_wrapper = PromptEncodeWrapper( llm_model=model.language_model.model, lm_head=model.language_model.lm_head, num_layers=num_layers, logits_scaling=logits_scaling, ).eval() with torch.inference_mode(): export_prompt_encode( wrapper=prompt_wrapper, sample_inputs_embeds=sample_embeds, sample_position_ids=sample_pos_ids, sample_attention_mask=sample_attn_4d, out_path=prompt_path, num_layers=num_layers, opset=20, ir_version=10, ) if "decode" in stages: # We need a sample past_kv set for the decode_step trace; harvest by # running the prompt wrapper once. prompt_wrapper = PromptEncodeWrapper( llm_model=model.language_model.model, lm_head=model.language_model.lm_head, num_layers=num_layers, logits_scaling=logits_scaling, ).eval() with torch.inference_mode(): p_outs = prompt_wrapper(sample_embeds, sample_pos_ids, sample_attn_4d) sample_past_kv_flat = tuple(t.detach().clone() for t in p_outs[1:]) assert len(sample_past_kv_flat) == 2 * num_layers embed_tokens = model.language_model.model.embed_tokens with torch.inference_mode(): sample_step_embed = ( embed_tokens(torch.tensor([[0]], dtype=torch.long)).to(torch.float32) ) sample_step_pos = torch.tensor([[N]], dtype=torch.long) sample_step_attn_2d = torch.ones((1, N + 1), dtype=torch.long) sample_step_attn_4d = _build_causal_mask_4d( sample_step_attn_2d, T_past=N, dtype=torch.float32 ) decode_wrapper = DecodeStepWrapper( llm_model=model.language_model.model, lm_head=model.language_model.lm_head, num_layers=num_layers, logits_scaling=logits_scaling, ).eval() with torch.inference_mode(): export_decode_step( wrapper=decode_wrapper, sample_inputs_embeds=sample_step_embed, sample_position_ids=sample_step_pos, sample_attention_mask=sample_step_attn_4d, sample_past_kv_flat=sample_past_kv_flat, out_path=decode_path, num_layers=num_layers, opset=20, ir_version=10, ) # ----- end-to-end LLM parity ----- if {"prompt", "decode"} <= stages and prompt_path.exists() and decode_path.exists(): parity_payload["llm_e2e"] = llm_parity_e2e( model=model, processor=processor, waveform=waveform, prompt_onnx=prompt_path, decode_onnx=decode_path, baseline_json=Path(args.baseline), max_new_tokens=args.max_new_tokens, abs_tol=args.abs_tol, argmax_only=bool(suffix), ) # ----- Per-graph size + int8-vs-fp32 deltas (only when graph-suffix set) ----- if suffix: parity_payload["graph_suffix"] = suffix parity_payload["graphs"] = {} for label, p in ( ("encoder", encoder_path), ("prompt_encode", prompt_path), ("decode_step", decode_path), ): if not p.exists(): continue data = p.with_name(p.name + "_data") entry = { "graph_path": str(p), "graph_size_bytes": int(p.stat().st_size), "sidecar_path": str(data) if data.exists() else None, "int8_size_bytes": int(data.stat().st_size) if data.exists() else None, } fp32 = p.with_name(p.name.replace(suffix, "")) fp32_data = fp32.with_name(fp32.name + "_data") if fp32.exists() and fp32_data.exists(): entry["fp32_sidecar_path"] = str(fp32_data) entry["fp32_size_bytes"] = int(fp32_data.stat().st_size) if entry["int8_size_bytes"]: entry["size_ratio"] = entry["int8_size_bytes"] / entry["fp32_size_bytes"] parity_payload["graphs"][label] = entry # ----- Write parity report ----- parity_json.parent.mkdir(parents=True, exist_ok=True) parity_json.write_text(json.dumps(parity_payload, indent=2)) print(f"\nwrote parity report -> {parity_json}") # ----- Final summary ----- failures = [] print("\n--- summary ---") if "encoder" in parity_payload: e = parity_payload["encoder"] print(f" encoder: {'PASS' if e['ok'] else 'FAIL'} max_abs_err={e['max_abs_err']:.3e}") if not e["ok"]: failures.append("encoder") if "llm_e2e" in parity_payload: l = parity_payload["llm_e2e"] print( f" llm_e2e: {'PASS' if l['ok'] else 'FAIL'} " f"prompt_argmax_mm={l['prompt_argmax_mismatches']} " f"step_argmax_mm={l['overall_argmax_mismatches_step']} " f"transcript_match={l['transcript_match']}" ) if not l["ok"]: failures.append("llm_e2e") if failures: raise SystemExit(f"failed: {failures}") if __name__ == "__main__": main()