| """ |
| Hub utilities for downloading and managing Chiluka TTS models. |
| |
| Supports: |
| - HuggingFace Hub integration |
| - Automatic model downloading |
| - Local caching |
| - Multiple model variants |
| """ |
|
|
| import os |
| import shutil |
| from pathlib import Path |
| from typing import Optional, Union |
|
|
| |
| DEFAULT_HF_REPO = "Seemanth/chiluka-tts" |
|
|
| |
| CACHE_DIR = Path.home() / ".cache" / "chiluka" |
|
|
| |
| |
| |
| |
| |
| MODEL_REGISTRY = { |
| "telugu": { |
| "config": "configs/config_ft.yml", |
| "checkpoint": "checkpoints/epoch_2nd_00017.pth", |
| "languages": ["te", "en"], |
| "description": "Telugu + English single-speaker TTS", |
| }, |
| "hindi_english": { |
| "config": "configs/config_hindi_english.yml", |
| "checkpoint": "checkpoints/epoch_2nd_00029.pth", |
| "languages": ["hi", "en"], |
| "description": "Hindi + English multi-speaker TTS (5 speakers)", |
| }, |
| } |
|
|
| DEFAULT_MODEL = "hindi_english" |
|
|
| |
| PRETRAINED_FILES = { |
| "asr_config": "pretrained/ASR/config.yml", |
| "asr_model": "pretrained/ASR/epoch_00080.pth", |
| "f0_model": "pretrained/JDC/bst.t7", |
| "plbert_config": "pretrained/PLBERT/config.yml", |
| "plbert_model": "pretrained/PLBERT/step_1000000.t7", |
| } |
|
|
|
|
| def list_models() -> dict: |
| """ |
| List all available model variants. |
| |
| Returns: |
| Dictionary of model names and their info. |
| |
| Example: |
| >>> from chiluka import hub |
| >>> hub.list_models() |
| {'telugu': {...}, 'hindi_english': {...}} |
| """ |
| return { |
| name: { |
| "languages": info["languages"], |
| "description": info["description"], |
| } |
| for name, info in MODEL_REGISTRY.items() |
| } |
|
|
|
|
| def get_cache_dir() -> Path: |
| """Get the cache directory for Chiluka models.""" |
| cache_dir = Path(os.environ.get("CHILUKA_CACHE", CACHE_DIR)) |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| return cache_dir |
|
|
|
|
| def is_model_cached(repo_id: str = DEFAULT_HF_REPO) -> bool: |
| """Check if a model is already cached locally.""" |
| cache_path = get_cache_dir() / repo_id.replace("/", "_") |
| if not cache_path.exists(): |
| return False |
|
|
| |
| for file_path in PRETRAINED_FILES.values(): |
| if not (cache_path / file_path).exists(): |
| return False |
|
|
| |
| for model_info in MODEL_REGISTRY.values(): |
| config_exists = (cache_path / model_info["config"]).exists() |
| checkpoint_exists = (cache_path / model_info["checkpoint"]).exists() |
| if config_exists and checkpoint_exists: |
| return True |
|
|
| return False |
|
|
|
|
| def download_from_hf( |
| repo_id: str = DEFAULT_HF_REPO, |
| revision: str = "main", |
| force_download: bool = False, |
| token: Optional[str] = None, |
| ) -> Path: |
| """ |
| Download model files from HuggingFace Hub. |
| |
| Args: |
| repo_id: HuggingFace Hub repository ID (e.g., 'Seemanth/chiluka-tts') |
| revision: Git revision to download (branch, tag, or commit hash) |
| force_download: If True, re-download even if cached |
| token: HuggingFace API token for private repos |
| |
| Returns: |
| Path to the downloaded model directory |
| """ |
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError: |
| raise ImportError( |
| "huggingface_hub is required for downloading models. " |
| "Install with: pip install huggingface_hub" |
| ) |
|
|
| cache_path = get_cache_dir() / repo_id.replace("/", "_") |
|
|
| if is_model_cached(repo_id) and not force_download: |
| print(f"Using cached model from {cache_path}") |
| return cache_path |
|
|
| print(f"Downloading model from HuggingFace Hub: {repo_id}...") |
|
|
| downloaded_path = snapshot_download( |
| repo_id=repo_id, |
| revision=revision, |
| cache_dir=get_cache_dir() / "hf_cache", |
| token=token, |
| local_dir=cache_path, |
| local_dir_use_symlinks=False, |
| ) |
|
|
| print(f"Model downloaded to {cache_path}") |
| return Path(downloaded_path) |
|
|
|
|
| def get_model_paths( |
| model: str = DEFAULT_MODEL, |
| repo_id: str = DEFAULT_HF_REPO, |
| ) -> dict: |
| """ |
| Get paths to all model files after downloading. |
| |
| Args: |
| model: Model variant name ('telugu', 'hindi_english') |
| repo_id: HuggingFace Hub repository ID |
| |
| Returns: |
| Dictionary with paths to config, checkpoint, and pretrained directory |
| """ |
| if model not in MODEL_REGISTRY: |
| available = ", ".join(MODEL_REGISTRY.keys()) |
| raise ValueError( |
| f"Unknown model '{model}'. Available models: {available}" |
| ) |
|
|
| model_dir = download_from_hf(repo_id) |
| model_info = MODEL_REGISTRY[model] |
|
|
| return { |
| "config_path": str(model_dir / model_info["config"]), |
| "checkpoint_path": str(model_dir / model_info["checkpoint"]), |
| "pretrained_dir": str(model_dir / "pretrained"), |
| } |
|
|
|
|
| def clear_cache(repo_id: Optional[str] = None): |
| """ |
| Clear cached models. |
| |
| Args: |
| repo_id: If specified, only clear cache for this repo. |
| If None, clear entire cache. |
| """ |
| cache_dir = get_cache_dir() |
|
|
| if repo_id: |
| cache_path = cache_dir / repo_id.replace("/", "_") |
| if cache_path.exists(): |
| shutil.rmtree(cache_path) |
| print(f"Cleared cache for {repo_id}") |
| else: |
| if cache_dir.exists(): |
| shutil.rmtree(cache_dir) |
| print("Cleared entire Chiluka cache") |
|
|
|
|
| def push_to_hub( |
| local_dir: str, |
| repo_id: str, |
| token: Optional[str] = None, |
| private: bool = False, |
| commit_message: str = "Upload Chiluka TTS model", |
| ): |
| """ |
| Push a local model to HuggingFace Hub. |
| |
| Args: |
| local_dir: Local directory containing model files |
| repo_id: Target HuggingFace Hub repository ID |
| token: HuggingFace API token (or set HF_TOKEN env var) |
| private: Whether to create a private repository |
| commit_message: Commit message for the upload |
| |
| Example: |
| >>> push_to_hub( |
| ... local_dir="./chiluka", |
| ... repo_id="Seemanth/chiluka-tts", |
| ... private=False |
| ... ) |
| """ |
| try: |
| from huggingface_hub import HfApi, create_repo |
| except ImportError: |
| raise ImportError( |
| "huggingface_hub is required for pushing models. " |
| "Install with: pip install huggingface_hub" |
| ) |
|
|
| api = HfApi(token=token) |
|
|
| |
| try: |
| create_repo(repo_id, private=private, token=token, exist_ok=True) |
| except Exception as e: |
| print(f"Note: {e}") |
|
|
| |
| print(f"Uploading to {repo_id}...") |
| api.upload_folder( |
| folder_path=local_dir, |
| repo_id=repo_id, |
| commit_message=commit_message, |
| ignore_patterns=["*.pyc", "__pycache__", "*.egg-info", ".git"], |
| ) |
|
|
| print(f"Model uploaded to: https://huggingface.co/{repo_id}") |
|
|
|
|
| def create_model_card(repo_id: str, save_path: Optional[str] = None) -> str: |
| """ |
| Generate a model card (README.md) for HuggingFace Hub. |
| |
| Args: |
| repo_id: Repository ID for the model |
| save_path: If provided, save the model card to this path |
| |
| Returns: |
| Model card content as string |
| """ |
| owner = repo_id.split("/")[0] |
|
|
| |
| model_rows = "" |
| for name, info in MODEL_REGISTRY.items(): |
| langs = ", ".join(info["languages"]) |
| model_rows += f"| `{name}` | {info['description']} | {langs} |\n" |
|
|
| model_card = f"""--- |
| language: |
| - en |
| - te |
| - hi |
| license: mit |
| library_name: chiluka |
| tags: |
| - text-to-speech |
| - tts |
| - styletts2 |
| - voice-cloning |
| - multi-language |
| --- |
| |
| # Chiluka TTS |
| |
| Chiluka (చిలుక - Telugu for "parrot") is a lightweight Text-to-Speech model based on StyleTTS2. |
| |
| ## Available Models |
| |
| | Model | Description | Languages | |
| |-------|-------------|-----------| |
| {model_rows} |
| |
| ## Installation |
| |
| ```bash |
| pip install chiluka |
| ``` |
| |
| Or install from source: |
| |
| ```bash |
| pip install git+https://github.com/{owner}/chiluka.git |
| ``` |
| |
| ## Usage |
| |
| ### Hindi + English (default) |
| |
| ```python |
| from chiluka import Chiluka |
| |
| tts = Chiluka.from_pretrained() |
| |
| wav = tts.synthesize( |
| text="Hello, world!", |
| reference_audio="reference.wav", |
| language="en" |
| ) |
| tts.save_wav(wav, "output.wav") |
| ``` |
| |
| ### Telugu |
| |
| ```python |
| tts = Chiluka.from_pretrained(model="telugu") |
| |
| wav = tts.synthesize( |
| text="నమస్కారం", |
| reference_audio="reference.wav", |
| language="te" |
| ) |
| ``` |
| |
| ### PyTorch Hub |
| |
| ```python |
| import torch |
| |
| tts = torch.hub.load('{owner}/chiluka', 'chiluka') |
| tts = torch.hub.load('{owner}/chiluka', 'chiluka_telugu') |
| ``` |
| |
| ## License |
| |
| MIT License |
| |
| ## Citation |
| |
| Based on StyleTTS2 by Yinghao Aaron Li et al. |
| """ |
|
|
| if save_path: |
| with open(save_path, "w") as f: |
| f.write(model_card) |
| print(f"Model card saved to {save_path}") |
|
|
| return model_card |
|
|