vocence-tts / miner.py
matthewliu0302's picture
adapt to vocence update
cdb447a
Raw
History Blame Contribute Delete
5.91 kB
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
import yaml
from safetensors.torch import load_file
from snac import SNAC
from transformers import AutoModelForCausalLM, AutoTokenizer
CODE_START_TOKEN_ID = 128257
CODE_END_TOKEN_ID = 128258
CODE_TOKEN_OFFSET = 128266
SNAC_MIN_ID = 128266
SNAC_MAX_ID = 156937
SNAC_TOKENS_PER_FRAME = 7
SOH_ID = 128259
EOH_ID = 128260
SOA_ID = 128261
BOS_ID = 128000
TEXT_EOT_ID = 128009
def build_prompt(tokenizer, instruction: str, text: str) -> str:
"""Build Maya1 prompt: control tokens + verbatim instruction/text."""
soh_token = tokenizer.decode([SOH_ID])
eoh_token = tokenizer.decode([EOH_ID])
soa_token = tokenizer.decode([SOA_ID])
sos_token = tokenizer.decode([CODE_START_TOKEN_ID])
eot_token = tokenizer.decode([TEXT_EOT_ID])
bos_token = tokenizer.bos_token
formatted_text = f'<description="{instruction}"> {text}'
prompt = (
soh_token + bos_token + formatted_text + eot_token +
eoh_token + soa_token + sos_token
)
return prompt
def extract_snac_codes(token_ids: list) -> list:
"""Extract SNAC codes from generated tokens."""
try:
eos_idx = token_ids.index(CODE_END_TOKEN_ID)
except ValueError:
eos_idx = len(token_ids)
snac_codes = [
token_id for token_id in token_ids[:eos_idx]
if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID
]
return snac_codes
def unpack_snac_from_7(snac_tokens: list) -> list:
"""Unpack 7-token SNAC frames to 3 hierarchical levels."""
if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID:
snac_tokens = snac_tokens[:-1]
frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME
snac_tokens = snac_tokens[:frames * SNAC_TOKENS_PER_FRAME]
if frames == 0:
return [[], [], []]
l1, l2, l3 = [], [], []
for i in range(frames):
slots = snac_tokens[i*7:(i+1)*7]
l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
l2.extend([
(slots[1] - CODE_TOKEN_OFFSET) % 4096,
(slots[4] - CODE_TOKEN_OFFSET) % 4096,
])
l3.extend([
(slots[2] - CODE_TOKEN_OFFSET) % 4096,
(slots[3] - CODE_TOKEN_OFFSET) % 4096,
(slots[5] - CODE_TOKEN_OFFSET) % 4096,
(slots[6] - CODE_TOKEN_OFFSET) % 4096,
])
return [l1, l2, l3]
def _load_snac(repo_path: Path) -> SNAC:
"""Load SNAC decoder weights from repo-local safetensors (no .bin)."""
snac_dir = repo_path / "snac_model"
weights_path = snac_dir / "model.safetensors"
config_path = snac_dir / "config.json"
if not weights_path.is_file() or not config_path.is_file():
raise FileNotFoundError(
f"SNAC assets missing under {snac_dir}: need config.json and model.safetensors"
)
model = SNAC.from_config(str(config_path))
model.load_state_dict(load_file(weights_path, device="cpu"))
return model.eval()
class Miner:
"""Vocence miner wrapper for Maya + SNAC inference."""
def __init__(self, path_hf_repo: Path) -> None:
self._repo_path = Path(path_hf_repo).resolve()
self._device = "cuda" if torch.cuda.is_available() else "cpu"
with (self._repo_path / "vocence_config.yaml").open() as f:
config = yaml.safe_load(f) or {}
model_name = config["model_name"]
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
self.snac_model = _load_snac(self._repo_path)
if torch.cuda.is_available():
self.snac_model = self.snac_model.to("cuda")
def warmup(self) -> None:
_ = self.generate_wav(
instruction=(
"A calm adult male speaker with an American accent, mid-pitched voice, "
"normal speaking pace, and a formal tone."
),
text="This is a warmup utterance for the voice engine.",
)
def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
prompt = build_prompt(self.tokenizer, instruction, text)
inputs = self.tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.inference_mode():
outputs = self.model.generate(
**inputs,
max_new_tokens=2048,
min_new_tokens=28,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
eos_token_id=CODE_END_TOKEN_ID,
pad_token_id=self.tokenizer.pad_token_id,
)
generated_ids = outputs[0, inputs["input_ids"].shape[1] :].tolist()
snac_tokens = extract_snac_codes(generated_ids)
if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
raise RuntimeError("Not enough SNAC tokens generated for decoding.")
levels = unpack_snac_from_7(snac_tokens)
codes_tensor = [
torch.tensor(level, dtype=torch.long, device=self._device).unsqueeze(0)
for level in levels
]
with torch.inference_mode():
z_q = self.snac_model.quantizer.from_codes(codes_tensor)
audio = self.snac_model.decoder(z_q)[0, 0].cpu().numpy()
if len(audio) > 2048:
audio = audio[2048:]
return audio.astype(np.float32), 24000