Li Wei Chen commited on
Commit ·
80cdf12
1
Parent(s): 14e913b
fix: correct G2P duration estimation and remove punctuation spaces
Browse filesWhen G2P is enabled, estimate audio duration from the original Chinese
text instead of the pinyin output to avoid weight inflation from tone
number digits (weight 3.5 vs CJK 3.0). Also strip spaces around
punctuation in the G2P output.
app.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Any
|
| 7 |
|
|
@@ -158,7 +159,10 @@ def apply_g2p(text: str, dialect: str) -> str:
|
|
| 158 |
|
| 159 |
lang_group = DIALECT_TO_LANG_GROUP.get(dialect, "hak_sx")
|
| 160 |
result = g2p(text, lang_group=lang_group, pronunciation_type="pinyin")
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
def validate_inputs(
|
|
@@ -208,11 +212,9 @@ def synthesize(
|
|
| 208 |
return None, startup_status()
|
| 209 |
|
| 210 |
try:
|
| 211 |
-
|
| 212 |
g2p_note = ""
|
| 213 |
-
|
| 214 |
-
input_text = apply_g2p(input_text, dialect)
|
| 215 |
-
g2p_note = f";G2P 轉換:{input_text}"
|
| 216 |
|
| 217 |
generation_config = RUNTIME.generation_config_cls(
|
| 218 |
num_step=int(num_step),
|
|
@@ -226,6 +228,21 @@ def synthesize(
|
|
| 226 |
ref_text=ref_text.strip(),
|
| 227 |
preprocess_prompt=True,
|
| 228 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
generate_kwargs: dict[str, Any] = {
|
| 230 |
"text": input_text,
|
| 231 |
"voice_clone_prompt": voice_clone_prompt,
|
|
@@ -233,7 +250,9 @@ def synthesize(
|
|
| 233 |
"generation_config": generation_config,
|
| 234 |
"language": "zh",
|
| 235 |
}
|
| 236 |
-
if
|
|
|
|
|
|
|
| 237 |
generate_kwargs["speed"] = float(speed)
|
| 238 |
|
| 239 |
audio = RUNTIME.model.generate(**generate_kwargs)
|
|
|
|
| 2 |
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
+
import re
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from typing import Any
|
| 8 |
|
|
|
|
| 159 |
|
| 160 |
lang_group = DIALECT_TO_LANG_GROUP.get(dialect, "hak_sx")
|
| 161 |
result = g2p(text, lang_group=lang_group, pronunciation_type="pinyin")
|
| 162 |
+
joined = " ".join(result.pronunciations).upper()
|
| 163 |
+
joined = re.sub(r"\s+([,。!?;:、…「」『』【】〔〕()])", r"\1", joined)
|
| 164 |
+
joined = re.sub(r"([,。!?;:、…「」『』【】〔〕()])\s+", r"\1", joined)
|
| 165 |
+
return joined
|
| 166 |
|
| 167 |
|
| 168 |
def validate_inputs(
|
|
|
|
| 212 |
return None, startup_status()
|
| 213 |
|
| 214 |
try:
|
| 215 |
+
original_text = text.strip()
|
| 216 |
g2p_note = ""
|
| 217 |
+
duration_override = None
|
|
|
|
|
|
|
| 218 |
|
| 219 |
generation_config = RUNTIME.generation_config_cls(
|
| 220 |
num_step=int(num_step),
|
|
|
|
| 228 |
ref_text=ref_text.strip(),
|
| 229 |
preprocess_prompt=True,
|
| 230 |
)
|
| 231 |
+
|
| 232 |
+
if use_g2p:
|
| 233 |
+
input_text = apply_g2p(original_text, dialect)
|
| 234 |
+
g2p_note = f";G2P 轉換:{input_text}"
|
| 235 |
+
# Estimate duration from original Chinese text to avoid weight inflation
|
| 236 |
+
# caused by tone number digits (weight 3.5) in the G2P output.
|
| 237 |
+
num_ref_tokens = voice_clone_prompt.ref_audio_tokens.size(-1)
|
| 238 |
+
frame_rate = RUNTIME.model.audio_tokenizer.config.frame_rate
|
| 239 |
+
est_frames = RUNTIME.model.duration_estimator.estimate_duration(
|
| 240 |
+
original_text, voice_clone_prompt.ref_text, num_ref_tokens
|
| 241 |
+
)
|
| 242 |
+
duration_override = est_frames / float(speed) / frame_rate
|
| 243 |
+
else:
|
| 244 |
+
input_text = original_text
|
| 245 |
+
|
| 246 |
generate_kwargs: dict[str, Any] = {
|
| 247 |
"text": input_text,
|
| 248 |
"voice_clone_prompt": voice_clone_prompt,
|
|
|
|
| 250 |
"generation_config": generation_config,
|
| 251 |
"language": "zh",
|
| 252 |
}
|
| 253 |
+
if duration_override is not None:
|
| 254 |
+
generate_kwargs["duration"] = duration_override
|
| 255 |
+
elif speed != DEFAULT_SPEED:
|
| 256 |
generate_kwargs["speed"] = float(speed)
|
| 257 |
|
| 258 |
audio = RUNTIME.model.generate(**generate_kwargs)
|