import torch from transformers import PreTrainedModel from .modeling_biome import BioME from .configuration_biome import BioMEConfig class BioMEModel(PreTrainedModel): config_class = BioMEConfig def __init__(self, config: BioMEConfig): super().__init__(config) self.model = BioME(config) self.post_init() def forward( self, wavs: torch.Tensor, start_pos: int = 0, padding_mask: torch.Tensor = None, fbank_mean: float = 15.41663, fbank_std: float = 6.55582, ): output, hidden_states, _, _, _, _ = self.model( wavs, start_pos=start_pos, padding_mask=padding_mask, fbank_mean=fbank_mean, fbank_std=fbank_std, ) return { "last_hidden_state": output, "hidden_states": hidden_states, }