hayas commited on
Commit
c8619d3
·
1 Parent(s): 8f7d8ef

Add files

Browse files
Files changed (7) hide show
  1. .python-version +1 -0
  2. README.md +2 -1
  3. app.py +151 -0
  4. pyproject.toml +59 -0
  5. requirements.txt +305 -0
  6. style.css +4 -0
  7. uv.lock +0 -0
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
  title: CAT Translate 7b
3
- emoji: 🦀
4
  colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.8.0
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: CAT Translate 7b
3
+ emoji: 🐱
4
  colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.8.0
8
+ python_version: "3.12.12"
9
  app_file: app.py
10
  pinned: false
11
  ---
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
+ MAX_NEW_TOKENS_LIMIT = int(os.getenv("MAX_NEW_TOKENS_LIMIT", "2000"))
11
+ MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "500"))
12
+ MAX_TOTAL_TOKENS = int(os.getenv("MAX_TOTAL_TOKENS", "8192"))
13
+
14
+ MODEL_ID = "cyberagent/CAT-Translate-7b"
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
17
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto")
18
+
19
+ PROMPT_TEMPLATE = "Translate the following {src_lang} text into {tgt_lang}.\n\n{src_text}"
20
+
21
+ DIRECTION_LANGS: dict[str, tuple[str, str]] = {
22
+ "Japanese → English": ("Japanese", "English"),
23
+ "English → Japanese": ("English", "Japanese"),
24
+ }
25
+
26
+ DEFAULT_DIRECTION = "Japanese → English"
27
+
28
+
29
+ def _build_messages(text: str, direction: str) -> list[dict]:
30
+ src_lang, tgt_lang = DIRECTION_LANGS[direction]
31
+ content = PROMPT_TEMPLATE.format(src_lang=src_lang, tgt_lang=tgt_lang, src_text=text)
32
+ return [{"role": "user", "content": content}]
33
+
34
+
35
+ def count_tokens(text: str, direction: str) -> str:
36
+ """Count input tokens without GPU. Returns a short info string."""
37
+ if not text:
38
+ return ""
39
+ messages = _build_messages(text, direction)
40
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=False)
41
+ return f"Input tokens: {len(input_ids)}"
42
+
43
+
44
+ @spaces.GPU(duration=60)
45
+ @torch.inference_mode()
46
+ def translate(text: str, direction: str, max_new_tokens: int) -> str:
47
+ if not text:
48
+ raise gr.Error("Please enter text to translate")
49
+
50
+ messages = _build_messages(text, direction)
51
+ inputs = tokenizer.apply_chat_template(
52
+ messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
53
+ ).to(model.device)
54
+ input_len = len(inputs["input_ids"][0])
55
+
56
+ if input_len + max_new_tokens > MAX_TOTAL_TOKENS:
57
+ error_message = (
58
+ f"Input ({input_len} tokens) + max output ({max_new_tokens} tokens)"
59
+ f" exceeds the total limit of {MAX_TOTAL_TOKENS} tokens."
60
+ )
61
+ raise gr.Error(error_message)
62
+
63
+ generation = model.generate(**inputs, do_sample=False, max_new_tokens=max_new_tokens, use_cache=True)
64
+ generation = generation[0][input_len:]
65
+ return tokenizer.decode(generation, skip_special_tokens=True)
66
+
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("# CAT-Translate-7b")
70
+ direction = gr.Radio(
71
+ label="Translation Direction",
72
+ choices=list(DIRECTION_LANGS.keys()),
73
+ value=DEFAULT_DIRECTION,
74
+ )
75
+ max_new_tokens = gr.Slider(
76
+ label="Max New Tokens",
77
+ info="Higher values allow longer translations but take more time",
78
+ minimum=50,
79
+ maximum=MAX_NEW_TOKENS_LIMIT,
80
+ step=10,
81
+ value=MAX_NEW_TOKENS_DEFAULT,
82
+ )
83
+ with gr.Row():
84
+ with gr.Column():
85
+ text = gr.Textbox(label="Input", lines=10, placeholder="Enter text to translate")
86
+ token_info = gr.Textbox(label="Token Count", lines=1)
87
+ translate_button = gr.Button("Translate", variant="primary")
88
+ with gr.Column():
89
+ output = gr.Textbox(label="Translation", lines=10, placeholder="Translation will appear here")
90
+
91
+ token_count_inputs = [text, direction]
92
+ for component in token_count_inputs:
93
+ component.change(fn=count_tokens, inputs=token_count_inputs, outputs=token_info)
94
+
95
+ translate_button.click(
96
+ fn=translate,
97
+ inputs=[text, direction, max_new_tokens],
98
+ outputs=output,
99
+ )
100
+
101
+ def translate_example(text: str, direction: str) -> str:
102
+ return translate(text, direction, MAX_NEW_TOKENS_DEFAULT)
103
+
104
+ gr.Examples(
105
+ label="Short examples",
106
+ examples=[
107
+ ["今日はいい天気ですね。", "Japanese → English"],
108
+ ["東京は世界で最も人口の多い都市の一つです。", "Japanese → English"],
109
+ ["The cherry blossoms are beautiful this year.", "English → Japanese"],
110
+ ["Technology is changing how we communicate with each other.", "English → Japanese"],
111
+ ],
112
+ inputs=[text, direction],
113
+ outputs=output,
114
+ fn=translate_example,
115
+ )
116
+ gr.Examples(
117
+ label="Long examples",
118
+ examples=[
119
+ [
120
+ "近年、大規模言語モデルの発展により、機械翻訳の品質は飛躍的に向上した。"
121
+ "従来の統計ベースの手法では、文脈を十分に考慮することが難しく、"
122
+ "長文になるほど翻訳精度が低下する傾向があった。"
123
+ "しかし、Transformerアーキテクチャの登場以降、"
124
+ "文全体の意味を捉えた上で自然な訳文を生成することが可能になりつつある。"
125
+ "特に、日本語と英語のように語順や文法構造が大きく異��る言語対においては、"
126
+ "この進歩の恩恵は顕著である。"
127
+ "一方で、専門用語や文化的なニュアンスの翻訳には依然として課題が残されており、"
128
+ "人間の翻訳者との協働が重要視されている。",
129
+ "Japanese → English",
130
+ ],
131
+ [
132
+ "The rapid advancement of artificial intelligence has fundamentally transformed "
133
+ "how software is developed, tested, and deployed. Modern development teams "
134
+ "increasingly rely on AI-powered tools for code generation, automated testing, "
135
+ "and even architectural design decisions. While these tools have dramatically "
136
+ "improved productivity, they also introduce new challenges around code quality, "
137
+ "security vulnerabilities, and the need for human oversight. The most effective "
138
+ "approach appears to be a collaborative one, where AI handles repetitive and "
139
+ "boilerplate tasks while human developers focus on creative problem-solving, "
140
+ "system design, and ensuring that the generated code aligns with business "
141
+ "requirements and ethical standards.",
142
+ "English → Japanese",
143
+ ],
144
+ ],
145
+ inputs=[text, direction],
146
+ outputs=output,
147
+ fn=translate_example,
148
+ )
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch(css_paths="style.css")
pyproject.toml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "cat-translate-7b"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "accelerate>=1.12.0",
9
+ "gradio>=6.8.0",
10
+ "spaces>=0.47.0",
11
+ "torch==2.9.1",
12
+ "transformers>=5.2.0",
13
+ ]
14
+
15
+ [tool.ruff]
16
+ line-length = 119
17
+
18
+ [tool.ruff.lint]
19
+ select = ["ALL"]
20
+ ignore = [
21
+ "COM812", # missing-trailing-comma
22
+ "D203", # one-blank-line-before-class
23
+ "D213", # multi-line-summary-second-line
24
+ "E501", # line-too-long
25
+ "SIM117", # multiple-with-statements
26
+ #
27
+ "D100", # undocumented-public-module
28
+ "D101", # undocumented-public-class
29
+ "D102", # undocumented-public-method
30
+ "D103", # undocumented-public-function
31
+ "D104", # undocumented-public-package
32
+ "D105", # undocumented-magic-method
33
+ "D107", # undocumented-public-init
34
+ "EM101", # raw-string-in-exception
35
+ "FBT001", # boolean-type-hint-positional-argument
36
+ "FBT002", # boolean-default-value-positional-argument
37
+ "ISC001", # single-line-implicit-string-concatenation
38
+ "PGH003", # blanket-type-ignore
39
+ "PLR0913", # too-many-arguments
40
+ "PLR0915", # too-many-statements
41
+ "TRY003", # raise-vanilla-args
42
+ ]
43
+ unfixable = [
44
+ "F401", # unused-import
45
+ ]
46
+
47
+ [tool.ruff.lint.pydocstyle]
48
+ convention = "google"
49
+
50
+ [tool.ruff.format]
51
+ docstring-code-format = true
52
+
53
+ [dependency-groups]
54
+ dev = [
55
+ "ruff>=0.15.4",
56
+ ]
57
+ hf-spaces = [
58
+ "datasets",
59
+ ]
requirements.txt ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --no-hashes --no-dev --group hf-spaces --no-emit-package typer-slim -o requirements.txt
3
+ accelerate==1.12.0
4
+ # via cat-translate-7b
5
+ aiofiles==24.1.0
6
+ # via gradio
7
+ aiohappyeyeballs==2.6.1
8
+ # via aiohttp
9
+ aiohttp==3.13.3
10
+ # via fsspec
11
+ aiosignal==1.4.0
12
+ # via aiohttp
13
+ annotated-doc==0.0.4
14
+ # via
15
+ # fastapi
16
+ # typer
17
+ annotated-types==0.7.0
18
+ # via pydantic
19
+ anyio==4.12.1
20
+ # via
21
+ # gradio
22
+ # httpx
23
+ # starlette
24
+ attrs==25.4.0
25
+ # via aiohttp
26
+ audioop-lts==0.2.2 ; python_full_version >= '3.13'
27
+ # via gradio
28
+ brotli==1.2.0
29
+ # via gradio
30
+ certifi==2026.2.25
31
+ # via
32
+ # httpcore
33
+ # httpx
34
+ # requests
35
+ charset-normalizer==3.4.4
36
+ # via requests
37
+ click==8.3.1
38
+ # via
39
+ # typer
40
+ # uvicorn
41
+ colorama==0.4.6 ; sys_platform == 'win32'
42
+ # via
43
+ # click
44
+ # tqdm
45
+ datasets==4.6.1
46
+ dill==0.4.0
47
+ # via
48
+ # datasets
49
+ # multiprocess
50
+ fastapi==0.135.1
51
+ # via gradio
52
+ ffmpy==1.0.0
53
+ # via gradio
54
+ filelock==3.25.0
55
+ # via
56
+ # datasets
57
+ # huggingface-hub
58
+ # torch
59
+ frozenlist==1.8.0
60
+ # via
61
+ # aiohttp
62
+ # aiosignal
63
+ fsspec==2026.2.0
64
+ # via
65
+ # datasets
66
+ # gradio-client
67
+ # huggingface-hub
68
+ # torch
69
+ gradio==6.8.0
70
+ # via
71
+ # cat-translate-7b
72
+ # spaces
73
+ gradio-client==2.2.0
74
+ # via gradio
75
+ groovy==0.1.2
76
+ # via gradio
77
+ h11==0.16.0
78
+ # via
79
+ # httpcore
80
+ # uvicorn
81
+ hf-xet==1.3.2 ; platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
82
+ # via huggingface-hub
83
+ httpcore==1.0.9
84
+ # via httpx
85
+ httpx==0.28.1
86
+ # via
87
+ # datasets
88
+ # gradio
89
+ # gradio-client
90
+ # huggingface-hub
91
+ # safehttpx
92
+ # spaces
93
+ huggingface-hub==1.5.0
94
+ # via
95
+ # accelerate
96
+ # datasets
97
+ # gradio
98
+ # gradio-client
99
+ # tokenizers
100
+ # transformers
101
+ idna==3.11
102
+ # via
103
+ # anyio
104
+ # httpx
105
+ # requests
106
+ # yarl
107
+ jinja2==3.1.6
108
+ # via
109
+ # gradio
110
+ # torch
111
+ markdown-it-py==4.0.0
112
+ # via rich
113
+ markupsafe==3.0.3
114
+ # via
115
+ # gradio
116
+ # jinja2
117
+ mdurl==0.1.2
118
+ # via markdown-it-py
119
+ mpmath==1.3.0
120
+ # via sympy
121
+ multidict==6.7.1
122
+ # via
123
+ # aiohttp
124
+ # yarl
125
+ multiprocess==0.70.18
126
+ # via datasets
127
+ networkx==3.6.1
128
+ # via torch
129
+ numpy==2.4.2
130
+ # via
131
+ # accelerate
132
+ # datasets
133
+ # gradio
134
+ # pandas
135
+ # transformers
136
+ nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
137
+ # via
138
+ # nvidia-cudnn-cu12
139
+ # nvidia-cusolver-cu12
140
+ # torch
141
+ nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
142
+ # via torch
143
+ nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
144
+ # via torch
145
+ nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
146
+ # via torch
147
+ nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
148
+ # via torch
149
+ nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
150
+ # via torch
151
+ nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
152
+ # via torch
153
+ nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
154
+ # via torch
155
+ nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
156
+ # via torch
157
+ nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
158
+ # via
159
+ # nvidia-cusolver-cu12
160
+ # torch
161
+ nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
162
+ # via torch
163
+ nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
164
+ # via torch
165
+ nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
166
+ # via
167
+ # nvidia-cufft-cu12
168
+ # nvidia-cusolver-cu12
169
+ # nvidia-cusparse-cu12
170
+ # torch
171
+ nvidia-nvshmem-cu12==3.3.20 ; platform_machine == 'x86_64' and sys_platform == 'linux'
172
+ # via torch
173
+ nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
174
+ # via torch
175
+ orjson==3.11.7
176
+ # via gradio
177
+ packaging==26.0
178
+ # via
179
+ # accelerate
180
+ # datasets
181
+ # gradio
182
+ # gradio-client
183
+ # huggingface-hub
184
+ # spaces
185
+ # transformers
186
+ pandas==3.0.1
187
+ # via
188
+ # datasets
189
+ # gradio
190
+ pillow==12.1.1
191
+ # via gradio
192
+ propcache==0.4.1
193
+ # via
194
+ # aiohttp
195
+ # yarl
196
+ psutil==5.9.8
197
+ # via
198
+ # accelerate
199
+ # spaces
200
+ pyarrow==23.0.1
201
+ # via datasets
202
+ pydantic==2.12.5
203
+ # via
204
+ # fastapi
205
+ # gradio
206
+ # spaces
207
+ pydantic-core==2.41.5
208
+ # via pydantic
209
+ pydub==0.25.1
210
+ # via gradio
211
+ pygments==2.19.2
212
+ # via rich
213
+ python-dateutil==2.9.0.post0
214
+ # via pandas
215
+ python-multipart==0.0.22
216
+ # via gradio
217
+ pytz==2025.2
218
+ # via gradio
219
+ pyyaml==6.0.3
220
+ # via
221
+ # accelerate
222
+ # datasets
223
+ # gradio
224
+ # huggingface-hub
225
+ # transformers
226
+ regex==2026.2.28
227
+ # via transformers
228
+ requests==2.32.5
229
+ # via
230
+ # datasets
231
+ # spaces
232
+ rich==14.3.3
233
+ # via typer
234
+ safehttpx==0.1.7
235
+ # via gradio
236
+ safetensors==0.7.0
237
+ # via
238
+ # accelerate
239
+ # transformers
240
+ semantic-version==2.10.0
241
+ # via gradio
242
+ setuptools==82.0.0
243
+ # via torch
244
+ shellingham==1.5.4
245
+ # via typer
246
+ six==1.17.0
247
+ # via python-dateutil
248
+ spaces==0.47.0
249
+ # via cat-translate-7b
250
+ starlette==0.52.1
251
+ # via
252
+ # fastapi
253
+ # gradio
254
+ sympy==1.14.0
255
+ # via torch
256
+ tokenizers==0.22.2
257
+ # via transformers
258
+ tomlkit==0.13.3
259
+ # via gradio
260
+ torch==2.9.1
261
+ # via
262
+ # accelerate
263
+ # cat-translate-7b
264
+ tqdm==4.67.3
265
+ # via
266
+ # datasets
267
+ # huggingface-hub
268
+ # transformers
269
+ transformers==5.2.0
270
+ # via cat-translate-7b
271
+ triton==3.5.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
272
+ # via torch
273
+ typer==0.24.1
274
+ # via
275
+ # gradio
276
+ # huggingface-hub
277
+ # typer-slim
278
+ typing-extensions==4.15.0
279
+ # via
280
+ # aiosignal
281
+ # anyio
282
+ # fastapi
283
+ # gradio
284
+ # gradio-client
285
+ # huggingface-hub
286
+ # pydantic
287
+ # pydantic-core
288
+ # spaces
289
+ # starlette
290
+ # torch
291
+ # typing-inspection
292
+ typing-inspection==0.4.2
293
+ # via
294
+ # fastapi
295
+ # pydantic
296
+ tzdata==2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32'
297
+ # via pandas
298
+ urllib3==2.6.3
299
+ # via requests
300
+ uvicorn==0.41.0
301
+ # via gradio
302
+ xxhash==3.6.0
303
+ # via datasets
304
+ yarl==1.23.0
305
+ # via aiohttp
style.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff