File size: 3,726 Bytes
6b26697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Darwin-TTS-1.7B-Cross: Cross-Modal LLMโ†’TTS FFN Blending
=========================================================
World's first cross-modal FFN transfer from LLM to TTS.
No training. 84 FFN tensors. Shape 100% match.

Usage:
    python darwin_tts_blend.py --alpha 3 --text "์•ˆ๋…•ํ•˜์„ธ์š”!"
    python darwin_tts_blend.py --alpha 5 --ref voice.wav --text "Hello!"

Alpha guide:
    0  = Original Qwen3-TTS (no blending)
    1  = Subtle (barely noticeable)
    3  = Recommended (emotion appears) โ˜…
    5  = Maximum stable (emotion intensified) โ˜…โ˜…
    10 = BROKEN (do not use)
"""
import argparse
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from safetensors import safe_open


def load_llm_ffn(model_id="Qwen/Qwen3-1.7B"):
    """Load FFN weights from Qwen3-1.7B LLM."""
    from huggingface_hub import snapshot_download
    path = snapshot_download(model_id, ignore_patterns=["*.bin", "*.ot", "*.msgpack"])
    ffn = {}
    for f in sorted(Path(path).rglob("*.safetensors")):
        with safe_open(str(f), framework="pt") as s:
            for k in s.keys():
                if any(x in k for x in ["gate_proj", "up_proj", "down_proj"]):
                    ffn[k] = s.get_tensor(k)
    print(f"Loaded {len(ffn)} LLM FFN tensors")
    return ffn


def blend_tts(alpha=0.03, tts_model="Qwen/Qwen3-TTS-12Hz-1.7B-Base"):
    """
    Load TTS model and blend LLM FFN into talker.
    
    Args:
        alpha: Blend ratio (0.0 to 0.05 recommended, default 0.03)
        tts_model: TTS model ID or path
    
    Returns:
        Blended Qwen3TTSModel ready for inference
    """
    from qwen_tts import Qwen3TTSModel
    
    print(f"Loading TTS: {tts_model}")
    model = Qwen3TTSModel.from_pretrained(
        tts_model, device_map="cuda:0", dtype=torch.bfloat16
    )
    
    if alpha > 0:
        llm_ffn = load_llm_ffn()
        cnt = 0
        for n, p in model.model.named_parameters():
            if "talker" not in n or "code_predictor" in n:
                continue
            if not any(x in n for x in ["gate_proj", "up_proj", "down_proj"]):
                continue
            llm_key = n.replace("talker.", "")
            if llm_key in llm_ffn:
                with torch.no_grad():
                    p.lerp_(llm_ffn[llm_key].to(p.device, p.dtype), alpha)
                cnt += 1
        print(f"Blended {cnt} FFN tensors (alpha={alpha}, shape 100% match)")
    
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Darwin-TTS: LLMโ†’TTS FFN Blending")
    parser.add_argument("--alpha", type=int, default=3,
                        help="Blend %% (0=original, 3=recommended, 5=max stable)")
    parser.add_argument("--text", type=str,
                        default="์•ˆ๋…•ํ•˜์„ธ์š”, ์ €๋Š” ๋‹ค์œˆ ์ธ๊ณต์ง€๋Šฅ์ž…๋‹ˆ๋‹ค.")
    parser.add_argument("--ref", type=str, default=None,
                        help="Reference audio for voice cloning")
    parser.add_argument("--output", type=str, default="darwin_output.wav")
    args = parser.parse_args()

    if args.ref is None:
        args.ref = "/tmp/_darwin_ref.wav"
        sf.write(args.ref,
                 (0.1 * np.sin(2 * np.pi * 200 * np.linspace(0, 3, 72000))
                  ).astype(np.float32), 24000)
        print("Using default sine reference (provide --ref for better quality)")

    model = blend_tts(alpha=args.alpha / 100.0)
    wavs, sr = model.generate_voice_clone(
        text=args.text, ref_audio=args.ref,
        ref_text="ref", x_vector_only_mode=True
    )
    wav = wavs[0].cpu().numpy() if hasattr(wavs[0], "cpu") else np.array(wavs[0])
    sf.write(args.output, wav, sr)
    print(f"Saved: {args.output} ({len(wav)/sr:.1f}s)")