from __future__ import annotations from pathlib import Path import numpy as np import torch import yaml from safetensors.torch import load_file from snac import SNAC from transformers import AutoModelForCausalLM, AutoTokenizer CODE_START_TOKEN_ID = 128257 CODE_END_TOKEN_ID = 128258 CODE_TOKEN_OFFSET = 128266 SNAC_MIN_ID = 128266 SNAC_MAX_ID = 156937 SNAC_TOKENS_PER_FRAME = 7 SOH_ID = 128259 EOH_ID = 128260 SOA_ID = 128261 BOS_ID = 128000 TEXT_EOT_ID = 128009 def build_prompt(tokenizer, instruction: str, text: str) -> str: """Build Maya1 prompt: control tokens + verbatim instruction/text.""" soh_token = tokenizer.decode([SOH_ID]) eoh_token = tokenizer.decode([EOH_ID]) soa_token = tokenizer.decode([SOA_ID]) sos_token = tokenizer.decode([CODE_START_TOKEN_ID]) eot_token = tokenizer.decode([TEXT_EOT_ID]) bos_token = tokenizer.bos_token formatted_text = f' {text}' prompt = ( soh_token + bos_token + formatted_text + eot_token + eoh_token + soa_token + sos_token ) return prompt def extract_snac_codes(token_ids: list) -> list: """Extract SNAC codes from generated tokens.""" try: eos_idx = token_ids.index(CODE_END_TOKEN_ID) except ValueError: eos_idx = len(token_ids) snac_codes = [ token_id for token_id in token_ids[:eos_idx] if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID ] return snac_codes def unpack_snac_from_7(snac_tokens: list) -> list: """Unpack 7-token SNAC frames to 3 hierarchical levels.""" if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID: snac_tokens = snac_tokens[:-1] frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME snac_tokens = snac_tokens[:frames * SNAC_TOKENS_PER_FRAME] if frames == 0: return [[], [], []] l1, l2, l3 = [], [], [] for i in range(frames): slots = snac_tokens[i*7:(i+1)*7] l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) l2.extend([ (slots[1] - CODE_TOKEN_OFFSET) % 4096, (slots[4] - CODE_TOKEN_OFFSET) % 4096, ]) l3.extend([ (slots[2] - CODE_TOKEN_OFFSET) % 4096, (slots[3] - CODE_TOKEN_OFFSET) % 4096, (slots[5] - CODE_TOKEN_OFFSET) % 4096, (slots[6] - CODE_TOKEN_OFFSET) % 4096, ]) return [l1, l2, l3] def _load_snac(repo_path: Path) -> SNAC: """Load SNAC decoder weights from repo-local safetensors (no .bin).""" snac_dir = repo_path / "snac_model" weights_path = snac_dir / "model.safetensors" config_path = snac_dir / "config.json" if not weights_path.is_file() or not config_path.is_file(): raise FileNotFoundError( f"SNAC assets missing under {snac_dir}: need config.json and model.safetensors" ) model = SNAC.from_config(str(config_path)) model.load_state_dict(load_file(weights_path, device="cpu")) return model.eval() class Miner: """Vocence miner wrapper for Maya + SNAC inference.""" def __init__(self, path_hf_repo: Path) -> None: self._repo_path = Path(path_hf_repo).resolve() self._device = "cuda" if torch.cuda.is_available() else "cpu" with (self._repo_path / "vocence_config.yaml").open() as f: config = yaml.safe_load(f) or {} model_name = config["model_name"] self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, ) self.snac_model = _load_snac(self._repo_path) if torch.cuda.is_available(): self.snac_model = self.snac_model.to("cuda") def warmup(self) -> None: _ = self.generate_wav( instruction=( "A calm adult male speaker with an American accent, mid-pitched voice, " "normal speaking pace, and a formal tone." ), text="This is a warmup utterance for the voice engine.", ) def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]: prompt = build_prompt(self.tokenizer, instruction, text) inputs = self.tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} with torch.inference_mode(): outputs = self.model.generate( **inputs, max_new_tokens=2048, min_new_tokens=28, temperature=0.4, top_p=0.9, repetition_penalty=1.1, do_sample=True, eos_token_id=CODE_END_TOKEN_ID, pad_token_id=self.tokenizer.pad_token_id, ) generated_ids = outputs[0, inputs["input_ids"].shape[1] :].tolist() snac_tokens = extract_snac_codes(generated_ids) if len(snac_tokens) < SNAC_TOKENS_PER_FRAME: raise RuntimeError("Not enough SNAC tokens generated for decoding.") levels = unpack_snac_from_7(snac_tokens) codes_tensor = [ torch.tensor(level, dtype=torch.long, device=self._device).unsqueeze(0) for level in levels ] with torch.inference_mode(): z_q = self.snac_model.quantizer.from_codes(codes_tensor) audio = self.snac_model.decoder(z_q)[0, 0].cpu().numpy() if len(audio) > 2048: audio = audio[2048:] return audio.astype(np.float32), 24000