eschmidbauer's picture
update transcribe.py to download samples for inference
b76c79e
raw
history blame contribute delete
8.92 kB
#!/usr/bin/env python3
"""
Transcribe audio using the exported ONNX models.
Usage:
python transcribe.py # download demo samples and transcribe
python transcribe.py audio.wav
python transcribe.py audio_dir/
python transcribe.py audio.wav es # specify language
"""
import glob
import os
import sys
import time
from pathlib import Path
import librosa
import numpy as np
import onnxruntime as ort
from transformers import AutoProcessor
MODEL_ID = "CohereLabs/cohere-transcribe-03-2026"
NUM_DECODER_LAYERS = 8
DECODER_HEADS = 8
HEAD_DIM = 128
# ---------------------------------------------------------------------------
# ONNX model wrappers
# ---------------------------------------------------------------------------
class OnnxEncoder:
"""Chains encoder-0.onnx .. encoder-N.onnx."""
def __init__(self, paths: list[str]):
self.splits = [ort.InferenceSession(p, providers=["CPUExecutionProvider"]) for p in paths]
def __call__(self, input_features: np.ndarray, length: np.ndarray):
outs = self.splits[0].run(None, {"input_features": input_features, "length": length})
for split in self.splits[1:]:
feeds = {inp.name: outs[j] for j, inp in enumerate(split.get_inputs())}
outs = split.run(None, feeds)
return outs[0], outs[1] # encoder_out, encoder_lengths
class OnnxCrossKV:
def __init__(self, path: str):
self.sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
def __call__(self, encoder_out: np.ndarray):
return list(self.sess.run(None, {"encoder_out": encoder_out}))
class OnnxDecoder:
def __init__(self, path: str):
self.sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
def __call__(self, input_ids, positions, cross_attn_mask, self_kv, cross_kv):
feeds = {
"input_ids": input_ids,
"positions": positions,
"cross_attention_mask": cross_attn_mask,
}
for i in range(NUM_DECODER_LAYERS):
feeds[f"self_k_in_{i}"] = self_kv[i * 2]
feeds[f"self_v_in_{i}"] = self_kv[i * 2 + 1]
for i in range(NUM_DECODER_LAYERS):
feeds[f"cross_k_in_{i}"] = cross_kv[i * 2]
feeds[f"cross_v_in_{i}"] = cross_kv[i * 2 + 1]
outs = self.sess.run(None, feeds)
return outs[0], list(outs[1:]) # logits, self_kv_out
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def greedy_decode(encoder, cross_kv_model, decoder, input_features, length,
prompt_ids, eos_token_id, max_new_tokens=256):
B = 1
# Encode once
encoder_out, enc_lengths = encoder(input_features, length)
cross_kv = cross_kv_model(encoder_out)
src_len = cross_kv[0].shape[2]
# Cross-attention mask
enc_len = int(enc_lengths[0])
cross_mask = np.zeros((B, 1, 1, src_len), dtype=np.float32)
if enc_len < src_len:
cross_mask[:, :, :, enc_len:] = -1e9
# Empty self-attention KV cache
self_kv = []
for _ in range(NUM_DECODER_LAYERS):
self_kv.append(np.zeros((B, DECODER_HEADS, 0, HEAD_DIM), dtype=np.float32))
self_kv.append(np.zeros((B, DECODER_HEADS, 0, HEAD_DIM), dtype=np.float32))
# Feed prompt
all_ids = prompt_ids.tolist()
for step, tok in enumerate(all_ids):
logits, self_kv = decoder(
np.array([[tok]], dtype=np.int64),
np.array([[step]], dtype=np.int64),
cross_mask, self_kv, cross_kv,
)
# Generate
generated = []
for gen_step in range(max_new_tokens):
next_id = int(np.argmax(logits[0, -1, :]))
if next_id == eos_token_id:
break
generated.append(next_id)
logits, self_kv = decoder(
np.array([[next_id]], dtype=np.int64),
np.array([[len(all_ids) + gen_step]], dtype=np.int64),
cross_mask, self_kv, cross_kv,
)
return generated
def load_models():
"""Load processor and ONNX models once."""
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
encoder_paths = sorted(glob.glob("encoder-*.onnx"))
if not encoder_paths:
print("No encoder-*.onnx files found. Run export_onnx.py first.")
sys.exit(1)
print(f"Loading {len(encoder_paths)} encoder splits, cross_kv, decoder...")
encoder = OnnxEncoder(encoder_paths)
cross_kv_model = OnnxCrossKV("cross_kv.onnx")
decoder = OnnxDecoder("decoder.onnx")
return processor, encoder, cross_kv_model, decoder
def transcribe_file(processor, encoder, cross_kv_model, decoder, wav_path, language="en"):
"""Transcribe a single file. Returns (text, audio_duration, wall_time)."""
pnc_token = "<|pnc|>"
prompt_text = (
"<|startofcontext|><|startoftranscript|><|emo:undefined|>"
f"<|{language}|><|{language}|>{pnc_token}<|noitn|><|notimestamp|><|nodiarize|>"
)
prompt_ids = np.array(
processor.tokenizer.encode(prompt_text, add_special_tokens=False), dtype=np.int64
)
eos_token_id = processor.tokenizer.eos_token_id
audio, _ = librosa.load(wav_path, sr=16000)
audio_duration = len(audio) / 16000
inputs = processor(audio=[audio], text=[prompt_text], sampling_rate=16000, return_tensors="np")
features = inputs["input_features"]
length = np.array([features.shape[2]], dtype=np.int64)
t0 = time.perf_counter()
generated_ids = greedy_decode(
encoder, cross_kv_model, decoder, features, length, prompt_ids, eos_token_id,
)
wall_s = time.perf_counter() - t0
text = processor.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return text, audio_duration, wall_s
def transcribe_files(processor, encoder, cross_kv_model, decoder, wav_paths, language="en"):
"""Transcribe a list of files with the same language."""
total_audio_s = 0.0
total_wall_s = 0.0
for wav_path in wav_paths:
text, audio_duration, wall_s = transcribe_file(
processor, encoder, cross_kv_model, decoder, wav_path, language,
)
total_audio_s += audio_duration
total_wall_s += wall_s
rtf = wall_s / audio_duration
print(f"{wav_path} ({audio_duration:.1f}s, RTF={rtf:.2f}): {text}")
if len(wav_paths) > 1:
rtf_total = total_wall_s / total_audio_s
print(f"\nTotal: {total_audio_s:.1f}s audio in {total_wall_s:.1f}s wall (RTF={rtf_total:.2f})")
def download_demo_samples(n_per_lang=2):
"""Download random English and Spanish samples via streaming."""
import random
import soundfile as sf
from datasets import load_dataset
demo_dir = Path("demo_audio")
demo_dir.mkdir(exist_ok=True)
wav_paths = []
sources = [
("hf-internal-testing/librispeech_asr_dummy", "clean", "validation", "en"),
("facebook/multilingual_librispeech", "spanish", "test", "es"),
]
for dataset_id, config, split, lang_tag in sources:
print(f"Downloading {n_per_lang} random {lang_tag} samples from {dataset_id}...")
ds = load_dataset(dataset_id, config, split=split, streaming=True)
reservoir = []
for idx, sample in enumerate(ds):
if idx < n_per_lang:
reservoir.append(sample)
else:
j = random.randint(0, idx)
if j < n_per_lang:
reservoir[j] = sample
if idx >= 50:
break
for i, sample in enumerate(reservoir):
audio = sample["audio"]
wav_path = demo_dir / f"{lang_tag}_{i}.wav"
sf.write(str(wav_path), audio["array"], audio["sampling_rate"])
wav_paths.append((str(wav_path), lang_tag))
print(f" {wav_path} ({len(audio['array']) / audio['sampling_rate']:.1f}s)")
return wav_paths
def main():
processor, encoder, cross_kv_model, decoder = load_models()
if len(sys.argv) < 2:
# No args: download demo samples and transcribe them
samples = download_demo_samples()
for wav_path, language in samples:
text, dur, wall = transcribe_file(
processor, encoder, cross_kv_model, decoder, wav_path, language,
)
print(f"{wav_path} ({dur:.1f}s, RTF={wall/dur:.2f}): {text}")
return
target = sys.argv[1]
language = sys.argv[2] if len(sys.argv) > 2 else "en"
if os.path.isdir(target):
wav_paths = sorted(glob.glob(os.path.join(target, "*.wav")))
if not wav_paths:
print(f"No .wav files found in {target}/")
sys.exit(1)
else:
wav_paths = [target]
transcribe_files(processor, encoder, cross_kv_model, decoder, wav_paths, language)
if __name__ == "__main__":
main()