Li Wei Chen commited on
Commit
80cdf12
·
1 Parent(s): 14e913b

fix: correct G2P duration estimation and remove punctuation spaces

Browse files

When G2P is enabled, estimate audio duration from the original Chinese
text instead of the pinyin output to avoid weight inflation from tone
number digits (weight 3.5 vs CJK 3.0). Also strip spaces around
punctuation in the G2P output.

Files changed (1) hide show
  1. app.py +25 -6
app.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import logging
4
  import os
 
5
  from dataclasses import dataclass
6
  from typing import Any
7
 
@@ -158,7 +159,10 @@ def apply_g2p(text: str, dialect: str) -> str:
158
 
159
  lang_group = DIALECT_TO_LANG_GROUP.get(dialect, "hak_sx")
160
  result = g2p(text, lang_group=lang_group, pronunciation_type="pinyin")
161
- return " ".join(result.pronunciations).upper()
 
 
 
162
 
163
 
164
  def validate_inputs(
@@ -208,11 +212,9 @@ def synthesize(
208
  return None, startup_status()
209
 
210
  try:
211
- input_text = text.strip()
212
  g2p_note = ""
213
- if use_g2p:
214
- input_text = apply_g2p(input_text, dialect)
215
- g2p_note = f";G2P 轉換:{input_text}"
216
 
217
  generation_config = RUNTIME.generation_config_cls(
218
  num_step=int(num_step),
@@ -226,6 +228,21 @@ def synthesize(
226
  ref_text=ref_text.strip(),
227
  preprocess_prompt=True,
228
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  generate_kwargs: dict[str, Any] = {
230
  "text": input_text,
231
  "voice_clone_prompt": voice_clone_prompt,
@@ -233,7 +250,9 @@ def synthesize(
233
  "generation_config": generation_config,
234
  "language": "zh",
235
  }
236
- if speed != DEFAULT_SPEED:
 
 
237
  generate_kwargs["speed"] = float(speed)
238
 
239
  audio = RUNTIME.model.generate(**generate_kwargs)
 
2
 
3
  import logging
4
  import os
5
+ import re
6
  from dataclasses import dataclass
7
  from typing import Any
8
 
 
159
 
160
  lang_group = DIALECT_TO_LANG_GROUP.get(dialect, "hak_sx")
161
  result = g2p(text, lang_group=lang_group, pronunciation_type="pinyin")
162
+ joined = " ".join(result.pronunciations).upper()
163
+ joined = re.sub(r"\s+([,。!?;:、…「」『』【】〔〕()])", r"\1", joined)
164
+ joined = re.sub(r"([,。!?;:、…「」『』【】〔〕()])\s+", r"\1", joined)
165
+ return joined
166
 
167
 
168
  def validate_inputs(
 
212
  return None, startup_status()
213
 
214
  try:
215
+ original_text = text.strip()
216
  g2p_note = ""
217
+ duration_override = None
 
 
218
 
219
  generation_config = RUNTIME.generation_config_cls(
220
  num_step=int(num_step),
 
228
  ref_text=ref_text.strip(),
229
  preprocess_prompt=True,
230
  )
231
+
232
+ if use_g2p:
233
+ input_text = apply_g2p(original_text, dialect)
234
+ g2p_note = f";G2P 轉換:{input_text}"
235
+ # Estimate duration from original Chinese text to avoid weight inflation
236
+ # caused by tone number digits (weight 3.5) in the G2P output.
237
+ num_ref_tokens = voice_clone_prompt.ref_audio_tokens.size(-1)
238
+ frame_rate = RUNTIME.model.audio_tokenizer.config.frame_rate
239
+ est_frames = RUNTIME.model.duration_estimator.estimate_duration(
240
+ original_text, voice_clone_prompt.ref_text, num_ref_tokens
241
+ )
242
+ duration_override = est_frames / float(speed) / frame_rate
243
+ else:
244
+ input_text = original_text
245
+
246
  generate_kwargs: dict[str, Any] = {
247
  "text": input_text,
248
  "voice_clone_prompt": voice_clone_prompt,
 
250
  "generation_config": generation_config,
251
  "language": "zh",
252
  }
253
+ if duration_override is not None:
254
+ generate_kwargs["duration"] = duration_override
255
+ elif speed != DEFAULT_SPEED:
256
  generate_kwargs["speed"] = float(speed)
257
 
258
  audio = RUNTIME.model.generate(**generate_kwargs)