| |
| """ |
| Transcribe audio using the exported ONNX models. |
| |
| Usage: |
| python transcribe.py # download demo samples and transcribe |
| python transcribe.py audio.wav |
| python transcribe.py audio_dir/ |
| python transcribe.py audio.wav es # specify language |
| """ |
|
|
| import glob |
| import os |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import librosa |
| import numpy as np |
| import onnxruntime as ort |
| from transformers import AutoProcessor |
|
|
| MODEL_ID = "CohereLabs/cohere-transcribe-03-2026" |
| NUM_DECODER_LAYERS = 8 |
| DECODER_HEADS = 8 |
| HEAD_DIM = 128 |
|
|
|
|
| |
| |
| |
|
|
| class OnnxEncoder: |
| """Chains encoder-0.onnx .. encoder-N.onnx.""" |
|
|
| def __init__(self, paths: list[str]): |
| self.splits = [ort.InferenceSession(p, providers=["CPUExecutionProvider"]) for p in paths] |
|
|
| def __call__(self, input_features: np.ndarray, length: np.ndarray): |
| outs = self.splits[0].run(None, {"input_features": input_features, "length": length}) |
| for split in self.splits[1:]: |
| feeds = {inp.name: outs[j] for j, inp in enumerate(split.get_inputs())} |
| outs = split.run(None, feeds) |
| return outs[0], outs[1] |
|
|
|
|
| class OnnxCrossKV: |
| def __init__(self, path: str): |
| self.sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) |
|
|
| def __call__(self, encoder_out: np.ndarray): |
| return list(self.sess.run(None, {"encoder_out": encoder_out})) |
|
|
|
|
| class OnnxDecoder: |
| def __init__(self, path: str): |
| self.sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) |
|
|
| def __call__(self, input_ids, positions, cross_attn_mask, self_kv, cross_kv): |
| feeds = { |
| "input_ids": input_ids, |
| "positions": positions, |
| "cross_attention_mask": cross_attn_mask, |
| } |
| for i in range(NUM_DECODER_LAYERS): |
| feeds[f"self_k_in_{i}"] = self_kv[i * 2] |
| feeds[f"self_v_in_{i}"] = self_kv[i * 2 + 1] |
| for i in range(NUM_DECODER_LAYERS): |
| feeds[f"cross_k_in_{i}"] = cross_kv[i * 2] |
| feeds[f"cross_v_in_{i}"] = cross_kv[i * 2 + 1] |
|
|
| outs = self.sess.run(None, feeds) |
| return outs[0], list(outs[1:]) |
|
|
|
|
| |
| |
| |
|
|
| def greedy_decode(encoder, cross_kv_model, decoder, input_features, length, |
| prompt_ids, eos_token_id, max_new_tokens=256): |
| B = 1 |
|
|
| |
| encoder_out, enc_lengths = encoder(input_features, length) |
| cross_kv = cross_kv_model(encoder_out) |
| src_len = cross_kv[0].shape[2] |
|
|
| |
| enc_len = int(enc_lengths[0]) |
| cross_mask = np.zeros((B, 1, 1, src_len), dtype=np.float32) |
| if enc_len < src_len: |
| cross_mask[:, :, :, enc_len:] = -1e9 |
|
|
| |
| self_kv = [] |
| for _ in range(NUM_DECODER_LAYERS): |
| self_kv.append(np.zeros((B, DECODER_HEADS, 0, HEAD_DIM), dtype=np.float32)) |
| self_kv.append(np.zeros((B, DECODER_HEADS, 0, HEAD_DIM), dtype=np.float32)) |
|
|
| |
| all_ids = prompt_ids.tolist() |
| for step, tok in enumerate(all_ids): |
| logits, self_kv = decoder( |
| np.array([[tok]], dtype=np.int64), |
| np.array([[step]], dtype=np.int64), |
| cross_mask, self_kv, cross_kv, |
| ) |
|
|
| |
| generated = [] |
| for gen_step in range(max_new_tokens): |
| next_id = int(np.argmax(logits[0, -1, :])) |
| if next_id == eos_token_id: |
| break |
| generated.append(next_id) |
| logits, self_kv = decoder( |
| np.array([[next_id]], dtype=np.int64), |
| np.array([[len(all_ids) + gen_step]], dtype=np.int64), |
| cross_mask, self_kv, cross_kv, |
| ) |
|
|
| return generated |
|
|
|
|
| def load_models(): |
| """Load processor and ONNX models once.""" |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
| encoder_paths = sorted(glob.glob("encoder-*.onnx")) |
| if not encoder_paths: |
| print("No encoder-*.onnx files found. Run export_onnx.py first.") |
| sys.exit(1) |
|
|
| print(f"Loading {len(encoder_paths)} encoder splits, cross_kv, decoder...") |
| encoder = OnnxEncoder(encoder_paths) |
| cross_kv_model = OnnxCrossKV("cross_kv.onnx") |
| decoder = OnnxDecoder("decoder.onnx") |
| return processor, encoder, cross_kv_model, decoder |
|
|
|
|
| def transcribe_file(processor, encoder, cross_kv_model, decoder, wav_path, language="en"): |
| """Transcribe a single file. Returns (text, audio_duration, wall_time).""" |
| pnc_token = "<|pnc|>" |
| prompt_text = ( |
| "<|startofcontext|><|startoftranscript|><|emo:undefined|>" |
| f"<|{language}|><|{language}|>{pnc_token}<|noitn|><|notimestamp|><|nodiarize|>" |
| ) |
| prompt_ids = np.array( |
| processor.tokenizer.encode(prompt_text, add_special_tokens=False), dtype=np.int64 |
| ) |
| eos_token_id = processor.tokenizer.eos_token_id |
|
|
| audio, _ = librosa.load(wav_path, sr=16000) |
| audio_duration = len(audio) / 16000 |
| inputs = processor(audio=[audio], text=[prompt_text], sampling_rate=16000, return_tensors="np") |
| features = inputs["input_features"] |
| length = np.array([features.shape[2]], dtype=np.int64) |
|
|
| t0 = time.perf_counter() |
| generated_ids = greedy_decode( |
| encoder, cross_kv_model, decoder, features, length, prompt_ids, eos_token_id, |
| ) |
| wall_s = time.perf_counter() - t0 |
|
|
| text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
| return text, audio_duration, wall_s |
|
|
|
|
| def transcribe_files(processor, encoder, cross_kv_model, decoder, wav_paths, language="en"): |
| """Transcribe a list of files with the same language.""" |
| total_audio_s = 0.0 |
| total_wall_s = 0.0 |
|
|
| for wav_path in wav_paths: |
| text, audio_duration, wall_s = transcribe_file( |
| processor, encoder, cross_kv_model, decoder, wav_path, language, |
| ) |
| total_audio_s += audio_duration |
| total_wall_s += wall_s |
| rtf = wall_s / audio_duration |
| print(f"{wav_path} ({audio_duration:.1f}s, RTF={rtf:.2f}): {text}") |
|
|
| if len(wav_paths) > 1: |
| rtf_total = total_wall_s / total_audio_s |
| print(f"\nTotal: {total_audio_s:.1f}s audio in {total_wall_s:.1f}s wall (RTF={rtf_total:.2f})") |
|
|
|
|
| def download_demo_samples(n_per_lang=2): |
| """Download random English and Spanish samples via streaming.""" |
| import random |
| import soundfile as sf |
| from datasets import load_dataset |
|
|
| demo_dir = Path("demo_audio") |
| demo_dir.mkdir(exist_ok=True) |
|
|
| wav_paths = [] |
| sources = [ |
| ("hf-internal-testing/librispeech_asr_dummy", "clean", "validation", "en"), |
| ("facebook/multilingual_librispeech", "spanish", "test", "es"), |
| ] |
| for dataset_id, config, split, lang_tag in sources: |
| print(f"Downloading {n_per_lang} random {lang_tag} samples from {dataset_id}...") |
| ds = load_dataset(dataset_id, config, split=split, streaming=True) |
| reservoir = [] |
| for idx, sample in enumerate(ds): |
| if idx < n_per_lang: |
| reservoir.append(sample) |
| else: |
| j = random.randint(0, idx) |
| if j < n_per_lang: |
| reservoir[j] = sample |
| if idx >= 50: |
| break |
|
|
| for i, sample in enumerate(reservoir): |
| audio = sample["audio"] |
| wav_path = demo_dir / f"{lang_tag}_{i}.wav" |
| sf.write(str(wav_path), audio["array"], audio["sampling_rate"]) |
| wav_paths.append((str(wav_path), lang_tag)) |
| print(f" {wav_path} ({len(audio['array']) / audio['sampling_rate']:.1f}s)") |
|
|
| return wav_paths |
|
|
|
|
| def main(): |
| processor, encoder, cross_kv_model, decoder = load_models() |
|
|
| if len(sys.argv) < 2: |
| |
| samples = download_demo_samples() |
| for wav_path, language in samples: |
| text, dur, wall = transcribe_file( |
| processor, encoder, cross_kv_model, decoder, wav_path, language, |
| ) |
| print(f"{wav_path} ({dur:.1f}s, RTF={wall/dur:.2f}): {text}") |
| return |
|
|
| target = sys.argv[1] |
| language = sys.argv[2] if len(sys.argv) > 2 else "en" |
|
|
| if os.path.isdir(target): |
| wav_paths = sorted(glob.glob(os.path.join(target, "*.wav"))) |
| if not wav_paths: |
| print(f"No .wav files found in {target}/") |
| sys.exit(1) |
| else: |
| wav_paths = [target] |
|
|
| transcribe_files(processor, encoder, cross_kv_model, decoder, wav_paths, language) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|