Spaces:
Sleeping
Sleeping
| """ | |
| Translation interface using the MADLAD-400 3B model. | |
| Translates between 418 languages from the MADLAD-400 paper. | |
| """ | |
| import warnings | |
| from collections.abc import Generator | |
| from functools import lru_cache | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from langmap.langid_mapping import langid_to_language | |
| MODEL_NAME = "google/madlad400-3b-mt" | |
| def _get_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2) | |
| return torch.device("cpu") | |
| def _load_tokenizer() -> AutoTokenizer: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| if tokenizer is None: | |
| raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}") | |
| return tokenizer | |
| def _load_model() -> AutoModelForSeq2SeqLM: | |
| device = _get_device() | |
| dtype = torch.float16 if device.type == "cuda" else torch.float32 | |
| return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device) | |
| def _build_language_mappings() -> tuple[dict[str, str], list[str]]: | |
| tokenizer = _load_tokenizer() | |
| vocab = tokenizer.get_vocab() | |
| name_to_code: dict[str, str] = {} | |
| for code, info in langid_to_language.items(): | |
| if code in vocab: | |
| locale = code[2:-1] # <2fr> → fr | |
| display_name = f"{info['name']} ({locale})" | |
| name_to_code[display_name] = code | |
| # Sort by region, then alphabetically within each region | |
| sorted_names = sorted( | |
| name_to_code.keys(), | |
| key=lambda n: (langid_to_language[name_to_code[n]]["region"], n), | |
| ) | |
| return name_to_code, sorted_names | |
| def translate( | |
| text: str, | |
| target_language_name: str, | |
| max_new_tokens: int = 512, | |
| num_beams: int = 1, | |
| temperature: float = 1.0, | |
| ) -> str: | |
| tokenizer = _load_tokenizer() | |
| model = _load_model() | |
| device = model.device | |
| name_to_code, _ = _build_language_mappings() | |
| target_code = name_to_code.get(target_language_name) | |
| if target_code is None: | |
| raise ValueError(f"Unsupported language: {target_language_name}") | |
| if num_beams > 1 and temperature != 1.0: | |
| gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).") | |
| input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device) | |
| generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams} | |
| if num_beams == 1: | |
| generate_kwargs["do_sample"] = True | |
| generate_kwargs["temperature"] = temperature | |
| outputs = model.generate(**generate_kwargs) | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if not isinstance(result, str): | |
| raise TypeError(f"Expected str from decode, got {type(result)}") | |
| return result | |
| def _translate_with_loading( | |
| text: str, | |
| target_language_name: str, | |
| ) -> Generator[tuple, None, None]: | |
| yield gr.update(value="Translating...", interactive=False), gr.update() | |
| result = translate(text, target_language_name) | |
| yield gr.update(value="Translate", interactive=True), result | |
| def _swap_languages( | |
| source_lang: str, target_lang: str, source_text: str, target_text: str | |
| ) -> tuple[str, str, str, str]: | |
| """Swap source/target languages and their text.""" | |
| return target_lang, source_lang, target_text, source_text | |
| def _build_demo() -> gr.Blocks: | |
| _, language_names = _build_language_mappings() | |
| with gr.Blocks(title="MADLAD-400 Translate") as demo: | |
| with gr.Row(): | |
| source_language = gr.Dropdown( | |
| choices=language_names, | |
| value="English (en)", | |
| show_label=False, | |
| filterable=True, | |
| ) | |
| swap_btn = gr.Button("⇄", scale=0, min_width=60) | |
| target_language = gr.Dropdown( | |
| choices=language_names, | |
| value="French (fr)", | |
| show_label=False, | |
| filterable=True, | |
| ) | |
| with gr.Row(equal_height=True): | |
| input_text = gr.Textbox( | |
| lines=6, | |
| max_length=2000, | |
| show_label=False, | |
| buttons=["clear"], | |
| ) | |
| output_text = gr.Textbox( | |
| placeholder="Translation", | |
| lines=6, | |
| show_label=False, | |
| interactive=False, | |
| buttons=["copy"], | |
| ) | |
| translate_btn = gr.Button("Translate", variant="primary") | |
| swap_btn.click( | |
| fn=_swap_languages, | |
| inputs=[source_language, target_language, input_text, output_text], | |
| outputs=[source_language, target_language, input_text, output_text], | |
| ) | |
| translate_btn.click( | |
| fn=_translate_with_loading, | |
| inputs=[input_text, target_language], | |
| outputs=[translate_btn, output_text], | |
| ) | |
| input_text.submit( | |
| fn=translate, | |
| inputs=[input_text, target_language], | |
| outputs=output_text, | |
| ) | |
| return demo | |
| demo = _build_demo() | |
| def main() -> None: | |
| demo.launch(theme=gr.themes.Ocean()) | |
| if __name__ == "__main__": | |
| main() | |