from __future__ import annotations import gc import os import re import tempfile import time import unicodedata from pathlib import Path from typing import Any import gradio as gr import torch import torchaudio from faster_whisper import BatchedInferencePipeline, WhisperModel from jiwer import cer, wer from transformers import AutoModel WHISPER_MODEL_ID = "Sh1man/whisper-large-v3-russian-ties-podlodka-v1.2-ct" GIGAAM_MODEL_ID = "ai-sage/GigaAM-v3" GIGAAM_REVISION = "e2e_rnnt" TARGET_SAMPLE_RATE = 16_000 WHISPER_BEAM_SIZE = 5 WHISPER_BATCH_SIZE = 8 if torch.cuda.is_available() else 4 WHISPER_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" WHISPER_COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8" MODEL_LABELS = { "whisper": "Sh1man Whisper Large V3 CT", "gigaam": "GigaAM v3 e2e RNNT", } MODEL_STATE: dict[str, Any] = {"name": None, "instance": None} def cleanup_loaded_model() -> None: loaded = MODEL_STATE.get("instance") MODEL_STATE["name"] = None MODEL_STATE["instance"] = None if loaded is not None: del loaded gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def get_model(model_name: str) -> Any: if MODEL_STATE["name"] == model_name and MODEL_STATE["instance"] is not None: return MODEL_STATE["instance"] cleanup_loaded_model() if model_name == "whisper": whisper_model = WhisperModel( WHISPER_MODEL_ID, device=WHISPER_DEVICE, compute_type=WHISPER_COMPUTE_TYPE, ) model = BatchedInferencePipeline(model=whisper_model) elif model_name == "gigaam": model = AutoModel.from_pretrained( GIGAAM_MODEL_ID, revision=GIGAAM_REVISION, trust_remote_code=True, ) if hasattr(model, "eval"): model.eval() if torch.cuda.is_available() and hasattr(model, "to"): model = model.to("cuda") else: raise ValueError(f"Unsupported model name: {model_name}") MODEL_STATE["name"] = model_name MODEL_STATE["instance"] = model return model def collapse_spaces(text: str) -> str: return " ".join(text.split()) def normalize_for_metrics(text: str, enabled: bool) -> str: text = unicodedata.normalize("NFKC", text.strip()) if not enabled: return collapse_spaces(text) text = text.lower().replace("ё", "е") text = re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE) text = text.replace("_", " ") return collapse_spaces(text) def extract_text(result: Any) -> str: if isinstance(result, str): return result if isinstance(result, dict): for key in ("text", "transcription", "prediction"): value = result.get(key) if isinstance(value, str): return value if "chunks" in result and isinstance(result["chunks"], list): return " ".join( extract_text(chunk) for chunk in result["chunks"] if chunk is not None ).strip() if isinstance(result, list): return " ".join(extract_text(item) for item in result if item is not None).strip() return str(result) def prepare_audio_file(audio_path: str) -> tuple[tempfile.TemporaryDirectory, str, float]: waveform, sample_rate = torchaudio.load(audio_path) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if sample_rate != TARGET_SAMPLE_RATE: waveform = torchaudio.functional.resample(waveform, sample_rate, TARGET_SAMPLE_RATE) duration_seconds = waveform.shape[1] / TARGET_SAMPLE_RATE temp_dir = tempfile.TemporaryDirectory() prepared_audio_path = Path(temp_dir.name) / "prepared_audio.wav" torchaudio.save(str(prepared_audio_path), waveform, TARGET_SAMPLE_RATE) return temp_dir, str(prepared_audio_path), duration_seconds def transcribe_with_whisper(prepared_audio_path: str) -> tuple[str, str]: transcriber = get_model("whisper") segments, _ = transcriber.transcribe( prepared_audio_path, batch_size=WHISPER_BATCH_SIZE, beam_size=WHISPER_BEAM_SIZE, language="ru", word_timestamps=False, ) transcription = collapse_spaces(" ".join(segment.text for segment in segments if segment.text)) mode_note = ( "Whisper использовал `faster-whisper` + `BatchedInferencePipeline` " f"с VAD по умолчанию, `beam_size={WHISPER_BEAM_SIZE}`, " f"`batch_size={WHISPER_BATCH_SIZE}`, `compute_type={WHISPER_COMPUTE_TYPE}`." ) return transcription, mode_note def format_boundary(boundary: Any) -> str: if not isinstance(boundary, (tuple, list)) or len(boundary) != 2: return "" start, end = boundary return f"[{start:.2f}-{end:.2f}]" def extract_longform_text(result: Any) -> str: if not isinstance(result, list): return collapse_spaces(extract_text(result)) parts: list[str] = [] for segment in result: if isinstance(segment, dict): segment_text = extract_text(segment) else: segment_text = extract_text(segment) if segment_text: parts.append(collapse_spaces(segment_text)) return collapse_spaces(" ".join(parts)) def transcribe_with_gigaam(audio_path: str) -> tuple[str, int]: if not os.getenv("HF_TOKEN"): raise ValueError( "Для GigaAM longform нужен секрет HF_TOKEN с доступом к " "'pyannote/segmentation-3.0'. Добавь его в Settings -> Variables and secrets." ) transcriber = get_model("gigaam") with torch.inference_mode(): result = transcriber.transcribe_longform(audio_path) return extract_longform_text(result), len(result) if isinstance(result, list) else 0 def load_reference_text(reference_text: str, reference_file: str | None) -> str: if reference_text.strip(): return reference_text.strip() if reference_file: for encoding in ("utf-8", "utf-8-sig", "cp1251"): try: return Path(reference_file).read_text(encoding=encoding).strip() except UnicodeDecodeError: continue raise ValueError("Не удалось прочитать эталонный текстовый файл.") return "" def format_metric(value: float | None) -> str: if value is None: return "n/a" return f"{value:.4f}" def benchmark_audio( audio_path: str | None, reference_text: str, reference_file: str | None, selected_models: list[str], normalize_metrics: bool, ) -> tuple[list[list[Any]], str, str, str]: if not audio_path: raise gr.Error("Загрузи аудиофайл для транскрибации.") if not selected_models: raise gr.Error("Выбери хотя бы одну модель.") reference = load_reference_text(reference_text, reference_file) normalized_reference = normalize_for_metrics(reference, normalize_metrics) if reference else "" temporary_dir: tempfile.TemporaryDirectory | None = None try: temporary_dir, prepared_audio_path, duration_seconds = prepare_audio_file(audio_path) whisper_text = "Модель не запускалась." gigaam_text = "Модель не запускалась." rows: list[list[Any]] = [] whisper_mode_note: str | None = None gigaam_segment_count: int | None = None for model_name in selected_models: started_at = time.perf_counter() if model_name == "whisper": transcription, whisper_mode_note = transcribe_with_whisper(prepared_audio_path) whisper_text = transcription or "Пустой результат." elif model_name == "gigaam": transcription, gigaam_segment_count = transcribe_with_gigaam(prepared_audio_path) gigaam_text = transcription or "Пустой результат." else: continue elapsed = time.perf_counter() - started_at current_wer: float | None = None current_cer: float | None = None if normalized_reference: normalized_prediction = normalize_for_metrics(transcription, normalize_metrics) current_wer = wer(normalized_reference, normalized_prediction) current_cer = cer(normalized_reference, normalized_prediction) rows.append( [ MODEL_LABELS[model_name], format_metric(current_wer), format_metric(current_cer), round(elapsed, 2), ] ) summary_lines = [ f"- Длительность аудио: `{duration_seconds:.1f}` сек.", ] if whisper_mode_note is not None: summary_lines.append(f"- {whisper_mode_note}") if gigaam_segment_count is not None: summary_lines.append( f"- GigaAM использовал встроенный `transcribe_longform` и собрал `{gigaam_segment_count}` сегментов через VAD." ) if reference: normalization_note = "с нормализацией" if normalize_metrics else "без нормализации" summary_lines.append(f"- `WER` и `CER` посчитаны {normalization_note}.") else: summary_lines.append("- Эталонный текст не задан, метрики пропущены.") return rows, whisper_text, gigaam_text, "\n".join(summary_lines) except Exception as error: raise gr.Error(f"Ошибка обработки: {error}") from error finally: if temporary_dir is not None: temporary_dir.cleanup() with gr.Blocks(title="Russian ASR Benchmark Space") as demo: gr.Markdown( """ # Russian ASR Benchmark Сравнение двух ASR-моделей: - `Sh1man/whisper-large-v3-russian-ties-podlodka-v1.2-ct` - `ai-sage/GigaAM-v3` c revision `e2e_rnnt` Загрузи аудио, вставь эталонный текст или приложи `.txt`, и Space посчитает `WER` / `CER` для каждой модели. Для `GigaAM` используется встроенный `transcribe_longform`. Для него нужен `HF_TOKEN` в секретах Space с доступом к `pyannote/segmentation-3.0`. """ ) with gr.Row(): audio_input = gr.Audio( label="Аудиофайл", type="filepath", sources=["upload", "microphone"], ) with gr.Column(): reference_input = gr.Textbox( label="Эталонный текст", placeholder="Вставь правильную расшифровку сюда", lines=10, ) reference_file_input = gr.File( label="Или загрузи эталонный текст (.txt)", file_types=[".txt"], type="filepath", ) with gr.Row(): model_selector = gr.CheckboxGroup( label="Модели для запуска", choices=[ ("Sh1man Whisper Large V3 CT", "whisper"), ("GigaAM v3 e2e RNNT", "gigaam"), ], value=["whisper", "gigaam"], ) normalize_checkbox = gr.Checkbox( label="Нормализовать текст перед подсчётом метрик", value=True, info="Приводит текст к нижнему регистру, схлопывает пробелы и убирает пунктуацию.", ) run_button = gr.Button("Транскрибировать и посчитать метрики", variant="primary") results_table = gr.Dataframe( headers=["Модель", "WER", "CER", "Время (сек)"], datatype=["str", "str", "str", "number"], label="Результаты сравнения", ) status_output = gr.Markdown("Статус появится после запуска.") with gr.Row(): whisper_output = gr.Textbox( label="Транскрипт: Sh1man Whisper Large V3 CT", lines=12, ) gigaam_output = gr.Textbox( label="Транскрипт: GigaAM v3 e2e RNNT", lines=12, ) run_button.click( fn=benchmark_audio, inputs=[ audio_input, reference_input, reference_file_input, model_selector, normalize_checkbox, ], outputs=[ results_table, whisper_output, gigaam_output, status_output, ], ) gr.Markdown( """ Первая инференс-сессия может идти заметно дольше из-за скачивания весов. `Whisper` здесь настроен как `faster-whisper` на CTranslate2 через `BatchedInferencePipeline` с VAD по умолчанию и `beam_size=5`. `GigaAM` использует встроенный longform-режим через `transcribe_longform` и VAD из `pyannote/segmentation-3.0`. """ ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1).launch()