Automatic Speech Recognition
Transformers
Safetensors
DiCoW
speech
whisper
multilingual
speaker-diarization
meeting-transcription
BUT-FIT
custom_code
Instructions to use BUT-FIT/DiCoW_v3_2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BUT-FIT/DiCoW_v3_2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="BUT-FIT/DiCoW_v3_2", trust_remote_code=True)# Load model directly from transformers import AutoModelForSpeechSeq2Seq model = AutoModelForSpeechSeq2Seq.from_pretrained("BUT-FIT/DiCoW_v3_2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| from torch import nn | |
| class MultiHeadCoAttention(nn.Module): | |
| def __init__(self, multi_dim, single_dim, num_heads): | |
| assert multi_dim % num_heads == 0, 'multi_dim must be divisible by num_heads' | |
| assert single_dim % num_heads == 0, 'single_dim must be divisible by num_heads' | |
| super().__init__() | |
| self.q_proj = nn.Linear(single_dim, single_dim) | |
| self.k_proj = nn.Linear(single_dim, single_dim) | |
| self.multi_v_proj = nn.Linear(multi_dim, multi_dim) # D' | |
| self.single_v_proj = nn.Linear(single_dim, single_dim) # D | |
| self.multi_out_proj = nn.Linear(multi_dim, multi_dim) # D' | |
| self.single_out_proj = nn.Linear(single_dim, single_dim) # D | |
| self.multi_dim = multi_dim | |
| self.single_dim = single_dim | |
| self.num_heads = num_heads | |
| def forward(self, query, key, multi_value, single_value): | |
| # q, k, multi_v: (T,B,ch,D') | |
| # single_v: (T,B,1,D) | |
| query = torch.transpose(query, 0, 1) # (B,T,ch,D')...[32, 150, 4, 64] | |
| key = torch.transpose(key, 0, 1) # (B,T,ch,D')...[32, 150, 4, 64] | |
| multi_value = torch.permute(multi_value, (1, 2, 0, 3)) # (B,ch,T,D')...[32, 4, 150, 64] | |
| single_value = torch.permute(single_value, (1, 2, 0, 3)) # (B,1,T,D)...[32, 1, 150, 256] | |
| ########### | |
| q = torch.split(self.q_proj(query), self.single_dim // self.num_heads, dim=-1) # seq: (B,T,ch,D'/h) | |
| q = torch.stack(q, dim=1) # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8] | |
| k = torch.split(self.k_proj(key), self.single_dim // self.num_heads, dim=-1) # seq: (B,T,ch,D'/h) | |
| k = torch.stack(k, dim=1) # (B,h,T,ch,D'/h)...[32, 8, 150, 4, 8] | |
| multi_v = torch.split(self.multi_v_proj(multi_value), self.multi_dim // self.num_heads, | |
| dim=-1) # seq: (B,ch,T,D'/h) | |
| multi_v = torch.stack(multi_v, dim=1) # (B, h, ch, T, D'/h)...[32, 8, 4, 150, 8] | |
| single_v = torch.split(self.single_v_proj(single_value), self.single_dim // self.num_heads, | |
| dim=-1) # seq: (B,1,T,D/h) | |
| single_v = torch.stack(single_v, dim=1) # seq: (B,h,1,T,D/h)...[32, 32, 1, 150, 8] | |
| q = q.view(*q.shape[:-2], -1) # (B, h, T, ch*D/h) | |
| k = k.view(*k.shape[:-2], -1) # (B, h, T, ch*D/h) | |
| normalizer = torch.sqrt(torch.Tensor([float(q.shape[-1])]).to(q.device)) | |
| sim_mat = torch.matmul(q, torch.transpose(k, -2, -1)) / normalizer # (B, h, T, T) | |
| att_mat = torch.unsqueeze(nn.functional.softmax(sim_mat, dim=-1), 2) # (B, h, 1, T, T) | |
| # co-attention | |
| multi_result = torch.matmul(att_mat, multi_v) # (B, h, ch, T, D'/h) | |
| single_result = torch.matmul(att_mat, single_v) # (B, h, 1, T, D/h) | |
| multi_result = torch.permute(multi_result, (3, 0, 2, 1, 4)) # (T, B, ch, h, D'/h) | |
| single_result = torch.permute(single_result, (3, 0, 2, 1, 4)) # (T, B, 1, h, D/h) | |
| multi_result = torch.reshape(multi_result, multi_result.shape[:-2] + (-1,)) # (T, B, ch, D') | |
| single_result = torch.reshape(single_result, single_result.shape[:-2] + (-1,)) # (T, B, 1, D) | |
| multi_result = self.multi_out_proj(multi_result) | |
| single_result = self.single_out_proj(single_result) | |
| return multi_result, single_result | |
| class CoAttention(nn.Module): | |
| def __init__(self, embed_dim=768, single_dim=256, multi_dim=64, n_heads=8, attn_dropout=0., | |
| init_mult=1e-2): # , pre_norm=True): | |
| super().__init__() | |
| self.init_mult = init_mult | |
| self.in_single_proj = nn.Linear(embed_dim, single_dim) # single_dim == D | |
| self.in_single_ln = nn.LayerNorm(single_dim) | |
| self.in_multi_proj = nn.Linear(embed_dim, multi_dim) # multi_dim == D' | |
| self.in_multi_ln = nn.LayerNorm(multi_dim) | |
| self.mca = MultiHeadCoAttention(multi_dim, single_dim, n_heads) | |
| self.mca_multi_out_ln = nn.LayerNorm(multi_dim) | |
| self.mca_single_out_ln = nn.LayerNorm(single_dim) | |
| # default MHA input: (seq, batch, feature) | |
| self.cross_frame_mha = nn.MultiheadAttention(single_dim, n_heads, dropout=attn_dropout, bias=True, kdim=None, | |
| vdim=None) | |
| self.mha_ln = nn.LayerNorm(single_dim) | |
| self.cat_proj = nn.Linear(single_dim + multi_dim, embed_dim) | |
| self.miso = False | |
| def scale_weights(self): | |
| self.cat_proj.bias.data *= 0. | |
| self.cat_proj.weight.data *= self.init_mult | |
| def forward(self, x): | |
| # x: (T,B,ch,F); (150, 32, 4, 768) | |
| frames, B, chans, feat_dim = x.shape | |
| single_x = torch.mean(x,dim=2) # (T,B,F) | |
| single_x = self.in_single_ln(self.in_single_proj(single_x)).unsqueeze(dim=-2) # (T,B,1,D) | |
| multi_x = self.in_multi_ln(self.in_multi_proj(x)) # (T,B,ch,D') | |
| # MCA | |
| multi_mca, single_mca = self.mca(single_x, single_x, multi_x, single_x) # (T,B,ch,D'), (T,B,ch,D) | |
| single_x = single_x + single_mca | |
| multi_x = multi_x + multi_mca | |
| multi_x = self.mca_multi_out_ln(multi_x) # (T,B,ch,D') | |
| single_x = torch.squeeze(self.mca_single_out_ln(single_x), -2) # (T,B,D) | |
| # MHA | |
| single_mha, _ = self.cross_frame_mha(single_x, single_x, single_x, need_weights=False) # (T, B, D) | |
| single_x = self.mha_ln(single_mha + single_x) | |
| # join representations | |
| single_x = single_x.unsqueeze(-2) # (T,B,1,D) | |
| single_x_tile = torch.tile(single_x, (1, 1, chans, 1)) # (T,B,ch,D) | |
| cat_x = torch.cat([single_x_tile, multi_x], dim=-1) # (T,B,ch,D+D') | |
| out = self.cat_proj(cat_x) # (T,B,ch,F) | |
| return out |