Spaces:
Running on Zero
Running on Zero
Fix ZeroGPU and private model loading
Browse files
app.py
CHANGED
|
@@ -6,6 +6,8 @@ import gradio as gr
|
|
| 6 |
import spaces
|
| 7 |
|
| 8 |
MODEL = None
|
|
|
|
|
|
|
| 9 |
|
| 10 |
DEFAULT_CONFIG = {
|
| 11 |
"audio": 'https://storage.googleapis.com/chatterbox-demo-samples/mtl-v3-single-language-prompts/es-latam/es_mx_f1.wav',
|
|
@@ -32,9 +34,15 @@ def default_text_for_ui():
|
|
| 32 |
def get_or_load_model():
|
| 33 |
global MODEL
|
| 34 |
if MODEL is None:
|
| 35 |
-
print("Model not loaded, initializing on
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
return MODEL
|
| 39 |
|
| 40 |
|
|
@@ -57,9 +65,8 @@ def generate_tts_audio(
|
|
| 57 |
cfgw_input: float = 0.5,
|
| 58 |
):
|
| 59 |
"""Generate speech from text with optional reference audio styling."""
|
| 60 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 61 |
current_model = get_or_load_model()
|
| 62 |
-
current_model.
|
| 63 |
if seed_num_input != 0:
|
| 64 |
set_seed(int(seed_num_input), device)
|
| 65 |
chosen_prompt = audio_prompt_path_input or default_audio_for_ui()
|
|
@@ -77,6 +84,9 @@ def generate_tts_audio(
|
|
| 77 |
return (current_model.sr, wav.squeeze(0).cpu().numpy())
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
| 80 |
with gr.Blocks() as demo:
|
| 81 |
gr.Markdown(
|
| 82 |
"""
|
|
|
|
| 6 |
import spaces
|
| 7 |
|
| 8 |
MODEL = None
|
| 9 |
+
# ZeroGPU supports CUDA placement at module load time via CUDA emulation.
|
| 10 |
+
TARGET_DEVICE = "cuda"
|
| 11 |
|
| 12 |
DEFAULT_CONFIG = {
|
| 13 |
"audio": 'https://storage.googleapis.com/chatterbox-demo-samples/mtl-v3-single-language-prompts/es-latam/es_mx_f1.wav',
|
|
|
|
| 34 |
def get_or_load_model():
|
| 35 |
global MODEL
|
| 36 |
if MODEL is None:
|
| 37 |
+
print(f"Model not loaded, initializing on {TARGET_DEVICE}...")
|
| 38 |
+
try:
|
| 39 |
+
MODEL = ChatterboxTTS.from_pretrained(TARGET_DEVICE)
|
| 40 |
+
except Exception as exc:
|
| 41 |
+
if TARGET_DEVICE != "cuda":
|
| 42 |
+
raise
|
| 43 |
+
print(f"CUDA model initialization failed, falling back to CPU: {exc}")
|
| 44 |
+
MODEL = ChatterboxTTS.from_pretrained("cpu")
|
| 45 |
+
print(f"Model loaded on {MODEL.device}.")
|
| 46 |
return MODEL
|
| 47 |
|
| 48 |
|
|
|
|
| 65 |
cfgw_input: float = 0.5,
|
| 66 |
):
|
| 67 |
"""Generate speech from text with optional reference audio styling."""
|
|
|
|
| 68 |
current_model = get_or_load_model()
|
| 69 |
+
device = current_model.device
|
| 70 |
if seed_num_input != 0:
|
| 71 |
set_seed(int(seed_num_input), device)
|
| 72 |
chosen_prompt = audio_prompt_path_input or default_audio_for_ui()
|
|
|
|
| 84 |
return (current_model.sr, wav.squeeze(0).cpu().numpy())
|
| 85 |
|
| 86 |
|
| 87 |
+
get_or_load_model()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
with gr.Blocks() as demo:
|
| 91 |
gr.Markdown(
|
| 92 |
"""
|
chatterbox/src/chatterbox/models/t3/inference/t3_hf_backend.py
CHANGED
|
@@ -23,14 +23,12 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
|
|
| 23 |
speech_head,
|
| 24 |
latents_queue=None,
|
| 25 |
logits_queue=None,
|
| 26 |
-
alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
|
| 27 |
):
|
| 28 |
super().__init__(config)
|
| 29 |
self.model = llama
|
| 30 |
self.speech_enc = speech_enc
|
| 31 |
self.speech_head = speech_head
|
| 32 |
self._added_cond = False
|
| 33 |
-
self.alignment_stream_analyzer = alignment_stream_analyzer
|
| 34 |
|
| 35 |
@torch.inference_mode()
|
| 36 |
def prepare_inputs_for_generation(
|
|
@@ -105,9 +103,6 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
|
|
| 105 |
logits = self.speech_head(hidden_states)
|
| 106 |
# assert inputs_embeds.size(0) == 1 # (disabled for CFG)
|
| 107 |
|
| 108 |
-
# NOTE: hallucination handler may modify logits to force emit an EOS token
|
| 109 |
-
# logits = self.alignment_stream_analyzer.step(logits)
|
| 110 |
-
|
| 111 |
return CausalLMOutputWithCrossAttentions(
|
| 112 |
logits=logits,
|
| 113 |
past_key_values=tfmr_out.past_key_values,
|
|
|
|
| 23 |
speech_head,
|
| 24 |
latents_queue=None,
|
| 25 |
logits_queue=None,
|
|
|
|
| 26 |
):
|
| 27 |
super().__init__(config)
|
| 28 |
self.model = llama
|
| 29 |
self.speech_enc = speech_enc
|
| 30 |
self.speech_head = speech_head
|
| 31 |
self._added_cond = False
|
|
|
|
| 32 |
|
| 33 |
@torch.inference_mode()
|
| 34 |
def prepare_inputs_for_generation(
|
|
|
|
| 103 |
logits = self.speech_head(hidden_states)
|
| 104 |
# assert inputs_embeds.size(0) == 1 # (disabled for CFG)
|
| 105 |
|
|
|
|
|
|
|
|
|
|
| 106 |
return CausalLMOutputWithCrossAttentions(
|
| 107 |
logits=logits,
|
| 108 |
past_key_values=tfmr_out.past_key_values,
|
chatterbox/src/chatterbox/models/t3/t3.py
CHANGED
|
@@ -16,7 +16,6 @@ from .modules.cond_enc import T3CondEnc, T3Cond
|
|
| 16 |
from .modules.t3_config import T3ConfigMultilingual
|
| 17 |
from .llama_configs import LLAMA_CONFIGS
|
| 18 |
from .inference.t3_hf_backend import T3HuggingfaceBackend
|
| 19 |
-
from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
|
| 20 |
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
|
@@ -255,21 +254,11 @@ class T3(nn.Module):
|
|
| 255 |
# TODO? synchronize the expensive compile function
|
| 256 |
# with self.compile_lock:
|
| 257 |
if not self.compiled:
|
| 258 |
-
alignment_stream_analyzer = AlignmentStreamAnalyzer(
|
| 259 |
-
self.tfmr,
|
| 260 |
-
None,
|
| 261 |
-
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
| 262 |
-
alignment_layer_idx=9, # TODO: hparam or something?
|
| 263 |
-
eos_idx=self.hp.stop_speech_token,
|
| 264 |
-
)
|
| 265 |
-
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
|
| 266 |
-
|
| 267 |
patched_model = T3HuggingfaceBackend(
|
| 268 |
config=self.cfg,
|
| 269 |
llama=self.tfmr,
|
| 270 |
speech_enc=self.speech_emb,
|
| 271 |
speech_head=self.speech_head,
|
| 272 |
-
alignment_stream_analyzer=alignment_stream_analyzer,
|
| 273 |
)
|
| 274 |
self.patched_model = patched_model
|
| 275 |
self.compiled = True
|
|
@@ -317,7 +306,7 @@ class T3(nn.Module):
|
|
| 317 |
inputs_embeds=inputs_embeds,
|
| 318 |
past_key_values=None,
|
| 319 |
use_cache=True,
|
| 320 |
-
output_attentions=
|
| 321 |
output_hidden_states=True,
|
| 322 |
return_dict=True,
|
| 323 |
)
|
|
@@ -333,11 +322,6 @@ class T3(nn.Module):
|
|
| 333 |
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
|
| 334 |
logits = cond + cfg * (cond - uncond)
|
| 335 |
|
| 336 |
-
if self.patched_model.alignment_stream_analyzer is not None:
|
| 337 |
-
if logits.dim() == 1: # guard in case something upstream squeezed
|
| 338 |
-
logits = logits.unsqueeze(0) # (1, V)
|
| 339 |
-
logits = self.patched_model.alignment_stream_analyzer.step(logits) # (1, V)
|
| 340 |
-
|
| 341 |
# Apply repetition penalty
|
| 342 |
ids_for_proc = generated_ids[:1, ...] # batch = 1
|
| 343 |
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
|
|
@@ -372,7 +356,7 @@ class T3(nn.Module):
|
|
| 372 |
output = self.patched_model(
|
| 373 |
inputs_embeds=next_token_embed,
|
| 374 |
past_key_values=past,
|
| 375 |
-
output_attentions=
|
| 376 |
output_hidden_states=True,
|
| 377 |
return_dict=True,
|
| 378 |
)
|
|
|
|
| 16 |
from .modules.t3_config import T3ConfigMultilingual
|
| 17 |
from .llama_configs import LLAMA_CONFIGS
|
| 18 |
from .inference.t3_hf_backend import T3HuggingfaceBackend
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
logger = logging.getLogger(__name__)
|
|
|
|
| 254 |
# TODO? synchronize the expensive compile function
|
| 255 |
# with self.compile_lock:
|
| 256 |
if not self.compiled:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
patched_model = T3HuggingfaceBackend(
|
| 258 |
config=self.cfg,
|
| 259 |
llama=self.tfmr,
|
| 260 |
speech_enc=self.speech_emb,
|
| 261 |
speech_head=self.speech_head,
|
|
|
|
| 262 |
)
|
| 263 |
self.patched_model = patched_model
|
| 264 |
self.compiled = True
|
|
|
|
| 306 |
inputs_embeds=inputs_embeds,
|
| 307 |
past_key_values=None,
|
| 308 |
use_cache=True,
|
| 309 |
+
output_attentions=False,
|
| 310 |
output_hidden_states=True,
|
| 311 |
return_dict=True,
|
| 312 |
)
|
|
|
|
| 322 |
cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
|
| 323 |
logits = cond + cfg * (cond - uncond)
|
| 324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
# Apply repetition penalty
|
| 326 |
ids_for_proc = generated_ids[:1, ...] # batch = 1
|
| 327 |
logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
|
|
|
|
| 356 |
output = self.patched_model(
|
| 357 |
inputs_embeds=next_token_embed,
|
| 358 |
past_key_values=past,
|
| 359 |
+
output_attentions=False,
|
| 360 |
output_hidden_states=True,
|
| 361 |
return_dict=True,
|
| 362 |
)
|
chatterbox/src/chatterbox/tts.py
CHANGED
|
@@ -11,7 +11,7 @@ from huggingface_hub import snapshot_download, hf_hub_download
|
|
| 11 |
|
| 12 |
from .models.t3 import T3
|
| 13 |
from .models.t3.modules.t3_config import T3ConfigMultilingual
|
| 14 |
-
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
|
| 15 |
from .models.s3gen import S3GEN_SR, S3Gen
|
| 16 |
from .models.tokenizers import MTLTokenizer
|
| 17 |
from .models.voice_encoder import VoiceEncoder
|
|
@@ -312,5 +312,12 @@ class ChatterboxTTS:
|
|
| 312 |
ref_dict=self.conds.gen,
|
| 313 |
)
|
| 314 |
wav = wav.squeeze(0).detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
| 316 |
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
|
|
|
| 11 |
|
| 12 |
from .models.t3 import T3
|
| 13 |
from .models.t3.modules.t3_config import T3ConfigMultilingual
|
| 14 |
+
from .models.s3tokenizer import S3_SR, S3_TOKEN_RATE, drop_invalid_tokens
|
| 15 |
from .models.s3gen import S3GEN_SR, S3Gen
|
| 16 |
from .models.tokenizers import MTLTokenizer
|
| 17 |
from .models.voice_encoder import VoiceEncoder
|
|
|
|
| 312 |
ref_dict=self.conds.gen,
|
| 313 |
)
|
| 314 |
wav = wav.squeeze(0).detach().cpu().numpy()
|
| 315 |
+
|
| 316 |
+
# Drop the final speech token's audio: it is emitted just before
|
| 317 |
+
# EOS with degraded attention and decodes to ~40 ms of noise.
|
| 318 |
+
n_tokens = int(speech_tokens.shape[-1])
|
| 319 |
+
st_len = max(1, n_tokens - 1)
|
| 320 |
+
wav = wav[: st_len * (S3GEN_SR // S3_TOKEN_RATE)]
|
| 321 |
+
|
| 322 |
watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
|
| 323 |
return torch.from_numpy(watermarked_wav).unsqueeze(0)
|
requirements.txt
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
gradio
|
| 2 |
numpy==1.26.0
|
| 3 |
resampy==0.4.3
|
| 4 |
librosa==0.10.0
|
| 5 |
s3tokenizer
|
| 6 |
-
torchaudio<2.8
|
| 7 |
|
| 8 |
transformers==4.46.3
|
| 9 |
diffusers==0.29.0
|
|
@@ -11,4 +14,4 @@ omegaconf==2.3.0
|
|
| 11 |
resemble-perth==1.0.1
|
| 12 |
silero-vad==5.1.2
|
| 13 |
conformer==0.3.2
|
| 14 |
-
safetensors
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu128
|
| 2 |
+
torch==2.8.0
|
| 3 |
+
torchaudio==2.8.0
|
| 4 |
+
|
| 5 |
gradio
|
| 6 |
numpy==1.26.0
|
| 7 |
resampy==0.4.3
|
| 8 |
librosa==0.10.0
|
| 9 |
s3tokenizer
|
|
|
|
| 10 |
|
| 11 |
transformers==4.46.3
|
| 12 |
diffusers==0.29.0
|
|
|
|
| 14 |
resemble-perth==1.0.1
|
| 15 |
silero-vad==5.1.2
|
| 16 |
conformer==0.3.2
|
| 17 |
+
safetensors
|