tedi-resemble commited on
Commit
0d2648c
·
verified ·
1 Parent(s): ea9a032

Fix ZeroGPU and private model loading

Browse files
app.py CHANGED
@@ -6,6 +6,8 @@ import gradio as gr
6
  import spaces
7
 
8
  MODEL = None
 
 
9
 
10
  DEFAULT_CONFIG = {
11
  "audio": 'https://storage.googleapis.com/chatterbox-demo-samples/mtl-v3-single-language-prompts/es-latam/es_mx_f1.wav',
@@ -32,9 +34,15 @@ def default_text_for_ui():
32
  def get_or_load_model():
33
  global MODEL
34
  if MODEL is None:
35
- print("Model not loaded, initializing on CPU...")
36
- MODEL = ChatterboxTTS.from_pretrained("cpu")
37
- print("Model loaded.")
 
 
 
 
 
 
38
  return MODEL
39
 
40
 
@@ -57,9 +65,8 @@ def generate_tts_audio(
57
  cfgw_input: float = 0.5,
58
  ):
59
  """Generate speech from text with optional reference audio styling."""
60
- device = "cuda" if torch.cuda.is_available() else "cpu"
61
  current_model = get_or_load_model()
62
- current_model.to(device)
63
  if seed_num_input != 0:
64
  set_seed(int(seed_num_input), device)
65
  chosen_prompt = audio_prompt_path_input or default_audio_for_ui()
@@ -77,6 +84,9 @@ def generate_tts_audio(
77
  return (current_model.sr, wav.squeeze(0).cpu().numpy())
78
 
79
 
 
 
 
80
  with gr.Blocks() as demo:
81
  gr.Markdown(
82
  """
 
6
  import spaces
7
 
8
  MODEL = None
9
+ # ZeroGPU supports CUDA placement at module load time via CUDA emulation.
10
+ TARGET_DEVICE = "cuda"
11
 
12
  DEFAULT_CONFIG = {
13
  "audio": 'https://storage.googleapis.com/chatterbox-demo-samples/mtl-v3-single-language-prompts/es-latam/es_mx_f1.wav',
 
34
  def get_or_load_model():
35
  global MODEL
36
  if MODEL is None:
37
+ print(f"Model not loaded, initializing on {TARGET_DEVICE}...")
38
+ try:
39
+ MODEL = ChatterboxTTS.from_pretrained(TARGET_DEVICE)
40
+ except Exception as exc:
41
+ if TARGET_DEVICE != "cuda":
42
+ raise
43
+ print(f"CUDA model initialization failed, falling back to CPU: {exc}")
44
+ MODEL = ChatterboxTTS.from_pretrained("cpu")
45
+ print(f"Model loaded on {MODEL.device}.")
46
  return MODEL
47
 
48
 
 
65
  cfgw_input: float = 0.5,
66
  ):
67
  """Generate speech from text with optional reference audio styling."""
 
68
  current_model = get_or_load_model()
69
+ device = current_model.device
70
  if seed_num_input != 0:
71
  set_seed(int(seed_num_input), device)
72
  chosen_prompt = audio_prompt_path_input or default_audio_for_ui()
 
84
  return (current_model.sr, wav.squeeze(0).cpu().numpy())
85
 
86
 
87
+ get_or_load_model()
88
+
89
+
90
  with gr.Blocks() as demo:
91
  gr.Markdown(
92
  """
chatterbox/src/chatterbox/models/t3/inference/t3_hf_backend.py CHANGED
@@ -23,14 +23,12 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
23
  speech_head,
24
  latents_queue=None,
25
  logits_queue=None,
26
- alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
27
  ):
28
  super().__init__(config)
29
  self.model = llama
30
  self.speech_enc = speech_enc
31
  self.speech_head = speech_head
32
  self._added_cond = False
33
- self.alignment_stream_analyzer = alignment_stream_analyzer
34
 
35
  @torch.inference_mode()
36
  def prepare_inputs_for_generation(
@@ -105,9 +103,6 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
105
  logits = self.speech_head(hidden_states)
106
  # assert inputs_embeds.size(0) == 1 # (disabled for CFG)
107
 
108
- # NOTE: hallucination handler may modify logits to force emit an EOS token
109
- # logits = self.alignment_stream_analyzer.step(logits)
110
-
111
  return CausalLMOutputWithCrossAttentions(
112
  logits=logits,
113
  past_key_values=tfmr_out.past_key_values,
 
23
  speech_head,
24
  latents_queue=None,
25
  logits_queue=None,
 
26
  ):
27
  super().__init__(config)
28
  self.model = llama
29
  self.speech_enc = speech_enc
30
  self.speech_head = speech_head
31
  self._added_cond = False
 
32
 
33
  @torch.inference_mode()
34
  def prepare_inputs_for_generation(
 
103
  logits = self.speech_head(hidden_states)
104
  # assert inputs_embeds.size(0) == 1 # (disabled for CFG)
105
 
 
 
 
106
  return CausalLMOutputWithCrossAttentions(
107
  logits=logits,
108
  past_key_values=tfmr_out.past_key_values,
chatterbox/src/chatterbox/models/t3/t3.py CHANGED
@@ -16,7 +16,6 @@ from .modules.cond_enc import T3CondEnc, T3Cond
16
  from .modules.t3_config import T3ConfigMultilingual
17
  from .llama_configs import LLAMA_CONFIGS
18
  from .inference.t3_hf_backend import T3HuggingfaceBackend
19
- from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer
20
 
21
 
22
  logger = logging.getLogger(__name__)
@@ -255,21 +254,11 @@ class T3(nn.Module):
255
  # TODO? synchronize the expensive compile function
256
  # with self.compile_lock:
257
  if not self.compiled:
258
- alignment_stream_analyzer = AlignmentStreamAnalyzer(
259
- self.tfmr,
260
- None,
261
- text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
262
- alignment_layer_idx=9, # TODO: hparam or something?
263
- eos_idx=self.hp.stop_speech_token,
264
- )
265
- assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
266
-
267
  patched_model = T3HuggingfaceBackend(
268
  config=self.cfg,
269
  llama=self.tfmr,
270
  speech_enc=self.speech_emb,
271
  speech_head=self.speech_head,
272
- alignment_stream_analyzer=alignment_stream_analyzer,
273
  )
274
  self.patched_model = patched_model
275
  self.compiled = True
@@ -317,7 +306,7 @@ class T3(nn.Module):
317
  inputs_embeds=inputs_embeds,
318
  past_key_values=None,
319
  use_cache=True,
320
- output_attentions=True,
321
  output_hidden_states=True,
322
  return_dict=True,
323
  )
@@ -333,11 +322,6 @@ class T3(nn.Module):
333
  cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
334
  logits = cond + cfg * (cond - uncond)
335
 
336
- if self.patched_model.alignment_stream_analyzer is not None:
337
- if logits.dim() == 1: # guard in case something upstream squeezed
338
- logits = logits.unsqueeze(0) # (1, V)
339
- logits = self.patched_model.alignment_stream_analyzer.step(logits) # (1, V)
340
-
341
  # Apply repetition penalty
342
  ids_for_proc = generated_ids[:1, ...] # batch = 1
343
  logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
@@ -372,7 +356,7 @@ class T3(nn.Module):
372
  output = self.patched_model(
373
  inputs_embeds=next_token_embed,
374
  past_key_values=past,
375
- output_attentions=True,
376
  output_hidden_states=True,
377
  return_dict=True,
378
  )
 
16
  from .modules.t3_config import T3ConfigMultilingual
17
  from .llama_configs import LLAMA_CONFIGS
18
  from .inference.t3_hf_backend import T3HuggingfaceBackend
 
19
 
20
 
21
  logger = logging.getLogger(__name__)
 
254
  # TODO? synchronize the expensive compile function
255
  # with self.compile_lock:
256
  if not self.compiled:
 
 
 
 
 
 
 
 
 
257
  patched_model = T3HuggingfaceBackend(
258
  config=self.cfg,
259
  llama=self.tfmr,
260
  speech_enc=self.speech_emb,
261
  speech_head=self.speech_head,
 
262
  )
263
  self.patched_model = patched_model
264
  self.compiled = True
 
306
  inputs_embeds=inputs_embeds,
307
  past_key_values=None,
308
  use_cache=True,
309
+ output_attentions=False,
310
  output_hidden_states=True,
311
  return_dict=True,
312
  )
 
322
  cfg = torch.as_tensor(cfg_weight, device=cond.device, dtype=cond.dtype)
323
  logits = cond + cfg * (cond - uncond)
324
 
 
 
 
 
 
325
  # Apply repetition penalty
326
  ids_for_proc = generated_ids[:1, ...] # batch = 1
327
  logits = repetition_penalty_processor(ids_for_proc, logits) # expects (B,V)
 
356
  output = self.patched_model(
357
  inputs_embeds=next_token_embed,
358
  past_key_values=past,
359
+ output_attentions=False,
360
  output_hidden_states=True,
361
  return_dict=True,
362
  )
chatterbox/src/chatterbox/tts.py CHANGED
@@ -11,7 +11,7 @@ from huggingface_hub import snapshot_download, hf_hub_download
11
 
12
  from .models.t3 import T3
13
  from .models.t3.modules.t3_config import T3ConfigMultilingual
14
- from .models.s3tokenizer import S3_SR, drop_invalid_tokens
15
  from .models.s3gen import S3GEN_SR, S3Gen
16
  from .models.tokenizers import MTLTokenizer
17
  from .models.voice_encoder import VoiceEncoder
@@ -312,5 +312,12 @@ class ChatterboxTTS:
312
  ref_dict=self.conds.gen,
313
  )
314
  wav = wav.squeeze(0).detach().cpu().numpy()
 
 
 
 
 
 
 
315
  watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
316
  return torch.from_numpy(watermarked_wav).unsqueeze(0)
 
11
 
12
  from .models.t3 import T3
13
  from .models.t3.modules.t3_config import T3ConfigMultilingual
14
+ from .models.s3tokenizer import S3_SR, S3_TOKEN_RATE, drop_invalid_tokens
15
  from .models.s3gen import S3GEN_SR, S3Gen
16
  from .models.tokenizers import MTLTokenizer
17
  from .models.voice_encoder import VoiceEncoder
 
312
  ref_dict=self.conds.gen,
313
  )
314
  wav = wav.squeeze(0).detach().cpu().numpy()
315
+
316
+ # Drop the final speech token's audio: it is emitted just before
317
+ # EOS with degraded attention and decodes to ~40 ms of noise.
318
+ n_tokens = int(speech_tokens.shape[-1])
319
+ st_len = max(1, n_tokens - 1)
320
+ wav = wav[: st_len * (S3GEN_SR // S3_TOKEN_RATE)]
321
+
322
  watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
323
  return torch.from_numpy(watermarked_wav).unsqueeze(0)
requirements.txt CHANGED
@@ -1,9 +1,12 @@
 
 
 
 
1
  gradio
2
  numpy==1.26.0
3
  resampy==0.4.3
4
  librosa==0.10.0
5
  s3tokenizer
6
- torchaudio<2.8
7
 
8
  transformers==4.46.3
9
  diffusers==0.29.0
@@ -11,4 +14,4 @@ omegaconf==2.3.0
11
  resemble-perth==1.0.1
12
  silero-vad==5.1.2
13
  conformer==0.3.2
14
- safetensors
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu128
2
+ torch==2.8.0
3
+ torchaudio==2.8.0
4
+
5
  gradio
6
  numpy==1.26.0
7
  resampy==0.4.3
8
  librosa==0.10.0
9
  s3tokenizer
 
10
 
11
  transformers==4.46.3
12
  diffusers==0.29.0
 
14
  resemble-perth==1.0.1
15
  silero-vad==5.1.2
16
  conformer==0.3.2
17
+ safetensors