OmniVoice / omnivoice /utils /duration.py
zhu-han's picture
Upload 48 files
aa79b9c verified
#!/usr/bin/env python3
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text duration estimation for TTS generation.
Provides ``RuleDurationEstimator``, which estimates audio duration from text
using character phonetic weights across 600+ languages. Used by
``OmniVoice.generate()`` to determine output length when no duration is specified.
"""
import bisect
import unicodedata
from functools import lru_cache
from typing import Optional
class RuleDurationEstimator:
def __init__(self):
# ==========================================
# 1. Phonetic Weights Table
# ==========================================
# The weight represents the relative speaking time compared to
# a standard Latin letter.
# Benchmark: 1.0 = One Latin Character (~40-50ms)
self.weights = {
# --- Logographic (1 char = full syllable/word) ---
"cjk": 3.0, # Chinese, Japanese Kanji, etc.
# --- Syllabic / Blocks
"hangul": 2.5, # Korean Hangul
"kana": 2.2, # Japanese Hiragana/Katakana
"ethiopic": 3.0, # Amharic/Ge'ez
"yi": 3.0, # Yi script
# --- Abugida (Consonant-Vowel complexes) ---
"indic": 1.8, # Hindi, Bengali, Tamil, etc.
"thai_lao": 1.5, # Thai, Lao
"khmer_myanmar": 1.8, # Khmer, Myanmar
# --- Abjad (Consonant-heavy) ---
"arabic": 1.5, # Arabic, Persian, Urdu
"hebrew": 1.5, # Hebrew
# --- Alphabet (Segmental) ---
"latin": 1.0, # English, Spanish, French, Vietnamese, etc. (Baseline)
"cyrillic": 1.0, # Russian, Ukrainian
"greek": 1.0, # Greek
"armenian": 1.0, # Armenian
"georgian": 1.0, # Georgian
# --- Symbols & Misc ---
"punctuation": 0.5, # Pause capability
"space": 0.2, # Word boundary/Breath (0.05 / 0.22)
"digit": 3.5, # Numbers
"mark": 0.0, # Diacritics/Accents (Silent modifiers)
"default": 1.0, # Fallback for unknown scripts
}
# ==========================================
# 2. Unicode Range Mapping
# ==========================================
# Format: (End_Codepoint, Type_Key)
# Used for fast binary search (bisect).
self.ranges = [
(0x02AF, "latin"), # Latin (Basic, Supplement, Ext, IPA)
(0x03FF, "greek"), # Greek & Coptic
(0x052F, "cyrillic"), # Cyrillic
(0x058F, "armenian"), # Armenian
(0x05FF, "hebrew"), # Hebrew
(0x077F, "arabic"), # Arabic, Syriac, Arabic Supplement
(0x089F, "arabic"), # Arabic Extended-B (+ Syriac Supp)
(0x08FF, "arabic"), # Arabic Extended-A
(0x097F, "indic"), # Devanagari
(0x09FF, "indic"), # Bengali
(0x0A7F, "indic"), # Gurmukhi
(0x0AFF, "indic"), # Gujarati
(0x0B7F, "indic"), # Oriya
(0x0BFF, "indic"), # Tamil
(0x0C7F, "indic"), # Telugu
(0x0CFF, "indic"), # Kannada
(0x0D7F, "indic"), # Malayalam
(0x0DFF, "indic"), # Sinhala
(0x0EFF, "thai_lao"), # Thai & Lao
(0x0FFF, "indic"), # Tibetan (Abugida)
(0x109F, "khmer_myanmar"), # Myanmar
(0x10FF, "georgian"), # Georgian
(0x11FF, "hangul"), # Hangul Jamo
(0x137F, "ethiopic"), # Ethiopic
(0x139F, "ethiopic"), # Ethiopic Supplement
(0x13FF, "default"), # Cherokee
(0x167F, "default"), # Canadian Aboriginal Syllabics
(0x169F, "default"), # Ogham
(0x16FF, "default"), # Runic
(0x171F, "default"), # Tagalog (Baybayin)
(0x173F, "default"), # Hanunoo
(0x175F, "default"), # Buhid
(0x177F, "default"), # Tagbanwa
(0x17FF, "khmer_myanmar"), # Khmer
(0x18AF, "default"), # Mongolian
(0x18FF, "default"), # Canadian Aboriginal Syllabics Ext
(0x194F, "indic"), # Limbu
(0x19DF, "indic"), # Tai Le & New Tai Lue
(0x19FF, "khmer_myanmar"), # Khmer Symbols
(0x1A1F, "indic"), # Buginese
(0x1AAF, "indic"), # Tai Tham
(0x1B7F, "indic"), # Balinese
(0x1BBF, "indic"), # Sundanese
(0x1BFF, "indic"), # Batak
(0x1C4F, "indic"), # Lepcha
(0x1C7F, "indic"), # Ol Chiki (Santali)
(0x1C8F, "cyrillic"), # Cyrillic Extended-C
(0x1CBF, "georgian"), # Georgian Extended
(0x1CCF, "indic"), # Sundanese Supplement
(0x1CFF, "indic"), # Vedic Extensions
(0x1D7F, "latin"), # Phonetic Extensions
(0x1DBF, "latin"), # Phonetic Extensions Supplement
(0x1DFF, "default"), # Combining Diacritical Marks Supplement
(0x1EFF, "latin"), # Latin Extended Additional (Vietnamese)
(0x309F, "kana"), # Hiragana
(0x30FF, "kana"), # Katakana
(0x312F, "cjk"), # Bopomofo (Pinyin)
(0x318F, "hangul"), # Hangul Compatibility Jamo
(0x9FFF, "cjk"), # CJK Unified Ideographs (Main)
(0xA4CF, "yi"), # Yi Syllables
(0xA4FF, "default"), # Lisu
(0xA63F, "default"), # Vai
(0xA69F, "cyrillic"), # Cyrillic Extended-B
(0xA6FF, "default"), # Bamum
(0xA7FF, "latin"), # Latin Extended-D
(0xA82F, "indic"), # Syloti Nagri
(0xA87F, "default"), # Phags-pa
(0xA8DF, "indic"), # Saurashtra
(0xA8FF, "indic"), # Devanagari Extended
(0xA92F, "indic"), # Kayah Li
(0xA95F, "indic"), # Rejang
(0xA97F, "hangul"), # Hangul Jamo Extended-A
(0xA9DF, "indic"), # Javanese
(0xA9FF, "khmer_myanmar"), # Myanmar Extended-B
(0xAA5F, "indic"), # Cham
(0xAA7F, "khmer_myanmar"), # Myanmar Extended-A
(0xAADF, "indic"), # Tai Viet
(0xAAFF, "indic"), # Meetei Mayek Extensions
(0xAB2F, "ethiopic"), # Ethiopic Extended-A
(0xAB6F, "latin"), # Latin Extended-E
(0xABBF, "default"), # Cherokee Supplement
(0xABFF, "indic"), # Meetei Mayek
(0xD7AF, "hangul"), # Hangul Syllables
(0xFAFF, "cjk"), # CJK Compatibility
(0xFDFF, "arabic"), # Arabic Presentation Forms-A
(0xFE6F, "default"), # Variation Selectors
(0xFEFF, "arabic"), # Arabic Presentation Forms-B
(0xFFEF, "latin"), # Fullwidth Latin
]
self.breakpoints = [r[0] for r in self.ranges]
@lru_cache(maxsize=4096)
def _get_char_weight(self, char):
"""Determines the weight of a single character."""
code = ord(char)
if (65 <= code <= 90) or (97 <= code <= 122):
return self.weights["latin"]
if code == 32:
return self.weights["space"]
# Ignore arabic Tatweel
if code == 0x0640:
return self.weights["mark"]
category = unicodedata.category(char)
if category.startswith("M"):
return self.weights["mark"]
if category.startswith("P") or category.startswith("S"):
return self.weights["punctuation"]
if category.startswith("Z"):
return self.weights["space"]
if category.startswith("N"):
return self.weights["digit"]
# 3. Binary search for Unicode Block (此时区间里绝不会再混进标点符号)
idx = bisect.bisect_left(self.breakpoints, code)
if idx < len(self.ranges):
script_type = self.ranges[idx][1]
return self.weights.get(script_type, self.weights["default"])
# 4. Handle upper planes (CJK Ext B/C/D, Historic scripts)
if code > 0x20000:
return self.weights["cjk"]
return self.weights["default"]
def calculate_total_weight(self, text):
"""Sums up the normalized weights for a string."""
return sum(self._get_char_weight(c) for c in text)
def estimate_duration(
self,
target_text: str,
ref_text: str,
ref_duration: float,
low_threshold: Optional[float] = 50,
boost_strength: float = 3,
) -> float:
"""
Args:
target_text (str): The text for which we want to estimate the duration.
ref_text (str): The reference text that was used to measure
the ref_duration.
ref_duration (float): The actual duration it took
to speak the ref_text.
low_threshold (float): The minimum duration threshold below which the
estimation will be considered unreliable.
boost_strength (float): Controls the power-curve boost for short durations.
Higher values boost small durations more aggressively.
1 = no boost (linear), 2 = sqrt-like
Returns:
float: The estimated duration for the target_text based
on the ref_text and ref_duration.
"""
if ref_duration <= 0 or not ref_text:
return 0.0
ref_weight = self.calculate_total_weight(ref_text)
if ref_weight == 0:
return 0.0
speed_factor = ref_weight / ref_duration
target_weight = self.calculate_total_weight(target_text)
estimated_duration = target_weight / speed_factor
if low_threshold is not None and estimated_duration < low_threshold:
alpha = 1.0 / boost_strength
return low_threshold * (estimated_duration / low_threshold) ** alpha
else:
return estimated_duration
# ==========================================
# Example Usage
# ==========================================
if __name__ == "__main__":
estimator = RuleDurationEstimator()
ref_txt = "Hello, world."
ref_dur = 1.5
test_cases = [
("Hindi (With complex marks)", "नमस्ते दुनिया"),
("Arabic (With vowels)", "مَرْحَبًا بِالْعَالَم"),
("Vietnamese (Lots of diacritics)", "Chào thế giới"),
("Chinese", "你好,世界!"),
("Mixed Emoji", "Hello 🌍! This is fun 🎉"),
]
print("--- Reference ---")
print(f"Reference Text: '{ref_txt}'")
print(f"Reference Duration: {ref_dur}s")
print("-" * 30)
for lang, txt in test_cases:
est_time = estimator.estimate_duration(txt, ref_txt, ref_dur)
weight = estimator.calculate_total_weight(txt)
print(f"[{lang}]")
print(f"Text: {txt}")
print(f"Total Weight: {weight:.2f}")
print(f"Estimated Duration: {est_time:.2f} s")
print("-" * 30)