from __future__ import annotations import json import mimetypes import os from pathlib import Path from typing import Any import gradio as gr import requests APP_ROOT = Path(__file__).parent EXAMPLES_MANIFEST_PATH = APP_ROOT / "examples" / "manifest.json" TRANSCRIPTIONS_API_URL = "https://api.cohere.ai/compatibility/v1/audio/transcriptions" DEFAULT_MODEL_ID = "cohere-transcribe-03-2026" DEFAULT_TIMEOUT_SECONDS = 180.0 DEFAULT_LANGUAGE = "en" MAX_AUDIO_FILE_BYTES = 25 * 1024 * 1024 MAX_AUDIO_FILE_LABEL = "25 MB" LANGUAGE_OPTIONS = [ ("English", "en"), ("French", "fr"), ("German", "de"), ("Spanish", "es"), ("Portuguese", "pt"), ("Italian", "it"), ("Dutch", "nl"), ("Polish", "pl"), ("Greek", "el"), ("Arabic", "ar"), ("Japanese", "ja"), ("Korean", "ko"), ("Chinese", "zh"), ("Vietnamese", "vi"), ] LANGUAGE_LABELS = {code: label for label, code in LANGUAGE_OPTIONS} APP_THEME = gr.themes.Soft( primary_hue="orange", secondary_hue="amber", neutral_hue="stone", font=[gr.themes.GoogleFont("Manrope"), "ui-sans-serif", "system-ui", "sans-serif"], ).set( body_background_fill="#fff7ed", body_background_fill_dark="#fff7ed", body_text_color="#6b4f3f", body_text_color_dark="#6b4f3f", body_text_color_subdued="#7c6558", body_text_color_subdued_dark="#7c6558", block_background_fill="#fffaf0", block_background_fill_dark="#fffaf0", block_border_color="rgba(194, 65, 12, 0.14)", block_border_color_dark="rgba(194, 65, 12, 0.14)", block_label_text_color="#7c2d12", block_label_text_color_dark="#7c2d12", input_background_fill="#fffdf7", input_background_fill_dark="#fffdf7", input_border_color="rgba(194, 65, 12, 0.18)", input_border_color_dark="rgba(194, 65, 12, 0.18)", button_primary_background_fill="#ea580c", button_primary_background_fill_dark="#ea580c", button_primary_background_fill_hover="#f97316", button_primary_background_fill_hover_dark="#f97316", button_primary_text_color="#fff7ed", button_primary_text_color_dark="#fff7ed", button_secondary_background_fill="#fffaf0", button_secondary_background_fill_dark="#fffaf0", button_secondary_background_fill_hover="#fff1e6", button_secondary_background_fill_hover_dark="#fff1e6", button_secondary_text_color="#7c2d12", button_secondary_text_color_dark="#7c2d12", button_secondary_border_color="rgba(194, 65, 12, 0.14)", button_secondary_border_color_dark="rgba(194, 65, 12, 0.14)", link_text_color="#c2410c", link_text_color_dark="#c2410c", ) class TranscriptionError(Exception): pass def get_api_key() -> str | None: api_key = os.getenv("COHERE_API_KEY", "").strip() return api_key or None def get_model_id() -> str: model_id = os.getenv("COHERE_ASR_MODEL", DEFAULT_MODEL_ID).strip() return model_id or DEFAULT_MODEL_ID def get_timeout_seconds() -> float: raw_value = os.getenv("COHERE_API_TIMEOUT_SECONDS", str(DEFAULT_TIMEOUT_SECONDS)) try: return max(30.0, float(raw_value)) except ValueError: return DEFAULT_TIMEOUT_SECONDS def get_debug_raw_response_enabled() -> bool: value = os.getenv("COHERE_ASR_DEBUG_RAW_RESPONSE", "").strip().lower() return value in {"1", "true", "yes", "on"} def display_language(language_code: str | None) -> str: if not language_code: return "Not returned" normalized = language_code.strip() base_code = normalized.split("-", 1)[0].lower() label = LANGUAGE_LABELS.get(base_code) if not label: return normalized if base_code == normalized.lower(): return label return f"{label} ({normalized})" def extract_transcript_text(payload: Any) -> str: if isinstance(payload, str): return payload.strip() if isinstance(payload, dict): for key in ("text", "transcript"): value = payload.get(key) if isinstance(value, str) and value.strip(): return value.strip() segments = payload.get("segments") if isinstance(segments, list): parts = [extract_transcript_text(segment) for segment in segments] joined = " ".join(part for part in parts if part) if joined: return joined.strip() for key in ("result", "data", "output"): nested = payload.get(key) text = extract_transcript_text(nested) if text: return text if isinstance(payload, list): parts = [extract_transcript_text(item) for item in payload] joined = " ".join(part for part in parts if part) return joined.strip() return "" def format_api_error(response: requests.Response) -> str: message = "" try: payload = response.json() except ValueError: payload = response.text.strip() if isinstance(payload, dict): error = payload.get("error") if isinstance(error, dict): message = str(error.get("message") or error.get("type") or "") elif isinstance(error, str): message = error if not message: for key in ("message", "detail", "error_description"): value = payload.get(key) if isinstance(value, str) and value.strip(): message = value.strip() break elif isinstance(payload, str): message = payload if not message: message = "Unexpected API response." return f"Transcription request failed ({response.status_code}): {message}" def validate_audio_file(audio_path: str) -> Path: audio_file_path = Path(audio_path) try: file_size = audio_file_path.stat().st_size except OSError as exc: raise TranscriptionError("The selected audio file could not be read.") from exc if file_size > MAX_AUDIO_FILE_BYTES: raise TranscriptionError( f"Audio files must be {MAX_AUDIO_FILE_LABEL} or smaller. " "Please upload a shorter clip or a more compressed file." ) return audio_file_path def call_transcriptions_api(audio_path: str, language: str) -> dict[str, Any]: api_key = get_api_key() if not api_key: raise TranscriptionError( "COHERE_API_KEY is not configured. Add it locally or in Hugging Face Space secrets." ) audio_file_path = validate_audio_file(audio_path) mime_type = mimetypes.guess_type(audio_file_path.name)[0] or "application/octet-stream" headers = { "Authorization": f"Bearer {api_key}", "Accept": "application/json", } with audio_file_path.open("rb") as audio_file: multipart_fields = [ ("model", (None, get_model_id())), ("language", (None, language)), ("file", (audio_file_path.name, audio_file, mime_type)), ] response = requests.post( TRANSCRIPTIONS_API_URL, headers=headers, files=multipart_fields, timeout=get_timeout_seconds(), ) if not response.ok: raise TranscriptionError(format_api_error(response)) try: payload = response.json() except ValueError as exc: raise TranscriptionError("The transcription API returned a non-JSON response.") from exc if not isinstance(payload, dict): raise TranscriptionError("The transcription API returned an unexpected payload shape.") return payload def load_example_manifest() -> list[dict[str, Any]]: if not EXAMPLES_MANIFEST_PATH.exists(): return [] try: payload = json.loads(EXAMPLES_MANIFEST_PATH.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return [] raw_examples = payload.get("examples", []) if not isinstance(raw_examples, list): return [] return [item for item in raw_examples if isinstance(item, dict)] def resolve_example_audio_path(raw_path: str | None) -> Path | None: if not raw_path: return None candidate = Path(raw_path) if not candidate.is_absolute(): candidate = APP_ROOT / candidate if candidate.exists(): return candidate return None def build_example_table_rows(examples: list[dict[str, Any]]) -> list[list[str]]: rows: list[list[str]] = [] for example in examples: audio_path = resolve_example_audio_path(example.get("audio_path")) is_ready = bool(example.get("enabled")) and audio_path is not None status = "Ready" if is_ready else "Waiting for approved audio" rows.append( [ str(example.get("label") or "Untitled example"), display_language(example.get("language")), str(example.get("description") or ""), status, ] ) return rows def build_gradio_examples(examples: list[dict[str, Any]]) -> tuple[list[list[str]], list[str]]: example_inputs: list[list[str]] = [] example_labels: list[str] = [] for example in examples: audio_path = resolve_example_audio_path(example.get("audio_path")) if not example.get("enabled") or audio_path is None: continue example_inputs.append( [ str(audio_path), str(example.get("language") or DEFAULT_LANGUAGE), ] ) example_labels.append(str(example.get("label") or audio_path.stem)) return example_inputs, example_labels def build_transcription_outputs(audio_path: str | None, language: str) -> tuple[str, str]: if not audio_path: raise gr.Error("Record or upload an audio clip before starting transcription.") try: payload = call_transcriptions_api(audio_path, language) except TranscriptionError as exc: raise gr.Error(str(exc)) from exc except requests.Timeout as exc: raise gr.Error( "The transcription request timed out. Try a shorter clip or increase the timeout." ) from exc except requests.RequestException as exc: raise gr.Error(f"Unable to reach the transcription API: {exc}") from exc transcript = extract_transcript_text(payload) if not transcript: raise gr.Error("The API returned a response, but no transcript text was found.") raw_json = json.dumps(payload, indent=2, ensure_ascii=False) return ( transcript, raw_json, ) def transcribe_audio(audio_path: str | None, language: str) -> str: transcript, _raw_json = build_transcription_outputs(audio_path, language) return transcript def transcribe_audio_debug(audio_path: str | None, language: str) -> tuple[str, str]: return build_transcription_outputs(audio_path, language) def reset_ui() -> tuple[None, str, str, dict[str, Any]]: return (None, DEFAULT_LANGUAGE, "", gr.update(interactive=False)) def reset_ui_debug() -> tuple[None, str, str, str, dict[str, Any]]: return (None, DEFAULT_LANGUAGE, "", "", gr.update(interactive=False)) def update_transcribe_button(audio_path: str | None) -> dict[str, bool]: return gr.update(interactive=api_ready and bool(audio_path)) example_manifest = load_example_manifest() example_rows = build_example_table_rows(example_manifest) example_inputs, example_labels = build_gradio_examples(example_manifest) api_ready = get_api_key() is not None debug_raw_response_enabled = get_debug_raw_response_enabled() hero_markdown = """
Upload audio or record directly in the browser, and inspect the transcript. By using this Space, you agree to the Cohere Privacy Policy.