#!/usr/bin/env python3 """ MERaLiON-3-10B + TurboQuant KV Cache Compression Demonstrates inference with MERaLiON-3-10B-preview using TurboQuant for KV cache compression at various bit widths (2, 3, 4). Architecture: - Speech encoder: Whisper-Large-V3 with weighted layer sum - Adaptor: MLP (scale_factor=5, SiLU, out_proj) - Text decoder: Gemma2-9B (42 layers, GQA 16/8 heads, head_dim=256) - TurboQuant compresses the Gemma2 decoder's KV cache Requirements: pip install torch transformers turboquant soundfile librosa Usage: python inference.py --audio test_speech.wav python inference.py --audio test_speech.wav --bits 4 python inference.py --audio test_speech.wav --benchmark """ import argparse import json import time import os import sys import torch import numpy as np # numpy 2.x removed np.trapz; turboquant 0.2.0 still uses it if not hasattr(np, "trapz"): np.trapz = np.trapezoid MODEL_ID = "MERaLiON/MERaLiON-3-10B-preview" def get_device(): """Select best available device.""" if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" return "cpu" def load_audio(path, sr=16000): """Load audio file and resample to target sample rate.""" import soundfile as sf import librosa audio, orig_sr = sf.read(path) if audio.ndim > 1: audio = audio.mean(axis=1) if orig_sr != sr: audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr) return audio.astype(np.float32) def load_model_and_processor(device="mps", dtype=torch.bfloat16): """Load MERaLiON-3-10B-preview with trust_remote_code.""" from transformers import AutoProcessor, AutoConfig, AutoModelForSpeechSeq2Seq print(f"Loading processor from {MODEL_ID}...") processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) # Patch config: MERaLiON3Config is missing pad_token_id, which # causes an AttributeError in transformers >= 5.x config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) config.pad_token_id = -1 print(f"Loading model from {MODEL_ID} (dtype={dtype})...") t0 = time.time() model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, config=config, dtype=dtype, trust_remote_code=True, low_cpu_mem_usage=True, ) if device != "cpu": t1 = time.time() print(f" Downloaded/loaded weights in {t1 - t0:.1f}s, moving to {device}...") model = model.to(device) print(f" Moved to {device} in {time.time() - t1:.1f}s") print(f"Model loaded in {time.time() - t0:.1f}s total") print(f" Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B") return model, processor def build_chat_text(processor, task="asr"): """Build the Gemma2 chat-format text prompt with a speech placeholder. The processor's _process_text() will expand the single token into fixed_speech_embeds_length * num_chunks copies. """ task_prompts = { "asr": "Transcribe the following speech.", "translate_zh": "Translate the following speech to Chinese.", "translate_ms": "Translate the following speech to Malay.", "translate_ta": "Translate the following speech to Tamil.", } instruction = task_prompts.get(task, task_prompts["asr"]) # MERaLiON3 uses Gemma2 chat format; speech_token = "" speech_token = processor.speech_token text = ( f"user\n" f"{speech_token}\n" f"{instruction}\n" f"model\n" ) return text def run_inference(model, processor, audio, task="asr", past_key_values=None, max_new_tokens=256): """Run speech-to-text inference, optionally with a custom KV cache.""" device = next(model.parameters()).device # Build text prompt text = build_chat_text(processor, task=task) # Use processor to prepare inputs inputs = processor( text=text, audios=audio, sampling_rate=16000, return_tensors="pt", padding=True, ) # Cast float tensors to model dtype (bfloat16) to match model weights model_dtype = next(model.parameters()).dtype for k, v in inputs.items(): if hasattr(v, 'to'): if v.is_floating_point(): inputs[k] = v.to(device=device, dtype=model_dtype) else: inputs[k] = v.to(device) generate_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=False, ) if past_key_values is not None: generate_kwargs["past_key_values"] = past_key_values # Warm up (first run compiles/traces on MPS) t0 = time.time() with torch.no_grad(): output_ids = model.generate(**inputs, **generate_kwargs) elapsed = time.time() - t0 # Decode output text_out = processor.batch_decode(output_ids, skip_special_tokens=True)[0] # Count generated tokens input_len = inputs.get("input_ids", output_ids).shape[-1] num_tokens = output_ids.shape[-1] - input_len if num_tokens <= 0: num_tokens = output_ids.shape[-1] return text_out, elapsed, num_tokens def benchmark(model, processor, audio, task="asr", max_new_tokens=256): """Run baseline + TurboQuant 4-bit + TurboQuant 2-bit, report results.""" from turboquant import TurboQuantCache results = {} # --- Baseline (no cache compression) --- print("\n" + "=" * 60) print("BASELINE (BF16 KV cache)") print("=" * 60) text, elapsed, ntok = run_inference( model, processor, audio, task=task, max_new_tokens=max_new_tokens ) tps = ntok / elapsed if elapsed > 0 else 0 print(f" Output: {text}") print(f" Tokens: {ntok}, Time: {elapsed:.2f}s, Speed: {tps:.1f} tok/s") results["baseline"] = {"text": text, "time": elapsed, "tokens": ntok, "tps": tps} # --- TurboQuant 4-bit --- print("\n" + "=" * 60) print("TURBOQUANT 4-BIT KV CACHE") print("=" * 60) cache_4bit = TurboQuantCache(bits=4) text4, elapsed4, ntok4 = run_inference( model, processor, audio, task=task, past_key_values=cache_4bit, max_new_tokens=max_new_tokens ) tps4 = ntok4 / elapsed4 if elapsed4 > 0 else 0 print(f" Output: {text4}") print(f" Tokens: {ntok4}, Time: {elapsed4:.2f}s, Speed: {tps4:.1f} tok/s") results["turboquant_4bit"] = {"text": text4, "time": elapsed4, "tokens": ntok4, "tps": tps4} # --- TurboQuant 2-bit --- print("\n" + "=" * 60) print("TURBOQUANT 2-BIT KV CACHE") print("=" * 60) cache_2bit = TurboQuantCache(bits=2) text2, elapsed2, ntok2 = run_inference( model, processor, audio, task=task, past_key_values=cache_2bit, max_new_tokens=max_new_tokens ) tps2 = ntok2 / elapsed2 if elapsed2 > 0 else 0 print(f" Output: {text2}") print(f" Tokens: {ntok2}, Time: {elapsed2:.2f}s, Speed: {tps2:.1f} tok/s") results["turboquant_2bit"] = {"text": text2, "time": elapsed2, "tokens": ntok2, "tps": tps2} # --- Summary --- print("\n" + "=" * 60) print("BENCHMARK SUMMARY") print("=" * 60) print(f"{'Config':<22} {'Tokens':>6} {'Time':>8} {'Tok/s':>8} {'Speedup':>8}") print("-" * 60) base_time = results["baseline"]["time"] for name, r in results.items(): speedup = base_time / r["time"] if r["time"] > 0 else 0 print(f"{name:<22} {r['tokens']:>6} {r['time']:>7.2f}s {r['tps']:>7.1f} {speedup:>7.2f}x") return results def main(): parser = argparse.ArgumentParser( description="MERaLiON-3-10B inference with TurboQuant KV cache compression" ) parser.add_argument("--audio", type=str, required=True, help="Path to audio file") parser.add_argument("--task", type=str, default="asr", choices=["asr", "translate_zh", "translate_ms", "translate_ta"], help="Task to perform (default: asr)") parser.add_argument("--bits", type=int, default=None, help="TurboQuant bits (2, 3, or 4). Omit for no compression.") parser.add_argument("--benchmark", action="store_true", help="Run full benchmark (baseline + 4-bit + 2-bit)") parser.add_argument("--device", type=str, default=None, help="Device: auto, cpu, cuda, mps (default: auto-detect)") parser.add_argument("--max-tokens", type=int, default=256, help="Maximum new tokens to generate (default: 256)") args = parser.parse_args() device = args.device or get_device() print(f"Device: {device}") print(f"Model: {MODEL_ID}") model, processor = load_model_and_processor(device=device) audio = load_audio(args.audio) print(f"Audio: {args.audio} ({len(audio)/16000:.1f}s)") if args.benchmark: results = benchmark(model, processor, audio, task=args.task, max_new_tokens=args.max_tokens) out_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "benchmark_results.json") with open(out_path, "w") as f: json.dump(results, f, indent=2) print(f"\nResults saved to {out_path}") else: cache = None label = "no compression" if args.bits is not None: from turboquant import TurboQuantCache cache = TurboQuantCache(bits=args.bits) label = f"TurboQuant {args.bits}-bit" print(f"Using {label} KV cache compression") text, elapsed, ntok = run_inference( model, processor, audio, task=args.task, past_key_values=cache, max_new_tokens=args.max_tokens ) tps = ntok / elapsed if elapsed > 0 else 0 print(f"\nResult ({label}): {text}") print(f"Tokens: {ntok}, Time: {elapsed:.2f}s, Speed: {tps:.1f} tok/s") if __name__ == "__main__": main()