Automatic Speech Recognition
Transformers
Safetensors
DiCoW
speech
whisper
multilingual
speaker-diarization
meeting-transcription
target-speaker-asr
BUT-FIT
custom_code
Instructions to use bohatey/DiCoW_v3_2_SF with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use bohatey/DiCoW_v3_2_SF with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="bohatey/DiCoW_v3_2_SF", trust_remote_code=True)# Load model directly from transformers import AutoModelForSpeechSeq2Seq model = AutoModelForSpeechSeq2Seq.from_pretrained("bohatey/DiCoW_v3_2_SF", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import Optional, Union | |
| import re | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from torch.nn import CrossEntropyLoss | |
| from transformers import Cache | |
| from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput | |
| from transformers.models.whisper.modeling_whisper import ( | |
| WhisperForConditionalGeneration, | |
| shift_tokens_right, | |
| WhisperModel | |
| ) | |
| from transformers.utils import logging | |
| from .config import DiCoWConfig | |
| from .encoder import DiCoWEncoder | |
| from .generation import DiCoWGenerationMixin | |
| logging.set_verbosity_debug() | |
| logger = logging.get_logger("transformers") | |
| class SoftLabelCreator(torch.nn.Module): | |
| """ | |
| Handles label smoothing for timestamps and the dual-loss logic (Upper vs Lower case). | |
| """ | |
| def __init__(self, tokenizer, timestamp_sigma=0.08): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.timestamp_sigma = timestamp_sigma | |
| # Pre-compute the Gaussian smoothing matrix | |
| self.register_buffer('ts_smoothing_matrix', self._build_smoothing_matrix()) | |
| def _build_smoothing_matrix(self): | |
| # FIX: Use get_vocab() instead of .decoder.items() | |
| vocab = self.tokenizer.get_vocab() | |
| vocab_size = len(vocab) | |
| timestamp_pattern = re.compile(r'<\|(\d+\.\d+)\|>') | |
| # 1. Map Token IDs to Time Values | |
| id_to_time = {} | |
| for token_str, token_id in vocab.items(): | |
| match = timestamp_pattern.match(token_str) | |
| if match: | |
| id_to_time[token_id] = float(match.group(1)) | |
| if not id_to_time: | |
| return None | |
| # Sorted list for fast lookups | |
| sorted_ids = sorted(id_to_time.keys()) | |
| self.sorted_ts_ids = torch.tensor(sorted_ids) | |
| times = torch.tensor([id_to_time[i] for i in sorted_ids]) | |
| # 2. Create the Smoothing Matrix (Num_Timestamps x Vocab_Size) | |
| num_ts = len(sorted_ids) | |
| smoothing_matrix = torch.zeros(num_ts, vocab_size) | |
| # Vectorized Gaussian computation | |
| diff_sq = (times.unsqueeze(1) - times.unsqueeze(0)) ** 2 | |
| weights = torch.exp(-diff_sq / (2 * self.timestamp_sigma ** 2)) | |
| # Normalize | |
| weights = weights / weights.sum(dim=1, keepdim=True) | |
| # Scatter rows back to vocab size | |
| for i, ts_id in enumerate(sorted_ids): | |
| smoothing_matrix[i, self.sorted_ts_ids] = weights[i] | |
| return smoothing_matrix | |
| def _get_soft_distribution(self, labels, vocab_size): | |
| """Internal helper to convert hard labels -> soft timestamp labels""" | |
| device = labels.device | |
| # Start with One-Hot (clamp -100 to 0 temporarily) | |
| labels_clamped = labels.clamp(min=0) | |
| soft_labels = F.one_hot(labels_clamped, num_classes=vocab_size).float() | |
| # Apply Timestamp Smoothing if matrix exists | |
| if hasattr(self, 'ts_smoothing_matrix') and self.ts_smoothing_matrix is not None: | |
| sorted_ts_ids = self.sorted_ts_ids.to(device) | |
| smoothing_matrix = self.ts_smoothing_matrix.to(device) | |
| is_timestamp = torch.isin(labels, sorted_ts_ids) | |
| if is_timestamp.any(): | |
| ts_indices = torch.searchsorted(sorted_ts_ids, labels[is_timestamp]) | |
| soft_labels[is_timestamp] = smoothing_matrix[ts_indices] | |
| return soft_labels | |
| def compute_loss(self, logits, labels, upp_labels): | |
| """ | |
| Computes the enhanced SOT loss: | |
| 1. Generates soft labels (timestamp smoothed) for both 'labels' and 'upp_labels'. | |
| 2. Computes KL Divergence (via CrossEntropy) for both. | |
| 3. Takes the minimum loss per token (case invariance). | |
| 4. Applies padding mask. | |
| """ | |
| vocab_size = logits.size(-1) | |
| device = logits.device | |
| # Ensure labels are on correct device | |
| labels = labels.to(device) | |
| if upp_labels is not None: | |
| upp_labels = upp_labels.to(device) | |
| # Flatten inputs | |
| flat_logits = logits.view(-1, vocab_size) | |
| flat_labels = labels.reshape(-1) | |
| # 1. Generate Soft Targets for Lowercase | |
| soft_lower = self._get_soft_distribution(flat_labels, vocab_size) | |
| # 2. Generate Soft Targets for Uppercase (if provided) | |
| if upp_labels is not None: | |
| flat_upp = upp_labels.reshape(-1) | |
| soft_upper = self._get_soft_distribution(flat_upp, vocab_size) | |
| else: | |
| # Fallback if no upper labels provided (shouldn't happen in this pipeline) | |
| soft_upper = soft_lower | |
| # 3. Compute Cross Entropy (Soft Target Mode) | |
| # Note: CE with soft targets = -sum(target * log_prob) | |
| loss_fct = CrossEntropyLoss(reduction='none') | |
| loss_lower = loss_fct(flat_logits, soft_lower) | |
| loss_upper = loss_fct(flat_logits, soft_upper) | |
| # 4. Mask Padding (ignore_index = -100) | |
| # Soft-target CE doesn't support ignore_index automatically | |
| mask = (flat_labels != -100).float() | |
| loss_lower = loss_lower * mask | |
| loss_upper = loss_upper * mask | |
| # 5. Take Min (Case Invariance) and Normalize | |
| combined_min = torch.min(loss_lower, loss_upper) | |
| # Sum and divide by number of non-padding tokens | |
| return combined_min.sum() / mask.sum().clamp(min=1) | |
| class DiCoW(WhisperModel): | |
| def __init__(self, config: DiCoWConfig): | |
| super().__init__(config) | |
| self.encoder = DiCoWEncoder(config) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_features: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| stno_mask: Optional[torch.FloatTensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, | |
| past_key_values: Optional[Cache] = None, | |
| decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, | |
| decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| enrollments=None | |
| ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if encoder_outputs is None: | |
| input_features = self._mask_input_features(input_features, attention_mask=attention_mask) | |
| encoder_outputs = self.encoder( | |
| input_features, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| head_mask=head_mask, | |
| return_dict=return_dict, | |
| stno_mask=stno_mask, | |
| enrollments=enrollments | |
| ) | |
| decoder_outputs = self.decoder( | |
| input_ids=decoder_input_ids, | |
| attention_mask=decoder_attention_mask, | |
| encoder_hidden_states=encoder_outputs[0], | |
| head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=decoder_inputs_embeds, | |
| position_ids=decoder_position_ids, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| ) | |
| if not return_dict: | |
| return decoder_outputs + encoder_outputs | |
| return Seq2SeqModelOutput( | |
| last_hidden_state=decoder_outputs.last_hidden_state, | |
| past_key_values=decoder_outputs.past_key_values, | |
| decoder_hidden_states=decoder_outputs.hidden_states, | |
| decoder_attentions=decoder_outputs.attentions, | |
| cross_attentions=decoder_outputs.cross_attentions, | |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, | |
| encoder_hidden_states=encoder_outputs.hidden_states, | |
| encoder_attentions=encoder_outputs.attentions, | |
| ) | |
| class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration): | |
| config_class = DiCoWConfig | |
| def __init__(self, config: DiCoWConfig): | |
| super().__init__(config) | |
| self.model = DiCoW(config) | |
| self.encoder_logits = None | |
| self.tokenizer = None | |
| self.stno_mask = None | |
| self.stno_mask_seek = None | |
| self.soft_label_creator = None | |
| self.post_init() | |
| def set_tokenizer(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| # Initialize the helper class | |
| self.soft_label_creator = SoftLabelCreator(tokenizer) | |
| def get_enc_logits(self, hidden_states): | |
| encoder = self.model.get_encoder() | |
| hidden_states = encoder.possibly_update_last_hidden_states(hidden_states) | |
| logits = encoder.lm_head(hidden_states) | |
| return logits | |
| def forward( | |
| self, | |
| input_features: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| stno_mask: Optional[torch.FloatTensor] = None, | |
| decoder_input_ids: Optional[torch.LongTensor] = None, | |
| decoder_attention_mask: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| decoder_head_mask: Optional[torch.Tensor] = None, | |
| cross_attn_head_mask: Optional[torch.Tensor] = None, | |
| encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None, | |
| past_key_values: Optional[Cache] = None, | |
| decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None, | |
| decoder_position_ids: Optional[tuple[torch.LongTensor]] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| upp_labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| forced_decoder_ids: Optional[torch.LongTensor] = None, | |
| enrollments=None, | |
| ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None: | |
| if decoder_input_ids is None and decoder_inputs_embeds is None: | |
| decoder_input_ids = shift_tokens_right( | |
| labels, self.config.pad_token_id, self.config.decoder_start_token_id | |
| ) | |
| outputs = self.model( | |
| input_features, | |
| attention_mask=attention_mask, | |
| decoder_input_ids=decoder_input_ids, | |
| encoder_outputs=encoder_outputs, | |
| decoder_attention_mask=decoder_attention_mask, | |
| head_mask=head_mask, | |
| decoder_head_mask=decoder_head_mask, | |
| cross_attn_head_mask=cross_attn_head_mask, | |
| past_key_values=past_key_values, | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| decoder_position_ids=decoder_position_ids, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| stno_mask=stno_mask, | |
| enrollments=enrollments, | |
| ) | |
| dec_lm_logits = self.proj_out(outputs.last_hidden_state) | |
| loss = None | |
| if labels is not None: | |
| # --- UPDATED LOSS CALCULATION --- | |
| if self.soft_label_creator is not None: | |
| # Delegate all soft label creation, flattening, and min-loss logic to the helper | |
| dec_loss = self.soft_label_creator.compute_loss(dec_lm_logits, labels, upp_labels) | |
| else: | |
| # Fallback to original hard label implementation if tokenizer/helper not ready | |
| loss_fct = CrossEntropyLoss(reduction='none') | |
| labels = labels.to(dec_lm_logits.device) | |
| flat_logits = dec_lm_logits.view(-1, self.config.vocab_size) | |
| dec_loss1 = loss_fct(flat_logits, labels.reshape(-1)) | |
| if upp_labels is not None: | |
| upp_labels = upp_labels.to(dec_lm_logits.device) | |
| dec_loss2 = loss_fct(flat_logits, upp_labels.reshape(-1)) | |
| dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean() | |
| else: | |
| dec_loss = dec_loss1.mean() | |
| # -------------------------------- | |
| if self.config.ctc_weight > 0.0: | |
| enc_lm_logits = self.get_enc_logits(outputs.encoder_last_hidden_state) | |
| # Prepare CTC labels | |
| enc_labels = labels.clone().to(dec_lm_logits.device) | |
| for token in self.tokenizer.prefix_tokens: | |
| if (enc_labels[:, 0] == token).all(): | |
| enc_labels = enc_labels[:, 1:] | |
| enc_labels[enc_labels == self.config.eos_token_id] = -100 | |
| ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels) | |
| loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss | |
| else: | |
| loss = dec_loss | |
| if not return_dict: | |
| output = (dec_lm_logits,) + outputs[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return Seq2SeqLMOutput( | |
| loss=loss, | |
| logits=dec_lm_logits, | |
| past_key_values=outputs.past_key_values, | |
| decoder_hidden_states=outputs.decoder_hidden_states, | |
| decoder_attentions=outputs.decoder_attentions, | |
| cross_attentions=outputs.cross_attentions, | |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, | |
| encoder_hidden_states=outputs.encoder_hidden_states, | |
| encoder_attentions=outputs.encoder_attentions, | |
| ) | |
| def _get_feat_extract_output_lengths(self, attention_mask: torch.LongTensor) -> torch.LongTensor: | |
| return (self.model.get_encoder()._get_feat_extract_output_lengths(attention_mask) / 4).ceil() |