tedi-resemble commited on
Commit
ea9a032
·
verified ·
1 Parent(s): f707318

Fix ZeroGPU and private model loading

Browse files
chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py DELETED
@@ -1,154 +0,0 @@
1
- # Copyright (c) 2025 Resemble AI
2
- # Author: John Meade, Jeremy Hsu
3
- # MIT License
4
- import logging
5
- import torch
6
- from dataclasses import dataclass
7
- from types import MethodType
8
-
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
-
13
- @dataclass
14
- class AlignmentAnalysisResult:
15
- # was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
16
- false_start: bool
17
- # was this frame detected as being part of a long tail with potential hallucinations?
18
- long_tail: bool
19
- # was this frame detected as repeating existing text content?
20
- repetition: bool
21
- # was the alignment position of this frame too far from the previous frame?
22
- discontinuity: bool
23
- # has inference reached the end of the text tokens? eg, this remains false if inference stops early
24
- complete: bool
25
- # approximate position in the text token sequence. Can be used for generating online timestamps.
26
- position: int
27
-
28
-
29
- class AlignmentStreamAnalyzer:
30
- def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
31
- """
32
- Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
33
- activation maps. This module exploits this to perform online integrity checks which streaming.
34
- A hook is injected into the specified attention layer, and heuristics are used to determine alignment
35
- position, repetition, etc.
36
-
37
- NOTE: currently requires no queues.
38
- """
39
- # self.queue = queue
40
- self.text_tokens_slice = (i, j) = text_tokens_slice
41
- self.eos_idx = eos_idx
42
- self.alignment = torch.zeros(0, j-i)
43
- # self.alignment_bin = torch.zeros(0, j-i)
44
- self.curr_frame_pos = 0
45
- self.text_position = 0
46
-
47
- self.started = False
48
- self.started_at = None
49
-
50
- self.complete = False
51
- self.completed_at = None
52
-
53
- # Using `output_attentions=True` is incompatible with optimized attention kernels, so
54
- # using it for all layers slows things down too much. We can apply it to just one layer
55
- # by intercepting the kwargs and adding a forward hook (credit: jrm)
56
- self.last_aligned_attn = None
57
- self._add_attention_spy(tfmr, alignment_layer_idx)
58
-
59
- def _add_attention_spy(self, tfmr, alignment_layer_idx):
60
- """
61
- Adds a forward hook to a specific attention layer to collect outputs.
62
- Using `output_attentions=True` is incompatible with optimized attention kernels, so
63
- using it for all layers slows things down too much.
64
- (credit: jrm)
65
- """
66
-
67
- def attention_forward_hook(module, input, output):
68
- """
69
- See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
70
- NOTE:
71
- - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
72
- - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
73
- """
74
- step_attention = output[1].cpu() # (B, 16, N, N)
75
- self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
76
-
77
- target_layer = tfmr.layers[alignment_layer_idx].self_attn
78
- hook_handle = target_layer.register_forward_hook(attention_forward_hook)
79
-
80
- # Backup original forward
81
- original_forward = target_layer.forward
82
- def patched_forward(self, *args, **kwargs):
83
- kwargs['output_attentions'] = True
84
- return original_forward(*args, **kwargs)
85
-
86
- # TODO: how to unpatch it?
87
- target_layer.forward = MethodType(patched_forward, target_layer)
88
-
89
- def step(self, logits):
90
- """
91
- Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
92
- """
93
- # extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
94
- aligned_attn = self.last_aligned_attn # (N, N)
95
- i, j = self.text_tokens_slice
96
- if self.curr_frame_pos == 0:
97
- # first chunk has conditioning info, text tokens, and BOS token
98
- A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
99
- else:
100
- # subsequent chunks have 1 frame due to KV-caching
101
- A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
102
-
103
- # TODO: monotonic masking; could have issue b/c spaces are often skipped.
104
- A_chunk[:, self.curr_frame_pos + 1:] = 0
105
-
106
-
107
- self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
108
-
109
- A = self.alignment
110
- T, S = A.shape
111
-
112
- # update position
113
- cur_text_posn = A_chunk[-1].argmax()
114
- discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
115
- if not discontinuity:
116
- self.text_position = cur_text_posn
117
-
118
- # Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
119
- # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
120
- # and there are some strong activations in the first few tokens.
121
- false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
122
- self.started = not false_start
123
- if self.started and self.started_at is None:
124
- self.started_at = T
125
-
126
- # Is generation likely complete?
127
- self.complete = self.complete or self.text_position >= S - 3
128
- if self.complete and self.completed_at is None:
129
- self.completed_at = T
130
-
131
- # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
132
- # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
133
- last_text_token_duration = A[15:, -3:].sum()
134
-
135
- # Activations for the final token that last too long are likely hallucinations.
136
- long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
137
-
138
- # If there are activations in previous tokens after generation has completed, assume this is a repetition error.
139
- repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
140
-
141
- # If a bad ending is detected, force emit EOS by modifying logits
142
- # NOTE: this means logits may be inconsistent with latents!
143
- if long_tail or repetition:
144
- logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
145
- # (±2**15 is safe for all dtypes >= 16bit)
146
- logits = -(2**15) * torch.ones_like(logits)
147
- logits[..., self.eos_idx] = 2**15
148
-
149
- # Suppress EoS to prevent early termination
150
- if cur_text_posn < S - 3: # FIXME: arbitrary
151
- logits[..., self.eos_idx] = -2**15
152
-
153
- self.curr_frame_pos += 1
154
- return logits