# SPDX-License-Identifier: Apache-2.0 """ Helpers for loading Qwen3 TTS with recent ``transformers`` + ``qwen-tts``. Import this **before** ``from qwen_tts import Qwen3TTSModel`` (or import this module first, which applies the check_model_inputs shim on load). """ from __future__ import annotations import json from pathlib import Path import transformers.utils.generic as _g _apply_done = False def apply_check_model_inputs_shim() -> None: """ qwen-tts uses ``@check_model_inputs()`` (call form). In current ``transformers``, ``check_model_inputs`` is a factory: ``check_model_inputs()`` returns a decorator that wraps ``forward``; the previous shim used ``lambda f: _orig(f)``, which is equivalent to ``_orig(f)`` and passes ``f`` as the *first parameter* (``tie_last_hidden_states``), not through the returned decorator. That left the wrong callable as ``forward`` and led to ``wrapped_fn() got an unexpected keyword argument 'inputs_embeds'``. Must run before importing qwen_tts. """ global _apply_done if _apply_done or getattr(_g, "_qwen_tts_check_model_inputs_shim_applied", False): return _orig = _g.check_model_inputs def check_model_inputs(*args, **kwargs): # @check_model_inputs() -> factory must return the inner decorator, same as HF if not args and not kwargs: return _orig() # @check_model_inputs (bare; rare) -> _orig()(func) if len(args) == 1 and callable(args[0]) and not kwargs: return _orig()(args[0]) return _orig(*args, **kwargs) _g.check_model_inputs = check_model_inputs _g._qwen_tts_check_model_inputs_shim_applied = True _apply_done = True def register_qwen3_tts_config_if_needed() -> None: from transformers import AutoConfig from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig try: AutoConfig.register("qwen3_tts", Qwen3TTSConfig) except ValueError as e: if "already" not in str(e).lower(): raise def ensure_qwen3_tts_config_pad_token_ids(config) -> None: """ Some model cards omit ``talker_config.pad_token_id`` and ``code_predictor_config.pad_token_id``; talker and code-predictor modules read ``config.pad_token_id`` for ``nn.Embedding`` padding_idx. """ tts_pad = getattr(config, "tts_pad_token_id", None) tc = getattr(config, "talker_config", None) if tc is None: return if getattr(tc, "pad_token_id", None) is None and tts_pad is not None: tc.pad_token_id = int(tts_pad) cpc = getattr(tc, "code_predictor_config", None) if cpc is not None and getattr(cpc, "pad_token_id", None) is None: codec_pad = getattr(tc, "codec_pad_id", None) if codec_pad is not None: cpc.pad_token_id = int(codec_pad) _UNWANTED_DTYPE_KEYS = frozenset({"dtype", "torch_dtype"}) # PretrainedConfig.to_dict() can put these inside nested blobs; some Qwen3 sub-configs # (e.g. Qwen3TTSSpeakerEncoderConfig) use a tight __init__ with no **kwargs and reject them. _NESTED_METADATA_KEYS = frozenset({"model_type"}) def _strip_config_for_qwen3_load(obj: object, depth: int = 0) -> int: """In-place: remove keys HuggingFace may serialize that Qwen3 sub-configs reject.""" n = 0 if isinstance(obj, dict): for k in _UNWANTED_DTYPE_KEYS: if k in obj: del obj[k] n += 1 if depth > 0: for k in _NESTED_METADATA_KEYS: if k in obj: del obj[k] n += 1 for v in obj.values(): n += _strip_config_for_qwen3_load(v, depth + 1) elif isinstance(obj, list): for v in obj: n += _strip_config_for_qwen3_load(v, depth) return n def sanitize_qwen3_tts_config_json(repo_or_config: Path) -> int: """ Fix ``config.json`` saved or merged in ways HuggingFace does not like for Qwen3: - Recursively remove ``dtype`` / ``torch_dtype`` (some stacks write these into nested dicts; e.g. ``Qwen3TTSSpeakerEncoderConfig`` rejects ``dtype``). - In nested dicts (not the root), remove ``model_type``: sub-configs such as ``Qwen3TTSSpeakerEncoderConfig`` have a fixed ``__init__`` with no ``**kwargs``; a nested ``"model_type": "..."`` from ``to_dict()`` then raises "unexpected keyword argument 'model_type'". The root key must stay for ``AutoConfig``. Pass a model directory (containing config.json) or a path to config.json. Returns number of removed key occurrences (0 if nothing to do or file missing). """ path = Path(repo_or_config) if path.is_dir(): path = path / "config.json" if not path.is_file(): return 0 with path.open("r", encoding="utf-8") as f: data = json.load(f) n = _strip_config_for_qwen3_load(data, 0) if n: with path.open("w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) return n # Run shim on import so ``import qwen3_tts_load_utils`` before qwen_tts is enough. apply_check_model_inputs_shim()