Spaces:
Running on Zero
Running on Zero
Fix ZeroGPU and private model loading
Browse files
chatterbox/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
DELETED
|
@@ -1,154 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2025 Resemble AI
|
| 2 |
-
# Author: John Meade, Jeremy Hsu
|
| 3 |
-
# MIT License
|
| 4 |
-
import logging
|
| 5 |
-
import torch
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
-
from types import MethodType
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
@dataclass
|
| 14 |
-
class AlignmentAnalysisResult:
|
| 15 |
-
# was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
|
| 16 |
-
false_start: bool
|
| 17 |
-
# was this frame detected as being part of a long tail with potential hallucinations?
|
| 18 |
-
long_tail: bool
|
| 19 |
-
# was this frame detected as repeating existing text content?
|
| 20 |
-
repetition: bool
|
| 21 |
-
# was the alignment position of this frame too far from the previous frame?
|
| 22 |
-
discontinuity: bool
|
| 23 |
-
# has inference reached the end of the text tokens? eg, this remains false if inference stops early
|
| 24 |
-
complete: bool
|
| 25 |
-
# approximate position in the text token sequence. Can be used for generating online timestamps.
|
| 26 |
-
position: int
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class AlignmentStreamAnalyzer:
|
| 30 |
-
def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
|
| 31 |
-
"""
|
| 32 |
-
Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
|
| 33 |
-
activation maps. This module exploits this to perform online integrity checks which streaming.
|
| 34 |
-
A hook is injected into the specified attention layer, and heuristics are used to determine alignment
|
| 35 |
-
position, repetition, etc.
|
| 36 |
-
|
| 37 |
-
NOTE: currently requires no queues.
|
| 38 |
-
"""
|
| 39 |
-
# self.queue = queue
|
| 40 |
-
self.text_tokens_slice = (i, j) = text_tokens_slice
|
| 41 |
-
self.eos_idx = eos_idx
|
| 42 |
-
self.alignment = torch.zeros(0, j-i)
|
| 43 |
-
# self.alignment_bin = torch.zeros(0, j-i)
|
| 44 |
-
self.curr_frame_pos = 0
|
| 45 |
-
self.text_position = 0
|
| 46 |
-
|
| 47 |
-
self.started = False
|
| 48 |
-
self.started_at = None
|
| 49 |
-
|
| 50 |
-
self.complete = False
|
| 51 |
-
self.completed_at = None
|
| 52 |
-
|
| 53 |
-
# Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
| 54 |
-
# using it for all layers slows things down too much. We can apply it to just one layer
|
| 55 |
-
# by intercepting the kwargs and adding a forward hook (credit: jrm)
|
| 56 |
-
self.last_aligned_attn = None
|
| 57 |
-
self._add_attention_spy(tfmr, alignment_layer_idx)
|
| 58 |
-
|
| 59 |
-
def _add_attention_spy(self, tfmr, alignment_layer_idx):
|
| 60 |
-
"""
|
| 61 |
-
Adds a forward hook to a specific attention layer to collect outputs.
|
| 62 |
-
Using `output_attentions=True` is incompatible with optimized attention kernels, so
|
| 63 |
-
using it for all layers slows things down too much.
|
| 64 |
-
(credit: jrm)
|
| 65 |
-
"""
|
| 66 |
-
|
| 67 |
-
def attention_forward_hook(module, input, output):
|
| 68 |
-
"""
|
| 69 |
-
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
|
| 70 |
-
NOTE:
|
| 71 |
-
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
|
| 72 |
-
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
|
| 73 |
-
"""
|
| 74 |
-
step_attention = output[1].cpu() # (B, 16, N, N)
|
| 75 |
-
self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
|
| 76 |
-
|
| 77 |
-
target_layer = tfmr.layers[alignment_layer_idx].self_attn
|
| 78 |
-
hook_handle = target_layer.register_forward_hook(attention_forward_hook)
|
| 79 |
-
|
| 80 |
-
# Backup original forward
|
| 81 |
-
original_forward = target_layer.forward
|
| 82 |
-
def patched_forward(self, *args, **kwargs):
|
| 83 |
-
kwargs['output_attentions'] = True
|
| 84 |
-
return original_forward(*args, **kwargs)
|
| 85 |
-
|
| 86 |
-
# TODO: how to unpatch it?
|
| 87 |
-
target_layer.forward = MethodType(patched_forward, target_layer)
|
| 88 |
-
|
| 89 |
-
def step(self, logits):
|
| 90 |
-
"""
|
| 91 |
-
Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
|
| 92 |
-
"""
|
| 93 |
-
# extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
|
| 94 |
-
aligned_attn = self.last_aligned_attn # (N, N)
|
| 95 |
-
i, j = self.text_tokens_slice
|
| 96 |
-
if self.curr_frame_pos == 0:
|
| 97 |
-
# first chunk has conditioning info, text tokens, and BOS token
|
| 98 |
-
A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
|
| 99 |
-
else:
|
| 100 |
-
# subsequent chunks have 1 frame due to KV-caching
|
| 101 |
-
A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
|
| 102 |
-
|
| 103 |
-
# TODO: monotonic masking; could have issue b/c spaces are often skipped.
|
| 104 |
-
A_chunk[:, self.curr_frame_pos + 1:] = 0
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
|
| 108 |
-
|
| 109 |
-
A = self.alignment
|
| 110 |
-
T, S = A.shape
|
| 111 |
-
|
| 112 |
-
# update position
|
| 113 |
-
cur_text_posn = A_chunk[-1].argmax()
|
| 114 |
-
discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
|
| 115 |
-
if not discontinuity:
|
| 116 |
-
self.text_position = cur_text_posn
|
| 117 |
-
|
| 118 |
-
# Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
|
| 119 |
-
# To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
|
| 120 |
-
# and there are some strong activations in the first few tokens.
|
| 121 |
-
false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
|
| 122 |
-
self.started = not false_start
|
| 123 |
-
if self.started and self.started_at is None:
|
| 124 |
-
self.started_at = T
|
| 125 |
-
|
| 126 |
-
# Is generation likely complete?
|
| 127 |
-
self.complete = self.complete or self.text_position >= S - 3
|
| 128 |
-
if self.complete and self.completed_at is None:
|
| 129 |
-
self.completed_at = T
|
| 130 |
-
|
| 131 |
-
# NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
|
| 132 |
-
# NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
|
| 133 |
-
last_text_token_duration = A[15:, -3:].sum()
|
| 134 |
-
|
| 135 |
-
# Activations for the final token that last too long are likely hallucinations.
|
| 136 |
-
long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
|
| 137 |
-
|
| 138 |
-
# If there are activations in previous tokens after generation has completed, assume this is a repetition error.
|
| 139 |
-
repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
|
| 140 |
-
|
| 141 |
-
# If a bad ending is detected, force emit EOS by modifying logits
|
| 142 |
-
# NOTE: this means logits may be inconsistent with latents!
|
| 143 |
-
if long_tail or repetition:
|
| 144 |
-
logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
|
| 145 |
-
# (±2**15 is safe for all dtypes >= 16bit)
|
| 146 |
-
logits = -(2**15) * torch.ones_like(logits)
|
| 147 |
-
logits[..., self.eos_idx] = 2**15
|
| 148 |
-
|
| 149 |
-
# Suppress EoS to prevent early termination
|
| 150 |
-
if cur_text_posn < S - 3: # FIXME: arbitrary
|
| 151 |
-
logits[..., self.eos_idx] = -2**15
|
| 152 |
-
|
| 153 |
-
self.curr_frame_pos += 1
|
| 154 |
-
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|