| from __future__ import annotations |
|
|
| import logging |
| import os |
| import re |
| from dataclasses import dataclass |
| from typing import Any |
|
|
| import gradio as gr |
| import numpy as np |
|
|
| try: |
| import spaces |
|
|
| USING_SPACES = True |
| except ImportError: |
| USING_SPACES = False |
|
|
| MODEL_ID = "formospeech/omnivoice-taiwanese-hakka" |
| DIALECT_LABELS = [ |
| "客語四縣腔", |
| "客語海陸腔", |
| "客語大埔腔", |
| "客語饒平腔", |
| "客語詔安腔", |
| "客語南四縣腔", |
| ] |
| DIALECT_TO_LANG_GROUP = { |
| "客語四縣腔": "hak_sx", |
| "客語海陸腔": "hak_hl", |
| "客語大埔腔": "hak_dp", |
| "客語饒平腔": "hak_rp", |
| "客語詔安腔": "hak_za", |
| "客語南四縣腔": "hak_nsx", |
| } |
| DEFAULT_SPEED = 1.0 |
| DEFAULT_STEPS = 32 |
| EXAMPLES = [ |
| [ |
| "客語四縣腔", |
| "食飯愛正經食,正毋會食到半出半入。", |
| "refs/0000001_0.15-0.93.wav", |
| "恁早。", |
| DEFAULT_SPEED, |
| DEFAULT_STEPS, |
| False, |
| ], |
| [ |
| "客語四縣腔", |
| "食飯愛正經食,正毋會食到半出半入。", |
| "refs/0000002_0.15-2.73.wav", |
| "你今晡日著到恁派頭。", |
| DEFAULT_SPEED, |
| DEFAULT_STEPS, |
| False, |
| ], |
| [ |
| "客語四縣腔", |
| "歸條路吊等長長个花燈,祈求風調雨順,歸屋下人个心願,親像花燈下燒暖个光華。", |
| "refs/0000002_0.15-2.73.wav", |
| "你今晡日著到恁派頭。", |
| DEFAULT_SPEED, |
| DEFAULT_STEPS, |
| False, |
| ], |
| ] |
|
|
|
|
| @dataclass |
| class RuntimeState: |
| model: Any | None |
| generation_config_cls: Any | None |
| sampling_rate: int | None |
| device: str |
| dtype_name: str |
| load_error: str | None = None |
|
|
|
|
| def gpu_decorator(func): |
| if USING_SPACES: |
| return spaces.GPU(func) |
| return func |
|
|
|
|
| def get_best_device() -> str: |
| try: |
| import torch |
| except Exception: |
| return "cpu" |
|
|
| if torch.cuda.is_available(): |
| return "cuda" |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
|
|
| def load_runtime() -> RuntimeState: |
| device = get_best_device() |
| dtype_name = "float16" if device == "cuda" else "float32" |
|
|
| try: |
| import torch |
| from omnivoice import OmniVoice, OmniVoiceGenerationConfig |
| except Exception as exc: |
| return RuntimeState( |
| model=None, |
| generation_config_cls=None, |
| sampling_rate=None, |
| device=device, |
| dtype_name=dtype_name, |
| load_error=f"依賴載入失敗:{type(exc).__name__}: {exc}", |
| ) |
|
|
| dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
| try: |
| logging.info("Loading model %s on %s with %s", MODEL_ID, device, dtype_name) |
| model = OmniVoice.from_pretrained( |
| MODEL_ID, |
| device_map=device, |
| dtype=dtype, |
| load_asr=False, |
| ) |
| except Exception as exc: |
| return RuntimeState( |
| model=None, |
| generation_config_cls=OmniVoiceGenerationConfig, |
| sampling_rate=None, |
| device=device, |
| dtype_name=dtype_name, |
| load_error=f"模型載入失敗:{type(exc).__name__}: {exc}", |
| ) |
|
|
| return RuntimeState( |
| model=model, |
| generation_config_cls=OmniVoiceGenerationConfig, |
| sampling_rate=model.sampling_rate, |
| device=device, |
| dtype_name=dtype_name, |
| ) |
|
|
|
|
| RUNTIME = load_runtime() |
|
|
|
|
| def startup_status() -> str: |
| if RUNTIME.load_error: |
| return RUNTIME.load_error |
| return ( |
| f"模型已載入:{MODEL_ID}\n" |
| f"裝置:{RUNTIME.device}\n" |
| f"推論精度:{RUNTIME.dtype_name}" |
| ) |
|
|
|
|
| def apply_g2p(text: str, dialect: str) -> str: |
| from formog2p.hakka.g2p import g2p |
|
|
| lang_group = DIALECT_TO_LANG_GROUP.get(dialect, "hak_sx") |
| result = g2p(text, lang_group=lang_group, pronunciation_type="pinyin") |
| joined = " ".join(result.pronunciations).upper() |
| joined = re.sub(r"\s+([,。!?;:、…「」『』【】〔〕()])", r"\1", joined) |
| joined = re.sub(r"([,。!?;:、…「」『』【】〔〕()])\s+", r"\1", joined) |
| return joined |
|
|
|
|
| def validate_inputs( |
| dialect: str | None, |
| text: str, |
| ref_audio: str | None, |
| ref_text: str, |
| ) -> str | None: |
| if dialect not in DIALECT_LABELS: |
| return "請先選擇客語腔調。" |
| if not text or not text.strip(): |
| return "請輸入要合成的文字。" |
| if not ref_audio: |
| return "請上傳參考音檔。" |
| if not ref_text or not ref_text.strip(): |
| return "請輸入參考文本。" |
| return None |
|
|
|
|
| def to_audio_output(audio: np.ndarray, sampling_rate: int) -> tuple[int, np.ndarray]: |
| waveform = np.asarray(audio) |
| if waveform.ndim > 1: |
| waveform = np.squeeze(waveform) |
| waveform = np.clip(waveform, -1.0, 1.0) |
| return sampling_rate, (waveform * 32767).astype(np.int16) |
|
|
|
|
| @gpu_decorator |
| def synthesize( |
| dialect: str | None, |
| text: str, |
| ref_audio: str | None, |
| ref_text: str, |
| speed: float, |
| num_step: int, |
| use_g2p: bool, |
| ) -> tuple[tuple[int, np.ndarray] | None, str]: |
| error = validate_inputs(dialect, text, ref_audio, ref_text) |
| if error: |
| return None, error |
|
|
| if ( |
| RUNTIME.load_error |
| or RUNTIME.model is None |
| or RUNTIME.generation_config_cls is None |
| ): |
| return None, startup_status() |
|
|
| try: |
| original_text = text.strip() |
| g2p_note = "" |
| duration_override = None |
|
|
| generation_config = RUNTIME.generation_config_cls( |
| num_step=int(num_step), |
| guidance_scale=2.0, |
| denoise=True, |
| preprocess_prompt=True, |
| postprocess_output=True, |
| ) |
| voice_clone_prompt = RUNTIME.model.create_voice_clone_prompt( |
| ref_audio=ref_audio, |
| ref_text=ref_text.strip(), |
| preprocess_prompt=True, |
| ) |
|
|
| if use_g2p: |
| input_text = apply_g2p(original_text, dialect) |
| g2p_note = f";G2P 轉換:{input_text}" |
| |
| |
| num_ref_tokens = voice_clone_prompt.ref_audio_tokens.size(-1) |
| frame_rate = RUNTIME.model.audio_tokenizer.config.frame_rate |
| est_frames = RUNTIME.model.duration_estimator.estimate_duration( |
| original_text, voice_clone_prompt.ref_text, num_ref_tokens |
| ) |
| duration_override = est_frames / float(speed) / frame_rate |
| else: |
| input_text = original_text |
|
|
| generate_kwargs: dict[str, Any] = { |
| "text": input_text, |
| "voice_clone_prompt": voice_clone_prompt, |
| "instruct": dialect, |
| "generation_config": generation_config, |
| "language": "zh", |
| } |
| if duration_override is not None: |
| generate_kwargs["duration"] = duration_override |
| elif speed != DEFAULT_SPEED: |
| generate_kwargs["speed"] = float(speed) |
|
|
| audio = RUNTIME.model.generate(**generate_kwargs) |
| if not audio: |
| return None, "模型沒有回傳音訊。" |
|
|
| return ( |
| to_audio_output(audio[0], int(RUNTIME.sampling_rate or 24000)), |
| f"合成完成。腔調:{dialect};speed={speed:.2f};steps={int(num_step)}{g2p_note}", |
| ) |
| except Exception as exc: |
| return None, f"合成失敗:{type(exc).__name__}: {exc}" |
|
|
|
|
| def build_demo() -> gr.Blocks: |
| with gr.Blocks(title="臺灣客語語音生成系統") as demo: |
| with gr.Column(): |
| gr.Markdown( |
| """ |
| # 臺灣客語語音合成系統 |
| ### Taiwanese Hakka Text-to-Speech System |
| ### 研發團隊 |
| - **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)** |
| - **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)** |
| ### 合作單位 |
| - **[國立聯合大學智慧客家實驗室](https://www.gohakka.org)** |
| """ |
| ) |
|
|
| with gr.Row(equal_height=False): |
| with gr.Column(scale=11, elem_classes="panel"): |
| dialect = gr.Dropdown( |
| choices=DIALECT_LABELS, |
| value=None, |
| allow_custom_value=False, |
| label="客語腔調", |
| info="此模型用 instruct 控制腔調,推論前必選。", |
| ) |
| text = gr.Textbox( |
| label="要合成的文字", |
| lines=4, |
| placeholder="例如:這下來試看啊,客語語音合成聽起來仰般。", |
| ) |
| ref_audio = gr.Audio( |
| label="參考音檔", |
| type="filepath", |
| ) |
| ref_text = gr.Textbox( |
| label="參考文本", |
| lines=2, |
| placeholder="請填寫參考音檔對應的逐字文本。", |
| ) |
| use_g2p = gr.Checkbox( |
| value=False, |
| label="使用 G2P 轉換", |
| info="勾選後會先用 formog2p 將漢字轉成拼音(大寫)再輸入模型;不勾選則直接輸入原文。", |
| ) |
| with gr.Accordion("進階設定", open=False): |
| speed = gr.Slider( |
| minimum=0.5, |
| maximum=1.5, |
| value=DEFAULT_SPEED, |
| step=0.05, |
| label="Speed", |
| info="1.0 為預設語速;越大越快。", |
| ) |
| num_step = gr.Slider( |
| minimum=4, |
| maximum=32, |
| value=DEFAULT_STEPS, |
| step=1, |
| label="Inference Steps", |
| info="步數越高通常品質越穩,但速度較慢。", |
| ) |
| submit = gr.Button("開始合成", variant="primary") |
|
|
| with gr.Column(scale=9): |
| output_audio = gr.Audio( |
| label="合成結果", |
| type="numpy", |
| ) |
| status = gr.Textbox( |
| label="狀態", |
| value=startup_status(), |
| lines=6, |
| interactive=False, |
| ) |
|
|
| submit.click( |
| fn=synthesize, |
| inputs=[dialect, text, ref_audio, ref_text, speed, num_step, use_g2p], |
| outputs=[output_audio, status], |
| ) |
|
|
| gr.Examples( |
| examples=EXAMPLES, |
| inputs=[dialect, text, ref_audio, ref_text, speed, num_step, use_g2p], |
| label="範例", |
| ) |
|
|
| return demo |
|
|
|
|
| demo = build_demo() |
|
|
|
|
| def main() -> None: |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(message)s", |
| ) |
| demo.queue().launch( |
| css="@import url(https://tauhu.tw/tauhu-oo.css);", |
| theme=gr.themes.Default( |
| font=( |
| "tauhu-oo", |
| gr.themes.GoogleFont("Source Sans Pro"), |
| "ui-sans-serif", |
| "system-ui", |
| "sans-serif", |
| ) |
| ), |
| server_name="0.0.0.0", |
| server_port=int(os.getenv("PORT", "7860")), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|