txya900619 commited on
Commit
0cfd22f
·
unverified ·
1 Parent(s): c0e4a49

feat: add initial implementation of Taiwanese Hakka TTS system with example audio files

Browse files
app.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+
11
+ MODEL_ID = "formospeech/omnivoice-taiwanese-hakka"
12
+ DIALECT_LABELS = [
13
+ "客語四縣腔",
14
+ "客語海陸腔",
15
+ "客語大埔腔",
16
+ "客語饒平腔",
17
+ "客語詔安腔",
18
+ "客語南四縣腔",
19
+ ]
20
+ DEFAULT_SPEED = 1.0
21
+ DEFAULT_STEPS = 32
22
+ EXAMPLES = [
23
+ [
24
+ "客語四縣腔",
25
+ "食飯愛正經食,正毋會食到半出半入。",
26
+ "refs/0000001_0.15-0.93.wav",
27
+ "恁早。",
28
+ DEFAULT_SPEED,
29
+ DEFAULT_STEPS,
30
+ ],
31
+ [
32
+ "客語四縣腔",
33
+ "食飯愛正經食,正毋會食到半出半入。",
34
+ "refs/0000002_0.15-2.73.wav",
35
+ "你今晡日著到恁派頭。",
36
+ DEFAULT_SPEED,
37
+ DEFAULT_STEPS,
38
+ ],
39
+ [
40
+ "客語四縣腔",
41
+ "歸條路吊等長長个花燈,祈求風調雨順,歸屋下人个心願,親像花燈下燒暖个光華。",
42
+ "refs/0000002_0.15-2.73.wav",
43
+ "你今晡日著到恁派頭。",
44
+ DEFAULT_SPEED,
45
+ DEFAULT_STEPS,
46
+ ],
47
+ ]
48
+
49
+
50
+ @dataclass
51
+ class RuntimeState:
52
+ model: Any | None
53
+ generation_config_cls: Any | None
54
+ sampling_rate: int | None
55
+ device: str
56
+ dtype_name: str
57
+ load_error: str | None = None
58
+
59
+
60
+ def get_best_device() -> str:
61
+ try:
62
+ import torch
63
+ except Exception:
64
+ return "cpu"
65
+
66
+ if torch.cuda.is_available():
67
+ return "cuda"
68
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
69
+ return "mps"
70
+ return "cpu"
71
+
72
+
73
+ def load_runtime() -> RuntimeState:
74
+ device = get_best_device()
75
+ dtype_name = "float16" if device == "cuda" else "float32"
76
+
77
+ try:
78
+ import torch
79
+ from omnivoice import OmniVoice, OmniVoiceGenerationConfig
80
+ except Exception as exc:
81
+ return RuntimeState(
82
+ model=None,
83
+ generation_config_cls=None,
84
+ sampling_rate=None,
85
+ device=device,
86
+ dtype_name=dtype_name,
87
+ load_error=f"依賴載入失敗:{type(exc).__name__}: {exc}",
88
+ )
89
+
90
+ dtype = torch.float16 if device == "cuda" else torch.float32
91
+
92
+ try:
93
+ logging.info("Loading model %s on %s with %s", MODEL_ID, device, dtype_name)
94
+ model = OmniVoice.from_pretrained(
95
+ MODEL_ID,
96
+ device_map=device,
97
+ dtype=dtype,
98
+ load_asr=False,
99
+ )
100
+ except Exception as exc:
101
+ return RuntimeState(
102
+ model=None,
103
+ generation_config_cls=OmniVoiceGenerationConfig,
104
+ sampling_rate=None,
105
+ device=device,
106
+ dtype_name=dtype_name,
107
+ load_error=f"模型載入失敗:{type(exc).__name__}: {exc}",
108
+ )
109
+
110
+ return RuntimeState(
111
+ model=model,
112
+ generation_config_cls=OmniVoiceGenerationConfig,
113
+ sampling_rate=model.sampling_rate,
114
+ device=device,
115
+ dtype_name=dtype_name,
116
+ )
117
+
118
+
119
+ RUNTIME = load_runtime()
120
+
121
+
122
+ def startup_status() -> str:
123
+ if RUNTIME.load_error:
124
+ return RUNTIME.load_error
125
+ return (
126
+ f"模型已載入:{MODEL_ID}\n"
127
+ f"裝置:{RUNTIME.device}\n"
128
+ f"推論精度:{RUNTIME.dtype_name}"
129
+ )
130
+
131
+
132
+ def validate_inputs(
133
+ dialect: str | None,
134
+ text: str,
135
+ ref_audio: str | None,
136
+ ref_text: str,
137
+ ) -> str | None:
138
+ if dialect not in DIALECT_LABELS:
139
+ return "請先選擇客語腔調。"
140
+ if not text or not text.strip():
141
+ return "請輸入要合成的文字。"
142
+ if not ref_audio:
143
+ return "請上傳參考音檔。"
144
+ if not ref_text or not ref_text.strip():
145
+ return "請輸入參考文本。"
146
+ return None
147
+
148
+
149
+ def to_audio_output(audio: np.ndarray, sampling_rate: int) -> tuple[int, np.ndarray]:
150
+ waveform = np.asarray(audio)
151
+ if waveform.ndim > 1:
152
+ waveform = np.squeeze(waveform)
153
+ waveform = np.clip(waveform, -1.0, 1.0)
154
+ return sampling_rate, (waveform * 32767).astype(np.int16)
155
+
156
+
157
+ def synthesize(
158
+ dialect: str | None,
159
+ text: str,
160
+ ref_audio: str | None,
161
+ ref_text: str,
162
+ speed: float,
163
+ num_step: int,
164
+ ) -> tuple[tuple[int, np.ndarray] | None, str]:
165
+ error = validate_inputs(dialect, text, ref_audio, ref_text)
166
+ if error:
167
+ return None, error
168
+
169
+ if (
170
+ RUNTIME.load_error
171
+ or RUNTIME.model is None
172
+ or RUNTIME.generation_config_cls is None
173
+ ):
174
+ return None, startup_status()
175
+
176
+ try:
177
+ generation_config = RUNTIME.generation_config_cls(
178
+ num_step=int(num_step),
179
+ guidance_scale=2.0,
180
+ denoise=True,
181
+ preprocess_prompt=True,
182
+ postprocess_output=True,
183
+ )
184
+ voice_clone_prompt = RUNTIME.model.create_voice_clone_prompt(
185
+ ref_audio=ref_audio,
186
+ ref_text=ref_text.strip(),
187
+ preprocess_prompt=True,
188
+ )
189
+ generate_kwargs: dict[str, Any] = {
190
+ "text": text.strip(),
191
+ "voice_clone_prompt": voice_clone_prompt,
192
+ "instruct": dialect,
193
+ "generation_config": generation_config,
194
+ "language": "zh",
195
+ }
196
+ if speed != DEFAULT_SPEED:
197
+ generate_kwargs["speed"] = float(speed)
198
+
199
+ audio = RUNTIME.model.generate(**generate_kwargs)
200
+ if not audio:
201
+ return None, "模型沒有回傳音訊。"
202
+
203
+ return (
204
+ to_audio_output(audio[0], int(RUNTIME.sampling_rate or 24000)),
205
+ f"合成完成。腔調:{dialect};speed={speed:.2f};steps={int(num_step)}",
206
+ )
207
+ except Exception as exc:
208
+ return None, f"合成失敗:{type(exc).__name__}: {exc}"
209
+
210
+
211
+ def build_demo() -> gr.Blocks:
212
+ with gr.Blocks(title="臺灣客語語音生成系統") as demo:
213
+ with gr.Column():
214
+ gr.Markdown(
215
+ """
216
+ # 臺灣客語語音合成系統
217
+ ### Taiwanese Hakka Text-to-Speech System
218
+ ### 研發團隊
219
+ - **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)**
220
+ - **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)**
221
+ ### 合作單位
222
+ - **[國立聯合大學智慧客家實驗室](https://www.gohakka.org)**
223
+ """
224
+ )
225
+
226
+ with gr.Row(equal_height=False):
227
+ with gr.Column(scale=11, elem_classes="panel"):
228
+ dialect = gr.Dropdown(
229
+ choices=DIALECT_LABELS,
230
+ value=None,
231
+ allow_custom_value=False,
232
+ label="客語腔調",
233
+ info="此模型用 instruct 控制腔調,推論前必選。",
234
+ )
235
+ text = gr.Textbox(
236
+ label="要合成的文字",
237
+ lines=4,
238
+ placeholder="例如:這下來試看啊,客語語音合成聽起來仰般。",
239
+ )
240
+ ref_audio = gr.Audio(
241
+ label="參考音檔",
242
+ type="filepath",
243
+ )
244
+ ref_text = gr.Textbox(
245
+ label="參考文本",
246
+ lines=2,
247
+ placeholder="請填寫參考音檔對應的逐字文本。",
248
+ )
249
+ with gr.Accordion("進階設定", open=False):
250
+ speed = gr.Slider(
251
+ minimum=0.5,
252
+ maximum=1.5,
253
+ value=DEFAULT_SPEED,
254
+ step=0.05,
255
+ label="Speed",
256
+ info="1.0 為預設語速;越大越快。",
257
+ )
258
+ num_step = gr.Slider(
259
+ minimum=4,
260
+ maximum=32,
261
+ value=DEFAULT_STEPS,
262
+ step=1,
263
+ label="Inference Steps",
264
+ info="步數越高通常品質越穩,但速度較慢。",
265
+ )
266
+ submit = gr.Button("開始合成", variant="primary")
267
+
268
+ with gr.Column(scale=9):
269
+ output_audio = gr.Audio(
270
+ label="合成結果",
271
+ type="numpy",
272
+ )
273
+ status = gr.Textbox(
274
+ label="狀態",
275
+ value=startup_status(),
276
+ lines=6,
277
+ interactive=False,
278
+ )
279
+
280
+ submit.click(
281
+ fn=synthesize,
282
+ inputs=[dialect, text, ref_audio, ref_text, speed, num_step],
283
+ outputs=[output_audio, status],
284
+ )
285
+
286
+ gr.Examples(
287
+ examples=EXAMPLES,
288
+ inputs=[dialect, text, ref_audio, ref_text, speed, num_step],
289
+ label="範例",
290
+ )
291
+
292
+ return demo
293
+
294
+
295
+ demo = build_demo()
296
+
297
+
298
+ def main() -> None:
299
+ logging.basicConfig(
300
+ level=logging.INFO,
301
+ format="%(asctime)s %(levelname)s %(message)s",
302
+ )
303
+ demo.queue().launch(
304
+ css="@import url(https://tauhu.tw/tauhu-oo.css);",
305
+ theme=gr.themes.Default(
306
+ font=(
307
+ "tauhu-oo",
308
+ gr.themes.GoogleFont("Source Sans Pro"),
309
+ "ui-sans-serif",
310
+ "system-ui",
311
+ "sans-serif",
312
+ )
313
+ ),
314
+ server_name="0.0.0.0",
315
+ server_port=int(os.getenv("PORT", "7860")),
316
+ )
317
+
318
+
319
+ if __name__ == "__main__":
320
+ main()
refs/0000001_0.15-0.93.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78ef5c480c70cb01e6065a59bf896b4598a638a075d5a1d287342c018b7a7129
3
+ size 37484
refs/0000002_0.15-2.73.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:560d43f7e4dc1cc7a994f98ca9870a8c5dad7be1a89fe2d26b534519370a5bd2
3
+ size 123884
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git+https://github.com/txya900619/OmniVoice-hakka.git