#!/usr/bin/env python3 """ Export Cohere Transcribe to ONNX. Produces self-contained ONNX files in the repo root: encoder-0..3.onnx — conformer encoder (4 splits, ~1.3 GB each) cross_kv.onnx — encoder states -> cross-attention KV for decoder decoder.onnx — autoregressive token decoder with KV cache Usage: python export_onnx.py """ import os import shutil from pathlib import Path import onnx import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForSpeechSeq2Seq MODEL_ID = "CohereLabs/cohere-transcribe-03-2026" DATA_DIR = Path("onnx_out") OPSET = 18 NUM_ENCODER_SPLITS = 4 NUM_DECODER_LAYERS = 8 DECODER_HIDDEN = 1024 DECODER_HEADS = 8 HEAD_DIM = DECODER_HIDDEN // DECODER_HEADS # --------------------------------------------------------------------------- # Encoder splits # --------------------------------------------------------------------------- class EncoderSplitForExport(nn.Module): """A slice of the conformer encoder. First split: conv subsampling + pos_enc + mask creation + layers. Middle splits: conformer layers only. Last split: final layers + encoder_decoder_proj. """ def __init__(self, model, layers, is_first=False, is_last=False): super().__init__() self.is_first = is_first self.is_last = is_last self.layers = nn.ModuleList(layers) if is_first: self.pre_encode = model.encoder.pre_encode self.pos_enc = model.encoder.pos_enc self.pos_enc._materialize_pe( length=self.pos_enc.max_len, device=torch.device("cpu"), dtype=torch.float32, ) if is_last: self.encoder_decoder_proj = model.encoder_decoder_proj def _create_masks(self, length, max_len): att_mask = torch.ones(1, max_len, max_len, dtype=torch.bool, device=length.device) pad_mask = torch.arange(0, max_len, device=length.device).expand( length.size(0), -1 ) < length.unsqueeze(-1) pad_mask_for_att = pad_mask.unsqueeze(1).repeat([1, max_len, 1]) pad_mask_for_att = torch.logical_and(pad_mask_for_att, pad_mask_for_att.transpose(1, 2)) att_mask = torch.logical_and(att_mask, pad_mask_for_att) return ~pad_mask, ~att_mask def forward(self, *args): if self.is_first: input_features, length = args x, length = self.pre_encode(input_features, length) length = length.to(torch.int64) max_len = x.size(1) x, pos_emb = self.pos_enc(x) pad_mask, att_mask = self._create_masks(length, max_len) else: x, pos_emb, att_mask, pad_mask, length = args for layer in self.layers: x = layer(x, pos_emb, mask=att_mask, pad_mask=pad_mask) if self.is_last: if self.encoder_decoder_proj is not None: x = self.encoder_decoder_proj(x) return x, length return x, pos_emb, att_mask, pad_mask, length # --------------------------------------------------------------------------- # Cross-KV projections # --------------------------------------------------------------------------- class CrossKVForExport(nn.Module): """Project encoder output to cross-attention K/V for all decoder layers.""" def __init__(self, model): super().__init__() self.projections = nn.ModuleList() for layer in model.transf_decoder._decoder.layers: cross_attn = layer.second_sub_layer proj = nn.Module() proj.key_net = cross_attn.key_net proj.value_net = cross_attn.value_net proj.num_heads = cross_attn.num_heads proj.head_dim = cross_attn.head_dim self.projections.append(proj) def forward(self, encoder_out: torch.Tensor): b = encoder_out.shape[0] kvs = [] for proj in self.projections: k = proj.key_net(encoder_out).view(b, -1, proj.num_heads, proj.head_dim).transpose(1, 2) v = proj.value_net(encoder_out).view(b, -1, proj.num_heads, proj.head_dim).transpose(1, 2) kvs.extend([k, v]) return tuple(kvs) # --------------------------------------------------------------------------- # Decoder # --------------------------------------------------------------------------- class DecoderForExport(nn.Module): """Transformer decoder with pre-computed cross-attention KV and self-attention KV cache.""" def __init__(self, model): super().__init__() self.embedding = model.transf_decoder._embedding self.decoder_layers = model.transf_decoder._decoder.layers self.final_layer_norm = model.transf_decoder._decoder.final_layer_norm self.classifier = model.log_softmax def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, cross_attention_mask: torch.Tensor, self_k_0: torch.Tensor, self_v_0: torch.Tensor, self_k_1: torch.Tensor, self_v_1: torch.Tensor, self_k_2: torch.Tensor, self_v_2: torch.Tensor, self_k_3: torch.Tensor, self_v_3: torch.Tensor, self_k_4: torch.Tensor, self_v_4: torch.Tensor, self_k_5: torch.Tensor, self_v_5: torch.Tensor, self_k_6: torch.Tensor, self_v_6: torch.Tensor, self_k_7: torch.Tensor, self_v_7: torch.Tensor, cross_k_0: torch.Tensor, cross_v_0: torch.Tensor, cross_k_1: torch.Tensor, cross_v_1: torch.Tensor, cross_k_2: torch.Tensor, cross_v_2: torch.Tensor, cross_k_3: torch.Tensor, cross_v_3: torch.Tensor, cross_k_4: torch.Tensor, cross_v_4: torch.Tensor, cross_k_5: torch.Tensor, cross_v_5: torch.Tensor, cross_k_6: torch.Tensor, cross_v_6: torch.Tensor, cross_k_7: torch.Tensor, cross_v_7: torch.Tensor, ): self_k_in = [self_k_0, self_k_1, self_k_2, self_k_3, self_k_4, self_k_5, self_k_6, self_k_7] self_v_in = [self_v_0, self_v_1, self_v_2, self_v_3, self_v_4, self_v_5, self_v_6, self_v_7] cross_k_in = [cross_k_0, cross_k_1, cross_k_2, cross_k_3, cross_k_4, cross_k_5, cross_k_6, cross_k_7] cross_v_in = [cross_v_0, cross_v_1, cross_v_2, cross_v_3, cross_v_4, cross_v_5, cross_v_6, cross_v_7] batch_size, tgt_len = input_ids.shape dtype = cross_k_in[0].dtype past_len = self_k_in[0].shape[2] total_kv_len = past_len + tgt_len # Causal self-attention mask q_pos = torch.arange(past_len, past_len + tgt_len, device=input_ids.device)[:, None] k_pos = torch.arange(total_kv_len, device=input_ids.device)[None, :] self_attn_mask = torch.zeros( (batch_size, 1, tgt_len, total_kv_len), device=input_ids.device, dtype=dtype ) self_attn_mask.masked_fill_((k_pos > q_pos)[None, None], float("-inf")) hidden = self.embedding(input_ids, positions) self_k_out, self_v_out = [], [] for i, layer in enumerate(self.decoder_layers): attn_self = layer.first_sub_layer attn_cross = layer.second_sub_layer # Self-attention residual = hidden h = layer.layer_norm_1(hidden) q = attn_self._reshape(attn_self.query_net(h)) k = torch.cat([self_k_in[i], attn_self._reshape(attn_self.key_net(h))], dim=2) v = torch.cat([self_v_in[i], attn_self._reshape(attn_self.value_net(h))], dim=2) self_k_out.append(k) self_v_out.append(v) out = F.scaled_dot_product_attention(q, k, v, attn_mask=self_attn_mask, dropout_p=0.0, scale=attn_self.scale) out = out.transpose(1, 2).contiguous().view(batch_size, tgt_len, attn_self.hidden_size) hidden = residual + attn_self.out_projection(out) # Cross-attention residual = hidden h = layer.layer_norm_2(hidden) q = attn_cross._reshape(attn_cross.query_net(h)) out = F.scaled_dot_product_attention( q, cross_k_in[i], cross_v_in[i], attn_mask=cross_attention_mask, dropout_p=0.0, scale=attn_cross.scale, ) out = out.transpose(1, 2).contiguous().view(batch_size, tgt_len, attn_cross.hidden_size) hidden = residual + attn_cross.out_projection(out) # FFN residual = hidden hidden = residual + layer.third_sub_layer(layer.layer_norm_3(hidden)) logits = self.classifier(self.final_layer_norm(hidden)) kv_out = [] for k, v in zip(self_k_out, self_v_out): kv_out.extend([k, v]) return (logits, *kv_out) # --------------------------------------------------------------------------- # Export helpers # --------------------------------------------------------------------------- def _split_layer_ranges(num_layers, num_splits): per_split = num_layers // num_splits remainder = num_layers % num_splits ranges, start = [], 0 for i in range(num_splits): end = start + per_split + (1 if i < remainder else 0) ranges.append((start, end)) start = end return ranges def export_encoder_splits(model, data_dir: Path): encoder = model.encoder num_layers = len(encoder.layers) ranges = _split_layer_ranges(num_layers, NUM_ENCODER_SPLITS) print(f"Encoder: {num_layers} layers -> {NUM_ENCODER_SPLITS} splits: " + ", ".join(f"[{s}:{e}]" for s, e in ranges)) encoder.pos_enc._materialize_pe( length=encoder.pos_enc.max_len, device=torch.device("cpu"), dtype=torch.float32, ) # Probe shapes for intermediate tensors B, MEL_BINS, T = 1, 128, 500 with torch.no_grad(): x, _ = encoder.pre_encode(torch.randn(B, MEL_BINS, T), torch.tensor([T], dtype=torch.long)) t_sub = x.size(1) x, pos_emb = encoder.pos_enc(x) d_model, pos_len = x.shape[2], pos_emb.shape[1] for idx, (start, end) in enumerate(ranges): is_first = (idx == 0) is_last = (idx == len(ranges) - 1) tag = "first" if is_first else ("last" if is_last else "middle") print(f"\n encoder-{idx} (layers {start}-{end-1}, {tag})") wrapper = EncoderSplitForExport( model, list(encoder.layers[start:end]), is_first=is_first, is_last=is_last, ) wrapper.eval() if is_first: dummy_args = (torch.randn(B, MEL_BINS, T), torch.tensor([T], dtype=torch.long)) input_names = ["input_features", "length"] dynamic_axes = {"input_features": {0: "B", 2: "T"}, "length": {0: "B"}} else: dummy_args = ( torch.randn(B, t_sub, d_model), torch.randn(1, pos_len, d_model), torch.zeros(B, t_sub, t_sub, dtype=torch.bool), torch.zeros(B, t_sub, dtype=torch.bool), torch.tensor([t_sub], dtype=torch.long), ) input_names = ["hidden_states", "pos_emb", "att_mask", "pad_mask", "length"] dynamic_axes = { "hidden_states": {0: "B", 1: "S"}, "pos_emb": {1: "P"}, "att_mask": {0: "B", 1: "S", 2: "S"}, "pad_mask": {0: "B", 1: "S"}, "length": {0: "B"}, } if is_last: output_names = ["encoder_out", "encoder_lengths"] dynamic_axes.update({"encoder_out": {0: "B", 1: "S"}, "encoder_lengths": {0: "B"}}) else: output_names = ["hidden_states_out", "pos_emb_out", "att_mask_out", "pad_mask_out", "length_out"] dynamic_axes.update({ "hidden_states_out": {0: "B", 1: "S"}, "pos_emb_out": {1: "P"}, "att_mask_out": {0: "B", 1: "S", 2: "S"}, "pad_mask_out": {0: "B", 1: "S"}, "length_out": {0: "B"}, }) out_path = str(data_dir / f"encoder-{idx}.onnx") torch.onnx.export( wrapper, dummy_args, out_path, opset_version=OPSET, dynamo=False, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, ) print(f" -> {os.path.getsize(out_path) / (1024**2):.0f} MB") def export_cross_kv(model, data_dir: Path): print("\n cross_kv") wrapper = CrossKVForExport(model) wrapper.eval() B, src_len = 1, 62 output_names = [] dynamic_axes = {"encoder_out": {0: "B", 1: "S"}} for i in range(NUM_DECODER_LAYERS): for kv in ["k", "v"]: name = f"cross_{kv}_{i}" output_names.append(name) dynamic_axes[name] = {0: "B", 2: "S"} out_path = str(data_dir / "cross_kv.onnx") torch.onnx.export( wrapper, (torch.randn(B, src_len, DECODER_HIDDEN),), out_path, opset_version=OPSET, dynamo=False, input_names=["encoder_out"], output_names=output_names, dynamic_axes=dynamic_axes, ) print(f" -> {os.path.getsize(out_path) / (1024**2):.0f} MB") def export_decoder(model, data_dir: Path): print("\n decoder") dec = DecoderForExport(model) dec.eval() B, tgt_len, src_len, past_len = 1, 1, 62, 10 args = ( torch.ones(B, tgt_len, dtype=torch.long), torch.tensor([[past_len]], dtype=torch.long), torch.zeros(B, 1, 1, src_len), ) for _ in range(NUM_DECODER_LAYERS): args += (torch.randn(B, DECODER_HEADS, past_len, HEAD_DIM), torch.randn(B, DECODER_HEADS, past_len, HEAD_DIM)) for _ in range(NUM_DECODER_LAYERS): args += (torch.randn(B, DECODER_HEADS, src_len, HEAD_DIM), torch.randn(B, DECODER_HEADS, src_len, HEAD_DIM)) input_names = ["input_ids", "positions", "cross_attention_mask"] output_names = ["logits"] dynamic_axes = { "input_ids": {0: "B", 1: "T"}, "positions": {0: "B", 1: "T"}, "cross_attention_mask": {0: "B", 3: "S"}, "logits": {0: "B", 1: "T"}, } for i in range(NUM_DECODER_LAYERS): input_names += [f"self_k_in_{i}", f"self_v_in_{i}"] output_names += [f"self_k_out_{i}", f"self_v_out_{i}"] dynamic_axes[f"self_k_in_{i}"] = {0: "B", 2: "P"} dynamic_axes[f"self_v_in_{i}"] = {0: "B", 2: "P"} dynamic_axes[f"self_k_out_{i}"] = {0: "B", 2: "KV"} dynamic_axes[f"self_v_out_{i}"] = {0: "B", 2: "KV"} for i in range(NUM_DECODER_LAYERS): input_names += [f"cross_k_in_{i}", f"cross_v_in_{i}"] dynamic_axes[f"cross_k_in_{i}"] = {0: "B", 2: "S"} dynamic_axes[f"cross_v_in_{i}"] = {0: "B", 2: "S"} out_path = str(data_dir / "decoder.onnx") torch.onnx.export( dec, args, out_path, opset_version=OPSET, dynamo=False, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, ) print(f" -> {os.path.getsize(out_path) / (1024**2):.0f} MB") # --------------------------------------------------------------------------- # Consolidation — inline all external data into self-contained .onnx files # --------------------------------------------------------------------------- def _force_inline(model): """Clear external_data references so onnx.save writes everything inline.""" for tensor in model.graph.initializer: if tensor.external_data: del tensor.external_data[:] tensor.data_location = onnx.TensorProto.DEFAULT for node in model.graph.node: for attr in node.attribute: if attr.t and attr.t.external_data: del attr.t.external_data[:] attr.t.data_location = onnx.TensorProto.DEFAULT def consolidate_to_root(data_dir: Path): """Load ONNX models with external data, save as self-contained files to repo root.""" for onnx_path in sorted(data_dir.glob("*.onnx")): model = onnx.load(str(onnx_path), load_external_data=True) _force_inline(model) dest = Path(onnx_path.name) onnx.save(model, str(dest)) print(f" {dest} ({dest.stat().st_size / (1024**2):.0f} MB)") shutil.rmtree(data_dir) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): DATA_DIR.mkdir(exist_ok=True) print(f"Loading {MODEL_ID}") model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, trust_remote_code=True, dtype=torch.float32, ) model.eval() print("\nExporting...") with torch.no_grad(): export_encoder_splits(model, DATA_DIR) export_cross_kv(model, DATA_DIR) export_decoder(model, DATA_DIR) print("\nConsolidating to root...") consolidate_to_root(DATA_DIR) print("\nDone.") if __name__ == "__main__": main()