zqhuang commited on
Commit
3836184
·
verified ·
1 Parent(s): aec6738

Upload UltravoxPipeline

Browse files
Files changed (1) hide show
  1. ultravox_model.py +0 -53
ultravox_model.py CHANGED
@@ -174,59 +174,6 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
174
  yield i_b, audio_index
175
  audio_index += 1
176
 
177
- def _select_embedings(
178
- self,
179
- inputs_embeds: torch.Tensor,
180
- start_idx: torch.Tensor,
181
- lengths: torch.Tensor,
182
- ) -> torch.Tensor:
183
- """
184
- Select a contiguous slice per batch starting at `start_idx[b]` with
185
- length `lengths[b]`, returned in a compact, front-aligned tensor.
186
- Any positions in the output that correspond to padding are zeroed out.
187
-
188
- Supports both 3D tensors (B, T, D) and 2D tensors (B, T).
189
- """
190
- B = inputs_embeds.size(0)
191
- T = inputs_embeds.size(1)
192
- max_length = int(lengths.max().item())
193
- if max_length == 0:
194
- # Return an empty slice with correct rank
195
- if inputs_embeds.dim() == 3:
196
- return inputs_embeds.new_zeros((B, 0, inputs_embeds.size(2)))
197
- else:
198
- return inputs_embeds.new_zeros((B, 0), dtype=inputs_embeds.dtype)
199
-
200
- # --- Create indices to gather ---
201
- idx = torch.arange(
202
- max_length, device=inputs_embeds.device, dtype=start_idx.dtype
203
- ) # (Lmax,)
204
- pos = start_idx.unsqueeze(1) + idx.unsqueeze(0) # (B, Lmax)
205
- # Clamp to prevent out-of-bounds gather, we will mask the invalid values later
206
- pos = pos.clamp_(0, T - 1)
207
-
208
- # --- Create mask for valid output positions ---
209
- mask = idx.unsqueeze(0) < lengths.unsqueeze(1) # (B, Lmax)
210
-
211
- # --- Gather and mask ---
212
- if inputs_embeds.dim() == 3:
213
- D = inputs_embeds.size(2)
214
- gathered = inputs_embeds.gather(
215
- 1, pos.unsqueeze(-1).expand(B, max_length, D)
216
- )
217
- # Zero out the padded values
218
- gathered = gathered * mask.unsqueeze(-1)
219
- return gathered
220
- elif inputs_embeds.dim() == 2:
221
- gathered = inputs_embeds.gather(1, pos)
222
- # Zero out the padded values
223
- gathered = gathered * mask
224
- return gathered
225
- else:
226
- raise ValueError(
227
- f"_select_embedings expects 2D or 3D tensors, got {inputs_embeds.dim()}D"
228
- )
229
-
230
  def _decoder_layers(self):
231
  """Return decoder blocks across architectures (LLaMA/GLM/etc.)."""
232
  lm = self.language_model
 
174
  yield i_b, audio_index
175
  audio_index += 1
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def _decoder_layers(self):
178
  """Return decoder blocks across architectures (LLaMA/GLM/etc.)."""
179
  lm = self.language_model