hayas's picture
Add files
c8619d3
Raw
History Blame
6.5 kB
#!/usr/bin/env python
import os
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MAX_NEW_TOKENS_LIMIT = int(os.getenv("MAX_NEW_TOKENS_LIMIT", "2000"))
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "500"))
MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", "8192"))
MODEL_ID = "cyberagent/CAT-Translate-7b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto")
PROMPT_TEMPLATE = "Translate the following {src_lang} text into {tgt_lang}.\n\n{src_text}"
DIRECTION_LANGS: dict[str, tuple[str, str]] = {
"Japanese → English": ("Japanese", "English"),
"English → Japanese": ("English", "Japanese"),
}
DEFAULT_DIRECTION = "Japanese → English"
def _build_messages(text: str, direction: str) -> list[dict]:
src_lang, tgt_lang = DIRECTION_LANGS[direction]
content = PROMPT_TEMPLATE.format(src_lang=src_lang, tgt_lang=tgt_lang, src_text=text)
return [{"role": "user", "content": content}]
def count_tokens(text: str, direction: str) -> str:
"""Count input tokens without GPU. Returns a short info string."""
if not text:
return ""
messages = _build_messages(text, direction)
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=False)
return f"Input tokens: {len(input_ids)}"
@spaces.GPU(duration=60)
@torch.inference_mode()
def translate(text: str, direction: str, max_new_tokens: int) -> str:
if not text:
raise gr.Error("Please enter text to translate")
messages = _build_messages(text, direction)
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
).to(model.device)
input_len = len(inputs["input_ids"][0])
if input_len + max_new_tokens > MAX_TOTAL_TOKENS:
error_message = (
f"Input ({input_len} tokens) + max output ({max_new_tokens} tokens)"
f" exceeds the total limit of {MAX_TOTAL_TOKENS} tokens."
)
raise gr.Error(error_message)
generation = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens, use_cache=True)
generation = generation[0][input_len:]
return tokenizer.decode(generation, skip_special_tokens=True)
with gr.Blocks() as demo:
gr.Markdown("# CAT-Translate-7b")
direction = gr.Radio(
label="Translation Direction",
choices=list(DIRECTION_LANGS.keys()),
value=DEFAULT_DIRECTION,
)
max_new_tokens = gr.Slider(
label="Max New Tokens",
info="Higher values allow longer translations but take more time",
minimum=50,
maximum=MAX_NEW_TOKENS_LIMIT,
step=10,
value=MAX_NEW_TOKENS_DEFAULT,
)
with gr.Row():
with gr.Column():
text = gr.Textbox(label="Input", lines=10, placeholder="Enter text to translate")
token_info = gr.Textbox(label="Token Count", lines=1)
translate_button = gr.Button("Translate", variant="primary")
with gr.Column():
output = gr.Textbox(label="Translation", lines=10, placeholder="Translation will appear here")
token_count_inputs = [text, direction]
for component in token_count_inputs:
component.change(fn=count_tokens, inputs=token_count_inputs, outputs=token_info)
translate_button.click(
fn=translate,
inputs=[text, direction, max_new_tokens],
outputs=output,
)
def translate_example(text: str, direction: str) -> str:
return translate(text, direction, MAX_NEW_TOKENS_DEFAULT)
gr.Examples(
label="Short examples",
examples=[
["今日はいい天気ですね。", "Japanese → English"],
["東京は世界で最も人口の多い都市の一つです。", "Japanese → English"],
["The cherry blossoms are beautiful this year.", "English → Japanese"],
["Technology is changing how we communicate with each other.", "English → Japanese"],
],
inputs=[text, direction],
outputs=output,
fn=translate_example,
)
gr.Examples(
label="Long examples",
examples=[
[
"近年、大規模言語モデルの発展により、機械翻訳の品質は飛躍的に向上した。"
"従来の統計ベースの手法では、文脈を十分に考慮することが難しく、"
"長文になるほど翻訳精度が低下する傾向があった。"
"しかし、Transformerアーキテクチャの登場以降、"
"文全体の意味を捉えた上で自然な訳文を生成することが可能になりつつある。"
"特に、日本語と英語のように語順や文法構造が大きく異なる言語対においては、"
"この進歩の恩恵は顕著である。"
"一方で、専門用語や文化的なニュアンスの翻訳には依然として課題が残されており、"
"人間の翻訳者との協働が重要視されている。",
"Japanese → English",
],
[
"The rapid advancement of artificial intelligence has fundamentally transformed "
"how software is developed, tested, and deployed. Modern development teams "
"increasingly rely on AI-powered tools for code generation, automated testing, "
"and even architectural design decisions. While these tools have dramatically "
"improved productivity, they also introduce new challenges around code quality, "
"security vulnerabilities, and the need for human oversight. The most effective "
"approach appears to be a collaborative one, where AI handles repetitive and "
"boilerplate tasks while human developers focus on creative problem-solving, "
"system design, and ensuring that the generated code aligns with business "
"requirements and ethical standards.",
"English → Japanese",
],
],
inputs=[text, direction],
outputs=output,
fn=translate_example,
)
if __name__ == "__main__":
demo.launch(css_paths="style.css")