#!/usr/bin/env python3 """ Transcribe audio using the exported ONNX models. Usage: python transcribe.py # download demo samples and transcribe python transcribe.py audio.wav python transcribe.py audio_dir/ python transcribe.py audio.wav es # specify language """ import glob import os import sys import time from pathlib import Path import librosa import numpy as np import onnxruntime as ort from transformers import AutoProcessor MODEL_ID = "CohereLabs/cohere-transcribe-03-2026" NUM_DECODER_LAYERS = 8 DECODER_HEADS = 8 HEAD_DIM = 128 # --------------------------------------------------------------------------- # ONNX model wrappers # --------------------------------------------------------------------------- class OnnxEncoder: """Chains encoder-0.onnx .. encoder-N.onnx.""" def __init__(self, paths: list[str]): self.splits = [ort.InferenceSession(p, providers=["CPUExecutionProvider"]) for p in paths] def __call__(self, input_features: np.ndarray, length: np.ndarray): outs = self.splits[0].run(None, {"input_features": input_features, "length": length}) for split in self.splits[1:]: feeds = {inp.name: outs[j] for j, inp in enumerate(split.get_inputs())} outs = split.run(None, feeds) return outs[0], outs[1] # encoder_out, encoder_lengths class OnnxCrossKV: def __init__(self, path: str): self.sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) def __call__(self, encoder_out: np.ndarray): return list(self.sess.run(None, {"encoder_out": encoder_out})) class OnnxDecoder: def __init__(self, path: str): self.sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) def __call__(self, input_ids, positions, cross_attn_mask, self_kv, cross_kv): feeds = { "input_ids": input_ids, "positions": positions, "cross_attention_mask": cross_attn_mask, } for i in range(NUM_DECODER_LAYERS): feeds[f"self_k_in_{i}"] = self_kv[i * 2] feeds[f"self_v_in_{i}"] = self_kv[i * 2 + 1] for i in range(NUM_DECODER_LAYERS): feeds[f"cross_k_in_{i}"] = cross_kv[i * 2] feeds[f"cross_v_in_{i}"] = cross_kv[i * 2 + 1] outs = self.sess.run(None, feeds) return outs[0], list(outs[1:]) # logits, self_kv_out # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def greedy_decode(encoder, cross_kv_model, decoder, input_features, length, prompt_ids, eos_token_id, max_new_tokens=256): B = 1 # Encode once encoder_out, enc_lengths = encoder(input_features, length) cross_kv = cross_kv_model(encoder_out) src_len = cross_kv[0].shape[2] # Cross-attention mask enc_len = int(enc_lengths[0]) cross_mask = np.zeros((B, 1, 1, src_len), dtype=np.float32) if enc_len < src_len: cross_mask[:, :, :, enc_len:] = -1e9 # Empty self-attention KV cache self_kv = [] for _ in range(NUM_DECODER_LAYERS): self_kv.append(np.zeros((B, DECODER_HEADS, 0, HEAD_DIM), dtype=np.float32)) self_kv.append(np.zeros((B, DECODER_HEADS, 0, HEAD_DIM), dtype=np.float32)) # Feed prompt all_ids = prompt_ids.tolist() for step, tok in enumerate(all_ids): logits, self_kv = decoder( np.array([[tok]], dtype=np.int64), np.array([[step]], dtype=np.int64), cross_mask, self_kv, cross_kv, ) # Generate generated = [] for gen_step in range(max_new_tokens): next_id = int(np.argmax(logits[0, -1, :])) if next_id == eos_token_id: break generated.append(next_id) logits, self_kv = decoder( np.array([[next_id]], dtype=np.int64), np.array([[len(all_ids) + gen_step]], dtype=np.int64), cross_mask, self_kv, cross_kv, ) return generated def load_models(): """Load processor and ONNX models once.""" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) encoder_paths = sorted(glob.glob("encoder-*.onnx")) if not encoder_paths: print("No encoder-*.onnx files found. Run export_onnx.py first.") sys.exit(1) print(f"Loading {len(encoder_paths)} encoder splits, cross_kv, decoder...") encoder = OnnxEncoder(encoder_paths) cross_kv_model = OnnxCrossKV("cross_kv.onnx") decoder = OnnxDecoder("decoder.onnx") return processor, encoder, cross_kv_model, decoder def transcribe_file(processor, encoder, cross_kv_model, decoder, wav_path, language="en"): """Transcribe a single file. Returns (text, audio_duration, wall_time).""" pnc_token = "<|pnc|>" prompt_text = ( "<|startofcontext|><|startoftranscript|><|emo:undefined|>" f"<|{language}|><|{language}|>{pnc_token}<|noitn|><|notimestamp|><|nodiarize|>" ) prompt_ids = np.array( processor.tokenizer.encode(prompt_text, add_special_tokens=False), dtype=np.int64 ) eos_token_id = processor.tokenizer.eos_token_id audio, _ = librosa.load(wav_path, sr=16000) audio_duration = len(audio) / 16000 inputs = processor(audio=[audio], text=[prompt_text], sampling_rate=16000, return_tensors="np") features = inputs["input_features"] length = np.array([features.shape[2]], dtype=np.int64) t0 = time.perf_counter() generated_ids = greedy_decode( encoder, cross_kv_model, decoder, features, length, prompt_ids, eos_token_id, ) wall_s = time.perf_counter() - t0 text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() return text, audio_duration, wall_s def transcribe_files(processor, encoder, cross_kv_model, decoder, wav_paths, language="en"): """Transcribe a list of files with the same language.""" total_audio_s = 0.0 total_wall_s = 0.0 for wav_path in wav_paths: text, audio_duration, wall_s = transcribe_file( processor, encoder, cross_kv_model, decoder, wav_path, language, ) total_audio_s += audio_duration total_wall_s += wall_s rtf = wall_s / audio_duration print(f"{wav_path} ({audio_duration:.1f}s, RTF={rtf:.2f}): {text}") if len(wav_paths) > 1: rtf_total = total_wall_s / total_audio_s print(f"\nTotal: {total_audio_s:.1f}s audio in {total_wall_s:.1f}s wall (RTF={rtf_total:.2f})") def download_demo_samples(n_per_lang=2): """Download random English and Spanish samples via streaming.""" import random import soundfile as sf from datasets import load_dataset demo_dir = Path("demo_audio") demo_dir.mkdir(exist_ok=True) wav_paths = [] sources = [ ("hf-internal-testing/librispeech_asr_dummy", "clean", "validation", "en"), ("facebook/multilingual_librispeech", "spanish", "test", "es"), ] for dataset_id, config, split, lang_tag in sources: print(f"Downloading {n_per_lang} random {lang_tag} samples from {dataset_id}...") ds = load_dataset(dataset_id, config, split=split, streaming=True) reservoir = [] for idx, sample in enumerate(ds): if idx < n_per_lang: reservoir.append(sample) else: j = random.randint(0, idx) if j < n_per_lang: reservoir[j] = sample if idx >= 50: break for i, sample in enumerate(reservoir): audio = sample["audio"] wav_path = demo_dir / f"{lang_tag}_{i}.wav" sf.write(str(wav_path), audio["array"], audio["sampling_rate"]) wav_paths.append((str(wav_path), lang_tag)) print(f" {wav_path} ({len(audio['array']) / audio['sampling_rate']:.1f}s)") return wav_paths def main(): processor, encoder, cross_kv_model, decoder = load_models() if len(sys.argv) < 2: # No args: download demo samples and transcribe them samples = download_demo_samples() for wav_path, language in samples: text, dur, wall = transcribe_file( processor, encoder, cross_kv_model, decoder, wav_path, language, ) print(f"{wav_path} ({dur:.1f}s, RTF={wall/dur:.2f}): {text}") return target = sys.argv[1] language = sys.argv[2] if len(sys.argv) > 2 else "en" if os.path.isdir(target): wav_paths = sorted(glob.glob(os.path.join(target, "*.wav"))) if not wav_paths: print(f"No .wav files found in {target}/") sys.exit(1) else: wav_paths = [target] transcribe_files(processor, encoder, cross_kv_model, decoder, wav_paths, language) if __name__ == "__main__": main()