Daryl Lim
fix: address code review findings
7633dda
"""
Translation interface using the MADLAD-400 3B model.
Translates between 418 languages from the MADLAD-400 paper.
"""
import warnings
from collections.abc import Generator
from functools import lru_cache
import gradio as gr
import spaces
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from langmap.langid_mapping import langid_to_language
MODEL_NAME = "google/madlad400-3b-mt"
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
warnings.warn("No GPU available. Running on CPU — translation will be slow.", stacklevel=2)
return torch.device("cpu")
@lru_cache(maxsize=1)
def _load_tokenizer() -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer is None:
raise RuntimeError(f"Failed to load tokenizer for {MODEL_NAME}")
return tokenizer
@lru_cache(maxsize=1)
def _load_model() -> AutoModelForSeq2SeqLM:
device = _get_device()
dtype = torch.float16 if device.type == "cuda" else torch.float32
return AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, dtype=dtype).to(device)
@lru_cache(maxsize=1)
def _build_language_mappings() -> tuple[dict[str, str], list[str]]:
tokenizer = _load_tokenizer()
vocab = tokenizer.get_vocab()
name_to_code: dict[str, str] = {}
for code, info in langid_to_language.items():
if code in vocab:
locale = code[2:-1] # <2fr> → fr
display_name = f"{info['name']} ({locale})"
name_to_code[display_name] = code
# Sort by region, then alphabetically within each region
sorted_names = sorted(
name_to_code.keys(),
key=lambda n: (langid_to_language[name_to_code[n]]["region"], n),
)
return name_to_code, sorted_names
@spaces.GPU
def translate(
text: str,
target_language_name: str,
max_new_tokens: int = 512,
num_beams: int = 1,
temperature: float = 1.0,
) -> str:
tokenizer = _load_tokenizer()
model = _load_model()
device = model.device
name_to_code, _ = _build_language_mappings()
target_code = name_to_code.get(target_language_name)
if target_code is None:
raise ValueError(f"Unsupported language: {target_language_name}")
if num_beams > 1 and temperature != 1.0:
gr.Info("Temperature has no effect when beam search is enabled (num_beams > 1).")
input_ids = tokenizer(target_code + " " + text, return_tensors="pt").input_ids.to(device)
generate_kwargs: dict = {"input_ids": input_ids, "max_new_tokens": max_new_tokens, "num_beams": num_beams}
if num_beams == 1:
generate_kwargs["do_sample"] = True
generate_kwargs["temperature"] = temperature
outputs = model.generate(**generate_kwargs)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
if not isinstance(result, str):
raise TypeError(f"Expected str from decode, got {type(result)}")
return result
def _translate_with_loading(
text: str,
target_language_name: str,
) -> Generator[tuple, None, None]:
yield gr.update(value="Translating...", interactive=False), gr.update()
result = translate(text, target_language_name)
yield gr.update(value="Translate", interactive=True), result
def _swap_languages(
source_lang: str, target_lang: str, source_text: str, target_text: str
) -> tuple[str, str, str, str]:
"""Swap source/target languages and their text."""
return target_lang, source_lang, target_text, source_text
def _build_demo() -> gr.Blocks:
_, language_names = _build_language_mappings()
with gr.Blocks(title="MADLAD-400 Translate") as demo:
with gr.Row():
source_language = gr.Dropdown(
choices=language_names,
value="English (en)",
show_label=False,
filterable=True,
)
swap_btn = gr.Button("⇄", scale=0, min_width=60)
target_language = gr.Dropdown(
choices=language_names,
value="French (fr)",
show_label=False,
filterable=True,
)
with gr.Row(equal_height=True):
input_text = gr.Textbox(
lines=6,
max_length=2000,
show_label=False,
buttons=["clear"],
)
output_text = gr.Textbox(
placeholder="Translation",
lines=6,
show_label=False,
interactive=False,
buttons=["copy"],
)
translate_btn = gr.Button("Translate", variant="primary")
swap_btn.click(
fn=_swap_languages,
inputs=[source_language, target_language, input_text, output_text],
outputs=[source_language, target_language, input_text, output_text],
)
translate_btn.click(
fn=_translate_with_loading,
inputs=[input_text, target_language],
outputs=[translate_btn, output_text],
)
input_text.submit(
fn=translate,
inputs=[input_text, target_language],
outputs=output_text,
)
return demo
demo = _build_demo()
def main() -> None:
demo.launch(theme=gr.themes.Ocean())
if __name__ == "__main__":
main()