""" Translation interface using the MADLAD-400 3B model. Translates between 418 languages from the MADLAD-400 paper. """ import warnings from collections.abc import Generator from functools import lru_cache import gradio as gr import spaces import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from langmap.langid_mapping import langid_to_language MODEL_NAME = "google/madlad400-3b-mt" def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) return torch.device("cpu") @lru_cache(maxsize=1) def _load_tokenizer() -> AutoTokenizer: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) if tokenizer is None: raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") return tokenizer @lru_cache(maxsize=1) def _load_model() -> AutoModelForSeq2SeqLM: device = _get_device() dtype = torch.float16 if device.type == "cuda" else torch.float32 return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device) @lru_cache(maxsize=1) def _build_language_mappings() -> tuple[dict[str, str], list[str]]: tokenizer = _load_tokenizer() vocab = tokenizer.get_vocab() name_to_code: dict[str, str] = {} for code, info in langid_to_language.items(): if code in vocab: locale = code[2:-1] # <2fr> → fr display_name = f"{info['name']} ({locale})" name_to_code[display_name] = code # Sort by region, then alphabetically within each region sorted_names = sorted( name_to_code.keys(), key=lambda n: (langid_to_language[name_to_code[n]]["region"], n), ) return name_to_code, sorted_names @spaces.GPU def translate( text: str, target_language_name: str, max_new_tokens: int = 512, num_beams: int = 1, temperature: float = 1.0, ) -> str: tokenizer = _load_tokenizer() model = _load_model() device = model.device name_to_code, _ = _build_language_mappings() target_code = name_to_code.get(target_language_name) if target_code is None: raise ValueError(f"Unsupported language: {target_language_name}") if num_beams > 1 and temperature != 1.0: gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).") input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams} if num_beams == 1: generate_kwargs["do_sample"] = True generate_kwargs["temperature"] = temperature outputs = model.generate(**generate_kwargs) result = tokenizer.decode(outputs[0], skip_special_tokens=True) if not isinstance(result, str): raise TypeError(f"Expected str from decode, got {type(result)}") return result def _translate_with_loading( text: str, target_language_name: str, ) -> Generator[tuple, None, None]: yield gr.update(value="Translating...", interactive=False), gr.update() result = translate(text, target_language_name) yield gr.update(value="Translate", interactive=True), result def _swap_languages( source_lang: str, target_lang: str, source_text: str, target_text: str ) -> tuple[str, str, str, str]: """Swap source/target languages and their text.""" return target_lang, source_lang, target_text, source_text def _build_demo() -> gr.Blocks: _, language_names = _build_language_mappings() with gr.Blocks(title="MADLAD-400 Translate") as demo: with gr.Row(): source_language = gr.Dropdown( choices=language_names, value="English (en)", show_label=False, filterable=True, ) swap_btn = gr.Button("⇄", scale=0, min_width=60) target_language = gr.Dropdown( choices=language_names, value="French (fr)", show_label=False, filterable=True, ) with gr.Row(equal_height=True): input_text = gr.Textbox( lines=6, max_length=2000, show_label=False, buttons=["clear"], ) output_text = gr.Textbox( placeholder="Translation", lines=6, show_label=False, interactive=False, buttons=["copy"], ) translate_btn = gr.Button("Translate", variant="primary") swap_btn.click( fn=_swap_languages, inputs=[source_language, target_language, input_text, output_text], outputs=[source_language, target_language, input_text, output_text], ) translate_btn.click( fn=_translate_with_loading, inputs=[input_text, target_language], outputs=[translate_btn, output_text], ) input_text.submit( fn=translate, inputs=[input_text, target_language], outputs=output_text, ) return demo demo = _build_demo() def main() -> None: demo.launch(theme=gr.themes.Ocean()) if __name__ == "__main__": main()