Spaces:
Running on Zero
Running on Zero
File size: 6,497 Bytes
c8619d3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | #!/usr/bin/env python
import os
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MAX_NEW_TOKENS_LIMIT = int(os.getenv("MAX_NEW_TOKENS_LIMIT", "2000"))
MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "500"))
MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", "8192"))
MODEL_ID = "cyberagent/CAT-Translate-7b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto")
PROMPT_TEMPLATE = "Translate the following {src_lang} text into {tgt_lang}.\n\n{src_text}"
DIRECTION_LANGS: dict[str, tuple[str, str]] = {
"Japanese → English": ("Japanese", "English"),
"English → Japanese": ("English", "Japanese"),
}
DEFAULT_DIRECTION = "Japanese → English"
def _build_messages(text: str, direction: str) -> list[dict]:
src_lang, tgt_lang = DIRECTION_LANGS[direction]
content = PROMPT_TEMPLATE.format(src_lang=src_lang, tgt_lang=tgt_lang, src_text=text)
return [{"role": "user", "content": content}]
def count_tokens(text: str, direction: str) -> str:
"""Count input tokens without GPU. Returns a short info string."""
if not text:
return ""
messages = _build_messages(text, direction)
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=False)
return f"Input tokens: {len(input_ids)}"
@spaces.GPU(duration=60)
@torch.inference_mode()
def translate(text: str, direction: str, max_new_tokens: int) -> str:
if not text:
raise gr.Error("Please enter text to translate")
messages = _build_messages(text, direction)
inputs = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
).to(model.device)
input_len = len(inputs["input_ids"][0])
if input_len + max_new_tokens > MAX_TOTAL_TOKENS:
error_message = (
f"Input ({input_len} tokens) + max output ({max_new_tokens} tokens)"
f" exceeds the total limit of {MAX_TOTAL_TOKENS} tokens."
)
raise gr.Error(error_message)
generation = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens, use_cache=True)
generation = generation[0][input_len:]
return tokenizer.decode(generation, skip_special_tokens=True)
with gr.Blocks() as demo:
gr.Markdown("# CAT-Translate-7b")
direction = gr.Radio(
label="Translation Direction",
choices=list(DIRECTION_LANGS.keys()),
value=DEFAULT_DIRECTION,
)
max_new_tokens = gr.Slider(
label="Max New Tokens",
info="Higher values allow longer translations but take more time",
minimum=50,
maximum=MAX_NEW_TOKENS_LIMIT,
step=10,
value=MAX_NEW_TOKENS_DEFAULT,
)
with gr.Row():
with gr.Column():
text = gr.Textbox(label="Input", lines=10, placeholder="Enter text to translate")
token_info = gr.Textbox(label="Token Count", lines=1)
translate_button = gr.Button("Translate", variant="primary")
with gr.Column():
output = gr.Textbox(label="Translation", lines=10, placeholder="Translation will appear here")
token_count_inputs = [text, direction]
for component in token_count_inputs:
component.change(fn=count_tokens, inputs=token_count_inputs, outputs=token_info)
translate_button.click(
fn=translate,
inputs=[text, direction, max_new_tokens],
outputs=output,
)
def translate_example(text: str, direction: str) -> str:
return translate(text, direction, MAX_NEW_TOKENS_DEFAULT)
gr.Examples(
label="Short examples",
examples=[
["今日はいい天気ですね。", "Japanese → English"],
["東京は世界で最も人口の多い都市の一つです。", "Japanese → English"],
["The cherry blossoms are beautiful this year.", "English → Japanese"],
["Technology is changing how we communicate with each other.", "English → Japanese"],
],
inputs=[text, direction],
outputs=output,
fn=translate_example,
)
gr.Examples(
label="Long examples",
examples=[
[
"近年、大規模言語モデルの発展により、機械翻訳の品質は飛躍的に向上した。"
"従来の統計ベースの手法では、文脈を十分に考慮することが難しく、"
"長文になるほど翻訳精度が低下する傾向があった。"
"しかし、Transformerアーキテクチャの登場以降、"
"文全体の意味を捉えた上で自然な訳文を生成することが可能になりつつある。"
"特に、日本語と英語のように語順や文法構造が大きく異なる言語対においては、"
"この進歩の恩恵は顕著である。"
"一方で、専門用語や文化的なニュアンスの翻訳には依然として課題が残されており、"
"人間の翻訳者との協働が重要視されている。",
"Japanese → English",
],
[
"The rapid advancement of artificial intelligence has fundamentally transformed "
"how software is developed, tested, and deployed. Modern development teams "
"increasingly rely on AI-powered tools for code generation, automated testing, "
"and even architectural design decisions. While these tools have dramatically "
"improved productivity, they also introduce new challenges around code quality, "
"security vulnerabilities, and the need for human oversight. The most effective "
"approach appears to be a collaborative one, where AI handles repetitive and "
"boilerplate tasks while human developers focus on creative problem-solving, "
"system design, and ensuring that the generated code aligns with business "
"requirements and ethical standards.",
"English → Japanese",
],
],
inputs=[text, direction],
outputs=output,
fn=translate_example,
)
if __name__ == "__main__":
demo.launch(css_paths="style.css")
|