mamungtai-sat pormungtai commited on
Commit
140a849
·
1 Parent(s): 5800fd3

Add Thai->English auto-translate (NLLB-200 + Typhoon 2 selectable) (#11)

Browse files

- Add Thai->English auto-translate (NLLB-200 + Typhoon 2 selectable) (a19330a931065d6a485e95a03f39e6282f6c6dad)


Co-authored-by: pormungtailaw <pormungtai@users.noreply.huggingface.co>

Files changed (2) hide show
  1. app.py +21 -3
  2. pipeline_manager.py +66 -0
app.py CHANGED
@@ -79,7 +79,8 @@ def modes_for(models, model_id):
79
  # ---------------------------------------------------------------------------
80
  @spaces.GPU(duration=120)
81
  def generate(model_id, mode, prompt, negative_prompt, ref_image,
82
- steps, guidance, denoise, ip_scale, width, height, seed, randomize):
 
83
  models = load_models()
84
  cfg = pm.get_model(models, model_id)
85
  if cfg is None:
@@ -88,6 +89,14 @@ def generate(model_id, mode, prompt, negative_prompt, ref_image,
88
  if randomize or seed is None or int(seed) < 0:
89
  seed = random.randint(0, MAX_SEED)
90
 
 
 
 
 
 
 
 
 
91
  try:
92
  img = pm.run_generation(
93
  cfg=cfg, mode=mode, prompt=prompt, negative_prompt=negative_prompt,
@@ -98,7 +107,7 @@ def generate(model_id, mode, prompt, negative_prompt, ref_image,
98
  traceback.print_exc()
99
  raise gr.Error(str(e))
100
 
101
- status = f"✅ {cfg['label']} · {pm.MODE_LABELS.get(mode, mode)} · seed {seed}"
102
  return img, seed, status
103
 
104
 
@@ -189,6 +198,14 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(primary_hue="blue"),
189
  label="โหมดรูปต้นแบบ / Input mode",
190
  )
191
 
 
 
 
 
 
 
 
 
192
  # ---- right: output ----
193
  with gr.Column(scale=1):
194
  output = gr.Image(label="Generated Image", height=560, elem_classes="card")
@@ -227,7 +244,8 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft(primary_hue="blue"),
227
  )
228
 
229
  gen_inputs = [selected_id, mode_radio, prompt, negative_prompt, ref_image,
230
- steps, guidance, denoise, ip_scale, width, height, seed, randomize]
 
231
  gen_btn.click(generate, inputs=gen_inputs, outputs=[output, seed, status])
232
  prompt.submit(generate, inputs=gen_inputs, outputs=[output, seed, status])
233
 
 
79
  # ---------------------------------------------------------------------------
80
  @spaces.GPU(duration=120)
81
  def generate(model_id, mode, prompt, negative_prompt, ref_image,
82
+ steps, guidance, denoise, ip_scale, width, height, seed, randomize,
83
+ translator):
84
  models = load_models()
85
  cfg = pm.get_model(models, model_id)
86
  if cfg is None:
 
89
  if randomize or seed is None or int(seed) < 0:
90
  seed = random.randint(0, MAX_SEED)
91
 
92
+ # Thai → English so the (English) text encoders understand the prompt.
93
+ note = ""
94
+ orig_prompt = prompt
95
+ prompt = pm.translate_prompt(prompt, translator)
96
+ negative_prompt = pm.translate_prompt(negative_prompt, translator)
97
+ if prompt != orig_prompt:
98
+ note = f" · 🌐 {translator}: _{prompt[:120]}_"
99
+
100
  try:
101
  img = pm.run_generation(
102
  cfg=cfg, mode=mode, prompt=prompt, negative_prompt=negative_prompt,
 
107
  traceback.print_exc()
108
  raise gr.Error(str(e))
109
 
110
+ status = f"✅ {cfg['label']} · {pm.MODE_LABELS.get(mode, mode)} · seed {seed}{note}"
111
  return img, seed, status
112
 
113
 
 
198
  label="โหมดรูปต้นแบบ / Input mode",
199
  )
200
 
201
+ translator = gr.Radio(
202
+ choices=[("ปิด / Off", "off"),
203
+ ("NLLB-200 (เร็ว)", "nllb"),
204
+ ("Typhoon 2 (ไทยแน่น)", "typhoon")],
205
+ value="nllb",
206
+ label="แปลไทย→อังกฤษ / Auto-translate (พิมพ์ไทยได้เลย)",
207
+ )
208
+
209
  # ---- right: output ----
210
  with gr.Column(scale=1):
211
  output = gr.Image(label="Generated Image", height=560, elem_classes="card")
 
244
  )
245
 
246
  gen_inputs = [selected_id, mode_radio, prompt, negative_prompt, ref_image,
247
+ steps, guidance, denoise, ip_scale, width, height, seed, randomize,
248
+ translator]
249
  gen_btn.click(generate, inputs=gen_inputs, outputs=[output, seed, status])
250
  prompt.submit(generate, inputs=gen_inputs, outputs=[output, seed, status])
251
 
pipeline_manager.py CHANGED
@@ -68,6 +68,72 @@ def get_model(models, model_id):
68
  return None
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # ---------------------------------------------------------------------------
72
  # Download helpers (Civitai / arbitrary URL → local cache)
73
  # ---------------------------------------------------------------------------
 
68
  return None
69
 
70
 
71
+ # ---------------------------------------------------------------------------
72
+ # Thai → English prompt translation (the SD/SDXL/FLUX text encoders are English;
73
+ # Thai prompts otherwise produce unrelated images). Runs on the Space, no API key.
74
+ # ---------------------------------------------------------------------------
75
+ TRANSLATORS = {
76
+ "nllb": "facebook/nllb-200-distilled-600M",
77
+ "typhoon": "scb10x/llama3.2-typhoon2-3b-instruct",
78
+ }
79
+ _TRANSLATOR_CACHE = {}
80
+
81
+
82
+ def has_thai(text):
83
+ return any("฀" <= ch <= "๿" for ch in (text or ""))
84
+
85
+
86
+ def _load_translator(engine):
87
+ if engine in _TRANSLATOR_CACHE:
88
+ return _TRANSLATOR_CACHE[engine]
89
+ name = TRANSLATORS[engine]
90
+ if engine == "nllb":
91
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
92
+ tok = AutoTokenizer.from_pretrained(name)
93
+ model = AutoModelForSeq2SeqLM.from_pretrained(name, torch_dtype=DTYPE_SD)
94
+ else: # typhoon (causal LM)
95
+ from transformers import AutoTokenizer, AutoModelForCausalLM
96
+ tok = AutoTokenizer.from_pretrained(name)
97
+ model = AutoModelForCausalLM.from_pretrained(name, torch_dtype=DTYPE_SD)
98
+ model.eval()
99
+ _TRANSLATOR_CACHE[engine] = (tok, model)
100
+ return tok, model
101
+
102
+
103
+ def translate_prompt(text, engine):
104
+ """Translate a Thai prompt to English. Pass-through if empty/English/off.
105
+ MUST be called inside the @spaces.GPU context (uses CUDA when available)."""
106
+ if not text or engine in (None, "off") or not has_thai(text):
107
+ return text
108
+ try:
109
+ tok, model = _load_translator(engine)
110
+ model = model.to(DEVICE)
111
+ if engine == "nllb":
112
+ tok.src_lang = "tha_Thai"
113
+ inputs = tok(text, return_tensors="pt", truncation=True,
114
+ max_length=400).to(DEVICE)
115
+ bos = tok.convert_tokens_to_ids("eng_Latn")
116
+ out = model.generate(**inputs, forced_bos_token_id=bos,
117
+ max_new_tokens=256, num_beams=4)
118
+ return tok.batch_decode(out, skip_special_tokens=True)[0].strip()
119
+ # typhoon: ask the LLM to rewrite as a clean English image prompt
120
+ msgs = [
121
+ {"role": "system", "content": "You convert Thai text-to-image prompts "
122
+ "into a single concise, vivid English prompt for Stable Diffusion. "
123
+ "Keep the described subject, clothing, pose, and scene. Output ONLY the "
124
+ "English prompt as a comma-separated phrase — no quotes, no explanation."},
125
+ {"role": "user", "content": text},
126
+ ]
127
+ ids = tok.apply_chat_template(msgs, add_generation_prompt=True,
128
+ return_tensors="pt").to(DEVICE)
129
+ out = model.generate(ids, max_new_tokens=256, do_sample=False,
130
+ pad_token_id=tok.eos_token_id)
131
+ return tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
132
+ except Exception as e: # noqa
133
+ print(f"[translate] {engine} failed, using original text: {e}")
134
+ return text
135
+
136
+
137
  # ---------------------------------------------------------------------------
138
  # Download helpers (Civitai / arbitrary URL → local cache)
139
  # ---------------------------------------------------------------------------