from __future__ import annotations import json import mimetypes import os from pathlib import Path from typing import Any import gradio as gr import requests APP_ROOT = Path(__file__).parent EXAMPLES_MANIFEST_PATH = APP_ROOT / "examples" / "manifest.json" TRANSCRIPTIONS_API_URL = "https://api.cohere.ai/compatibility/v1/audio/transcriptions" DEFAULT_MODEL_ID = "cohere-transcribe-03-2026" DEFAULT_TIMEOUT_SECONDS = 180.0 DEFAULT_LANGUAGE = "en" MAX_AUDIO_FILE_BYTES = 25 * 1024 * 1024 MAX_AUDIO_FILE_LABEL = "25 MB" LANGUAGE_OPTIONS = [ ("English", "en"), ("French", "fr"), ("German", "de"), ("Spanish", "es"), ("Portuguese", "pt"), ("Italian", "it"), ("Dutch", "nl"), ("Polish", "pl"), ("Greek", "el"), ("Arabic", "ar"), ("Japanese", "ja"), ("Korean", "ko"), ("Chinese", "zh"), ("Vietnamese", "vi"), ] LANGUAGE_LABELS = {code: label for label, code in LANGUAGE_OPTIONS} APP_THEME = gr.themes.Soft( primary_hue="orange", secondary_hue="amber", neutral_hue="stone", font=[gr.themes.GoogleFont("Manrope"), "ui-sans-serif", "system-ui", "sans-serif"], ).set( body_background_fill="#fff7ed", body_background_fill_dark="#fff7ed", body_text_color="#6b4f3f", body_text_color_dark="#6b4f3f", body_text_color_subdued="#7c6558", body_text_color_subdued_dark="#7c6558", block_background_fill="#fffaf0", block_background_fill_dark="#fffaf0", block_border_color="rgba(194, 65, 12, 0.14)", block_border_color_dark="rgba(194, 65, 12, 0.14)", block_label_text_color="#7c2d12", block_label_text_color_dark="#7c2d12", input_background_fill="#fffdf7", input_background_fill_dark="#fffdf7", input_border_color="rgba(194, 65, 12, 0.18)", input_border_color_dark="rgba(194, 65, 12, 0.18)", button_primary_background_fill="#ea580c", button_primary_background_fill_dark="#ea580c", button_primary_background_fill_hover="#f97316", button_primary_background_fill_hover_dark="#f97316", button_primary_text_color="#fff7ed", button_primary_text_color_dark="#fff7ed", button_secondary_background_fill="#fffaf0", button_secondary_background_fill_dark="#fffaf0", button_secondary_background_fill_hover="#fff1e6", button_secondary_background_fill_hover_dark="#fff1e6", button_secondary_text_color="#7c2d12", button_secondary_text_color_dark="#7c2d12", button_secondary_border_color="rgba(194, 65, 12, 0.14)", button_secondary_border_color_dark="rgba(194, 65, 12, 0.14)", link_text_color="#c2410c", link_text_color_dark="#c2410c", ) class TranscriptionError(Exception): pass def get_api_key() -> str | None: api_key = os.getenv("COHERE_API_KEY", "").strip() return api_key or None def get_model_id() -> str: model_id = os.getenv("COHERE_ASR_MODEL", DEFAULT_MODEL_ID).strip() return model_id or DEFAULT_MODEL_ID def get_timeout_seconds() -> float: raw_value = os.getenv("COHERE_API_TIMEOUT_SECONDS", str(DEFAULT_TIMEOUT_SECONDS)) try: return max(30.0, float(raw_value)) except ValueError: return DEFAULT_TIMEOUT_SECONDS def get_debug_raw_response_enabled() -> bool: value = os.getenv("COHERE_ASR_DEBUG_RAW_RESPONSE", "").strip().lower() return value in {"1", "true", "yes", "on"} def display_language(language_code: str | None) -> str: if not language_code: return "Not returned" normalized = language_code.strip() base_code = normalized.split("-", 1)[0].lower() label = LANGUAGE_LABELS.get(base_code) if not label: return normalized if base_code == normalized.lower(): return label return f"{label} ({normalized})" def extract_transcript_text(payload: Any) -> str: if isinstance(payload, str): return payload.strip() if isinstance(payload, dict): for key in ("text", "transcript"): value = payload.get(key) if isinstance(value, str) and value.strip(): return value.strip() segments = payload.get("segments") if isinstance(segments, list): parts = [extract_transcript_text(segment) for segment in segments] joined = " ".join(part for part in parts if part) if joined: return joined.strip() for key in ("result", "data", "output"): nested = payload.get(key) text = extract_transcript_text(nested) if text: return text if isinstance(payload, list): parts = [extract_transcript_text(item) for item in payload] joined = " ".join(part for part in parts if part) return joined.strip() return "" def format_api_error(response: requests.Response) -> str: message = "" try: payload = response.json() except ValueError: payload = response.text.strip() if isinstance(payload, dict): error = payload.get("error") if isinstance(error, dict): message = str(error.get("message") or error.get("type") or "") elif isinstance(error, str): message = error if not message: for key in ("message", "detail", "error_description"): value = payload.get(key) if isinstance(value, str) and value.strip(): message = value.strip() break elif isinstance(payload, str): message = payload if not message: message = "Unexpected API response." return f"Transcription request failed ({response.status_code}): {message}" def validate_audio_file(audio_path: str) -> Path: audio_file_path = Path(audio_path) try: file_size = audio_file_path.stat().st_size except OSError as exc: raise TranscriptionError("The selected audio file could not be read.") from exc if file_size > MAX_AUDIO_FILE_BYTES: raise TranscriptionError( f"Audio files must be {MAX_AUDIO_FILE_LABEL} or smaller. " "Please upload a shorter clip or a more compressed file." ) return audio_file_path def call_transcriptions_api(audio_path: str, language: str) -> dict[str, Any]: api_key = get_api_key() if not api_key: raise TranscriptionError( "COHERE_API_KEY is not configured. Add it locally or in Hugging Face Space secrets." ) audio_file_path = validate_audio_file(audio_path) mime_type = mimetypes.guess_type(audio_file_path.name)[0] or "application/octet-stream" headers = { "Authorization": f"Bearer {api_key}", "Accept": "application/json", } with audio_file_path.open("rb") as audio_file: multipart_fields = [ ("model", (None, get_model_id())), ("language", (None, language)), ("file", (audio_file_path.name, audio_file, mime_type)), ] response = requests.post( TRANSCRIPTIONS_API_URL, headers=headers, files=multipart_fields, timeout=get_timeout_seconds(), ) if not response.ok: raise TranscriptionError(format_api_error(response)) try: payload = response.json() except ValueError as exc: raise TranscriptionError("The transcription API returned a non-JSON response.") from exc if not isinstance(payload, dict): raise TranscriptionError("The transcription API returned an unexpected payload shape.") return payload def load_example_manifest() -> list[dict[str, Any]]: if not EXAMPLES_MANIFEST_PATH.exists(): return [] try: payload = json.loads(EXAMPLES_MANIFEST_PATH.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return [] raw_examples = payload.get("examples", []) if not isinstance(raw_examples, list): return [] return [item for item in raw_examples if isinstance(item, dict)] def resolve_example_audio_path(raw_path: str | None) -> Path | None: if not raw_path: return None candidate = Path(raw_path) if not candidate.is_absolute(): candidate = APP_ROOT / candidate if candidate.exists(): return candidate return None def build_example_table_rows(examples: list[dict[str, Any]]) -> list[list[str]]: rows: list[list[str]] = [] for example in examples: audio_path = resolve_example_audio_path(example.get("audio_path")) is_ready = bool(example.get("enabled")) and audio_path is not None status = "Ready" if is_ready else "Waiting for approved audio" rows.append( [ str(example.get("label") or "Untitled example"), display_language(example.get("language")), str(example.get("description") or ""), status, ] ) return rows def build_gradio_examples(examples: list[dict[str, Any]]) -> tuple[list[list[str]], list[str]]: example_inputs: list[list[str]] = [] example_labels: list[str] = [] for example in examples: audio_path = resolve_example_audio_path(example.get("audio_path")) if not example.get("enabled") or audio_path is None: continue example_inputs.append( [ str(audio_path), str(example.get("language") or DEFAULT_LANGUAGE), ] ) example_labels.append(str(example.get("label") or audio_path.stem)) return example_inputs, example_labels def build_transcription_outputs(audio_path: str | None, language: str) -> tuple[str, str]: if not audio_path: raise gr.Error("Record or upload an audio clip before starting transcription.") try: payload = call_transcriptions_api(audio_path, language) except TranscriptionError as exc: raise gr.Error(str(exc)) from exc except requests.Timeout as exc: raise gr.Error( "The transcription request timed out. Try a shorter clip or increase the timeout." ) from exc except requests.RequestException as exc: raise gr.Error(f"Unable to reach the transcription API: {exc}") from exc transcript = extract_transcript_text(payload) if not transcript: raise gr.Error("The API returned a response, but no transcript text was found.") raw_json = json.dumps(payload, indent=2, ensure_ascii=False) return ( transcript, raw_json, ) def transcribe_audio(audio_path: str | None, language: str) -> str: transcript, _raw_json = build_transcription_outputs(audio_path, language) return transcript def transcribe_audio_debug(audio_path: str | None, language: str) -> tuple[str, str]: return build_transcription_outputs(audio_path, language) def reset_ui() -> tuple[None, str, str, dict[str, Any]]: return (None, DEFAULT_LANGUAGE, "", gr.update(interactive=False)) def reset_ui_debug() -> tuple[None, str, str, str, dict[str, Any]]: return (None, DEFAULT_LANGUAGE, "", "", gr.update(interactive=False)) def update_transcribe_button(audio_path: str | None) -> dict[str, bool]: return gr.update(interactive=api_ready and bool(audio_path)) example_manifest = load_example_manifest() example_rows = build_example_table_rows(example_manifest) example_inputs, example_labels = build_gradio_examples(example_manifest) api_ready = get_api_key() is not None debug_raw_response_enabled = get_debug_raw_response_enabled() hero_markdown = """

Cohere Multilingual ASR 🎤

Upload audio or record directly in the browser, and inspect the transcript. By using this Space, you agree to the Cohere Privacy Policy.

""" examples_copy = "Try one of the sample clips." examples_source_copy = "Samples are sourced from `google/fleurs` validation audio." with gr.Blocks( title="Cohere Multilingual ASR", ) as demo: with gr.Column(elem_classes="app-shell"): gr.Markdown(hero_markdown, sanitize_html=False) if not api_ready: gr.Markdown( ( '
Configuration required. ' "Set the COHERE_API_KEY environment variable to enable transcription." "
" ), sanitize_html=False, ) with gr.Accordion("How to use", open=False, elem_classes=["panel", "guide-panel"]): gr.Markdown( "1. Record audio with your microphone or upload a file.\n" "2. Select the language spoken in the clip.\n" "3. Click `Transcribe` to send a single REST request.\n" "4. Review the transcript." ) with gr.Row(equal_height=False): with gr.Column(scale=5, elem_classes="panel"): gr.Markdown("### Audio Input") audio_input = gr.Audio( sources=["microphone", "upload"], type="filepath", label="Audio clip", ) language_input = gr.Dropdown( label="Language", choices=LANGUAGE_OPTIONS, value=DEFAULT_LANGUAGE, info="Select the language spoken in the audio clip.", ) with gr.Row(): transcribe_button = gr.Button( "Transcribe", variant="primary", interactive=False, ) clear_button = gr.Button("Reset", variant="secondary") gr.Markdown( f"Maximum file size: {MAX_AUDIO_FILE_LABEL}.", elem_classes="compact-note", ) with gr.Column(scale=7, elem_classes="panel"): gr.Markdown("### Transcript") transcript_output = gr.Textbox( label="Transcript", lines=12, max_lines=18, interactive=False, placeholder="The transcription will appear here.", info="Use your browser copy shortcut to copy the transcript.", ) if debug_raw_response_enabled: with gr.Accordion("Raw API response", open=False): raw_response_output = gr.Code( label="JSON", language="json", interactive=False, lines=14, ) audio_input.change( fn=update_transcribe_button, inputs=[audio_input], outputs=[transcribe_button], show_progress="hidden", ) if debug_raw_response_enabled: clear_button.click( fn=reset_ui_debug, outputs=[ audio_input, language_input, transcript_output, raw_response_output, transcribe_button, ], show_progress="hidden", ) else: clear_button.click( fn=reset_ui, outputs=[ audio_input, language_input, transcript_output, transcribe_button, ], show_progress="hidden", ) if debug_raw_response_enabled: transcribe_button.click( fn=transcribe_audio_debug, inputs=[audio_input, language_input], outputs=[ transcript_output, raw_response_output, ], api_name="transcribe", ) else: transcribe_button.click( fn=transcribe_audio, inputs=[audio_input, language_input], outputs=[ transcript_output, ], api_name="transcribe", ) with gr.Column(elem_classes=["panel", "example-panel"]): gr.Markdown('
Examples
', sanitize_html=False) gr.Markdown(examples_copy, elem_classes="section-copy") gr.Markdown(examples_source_copy, elem_classes="example-source") if example_inputs: gr.Examples( examples=example_inputs, inputs=[audio_input, language_input], example_labels=example_labels, examples_per_page=len(example_inputs), run_on_click=False, ) else: gr.Markdown( "No example audio is bundled yet. Add approved clips to `examples/` and update `examples/manifest.json` to activate them." ) if debug_raw_response_enabled: gr.Dataframe( headers=["Label", "Language", "Description", "Status"], value=example_rows, interactive=False, wrap=True, row_count=(len(example_rows), "fixed"), ) demo.queue(default_concurrency_limit=4) if __name__ == "__main__": demo.launch( theme=APP_THEME, css_paths="style.css", )