Hguimaraes commited on
Commit
8d3f9a4
·
verified ·
1 Parent(s): 9804c97

Upload model

Browse files
Files changed (7) hide show
  1. README.md +199 -0
  2. biome_model.py +35 -0
  3. biome_modules.py +246 -0
  4. config.json +36 -0
  5. configuration_biome.py +62 -0
  6. model.safetensors +3 -0
  7. modeling_biome.py +245 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
biome_model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel
3
+
4
+ from .modeling_biome import BioME
5
+ from .configuration_biome import BioMEConfig
6
+
7
+
8
+ class BioMEModel(PreTrainedModel):
9
+ config_class = BioMEConfig
10
+
11
+ def __init__(self, config: BioMEConfig):
12
+ super().__init__(config)
13
+ self.model = BioME(config)
14
+ self.post_init()
15
+
16
+ def forward(
17
+ self,
18
+ wavs: torch.Tensor,
19
+ start_pos: int = 0,
20
+ padding_mask: torch.Tensor = None,
21
+ fbank_mean: float = 15.41663,
22
+ fbank_std: float = 6.55582,
23
+ ):
24
+ output, hidden_states, _, _, _, _ = self.model(
25
+ wavs,
26
+ start_pos=start_pos,
27
+ padding_mask=padding_mask,
28
+ fbank_mean=fbank_mean,
29
+ fbank_std=fbank_std,
30
+ )
31
+
32
+ return {
33
+ "last_hidden_state": output,
34
+ "hidden_states": hidden_states,
35
+ }
biome_modules.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Our Transformer-based model for the AudioSet dataset.
3
+ The model is heavily inspired in the Llama-3 model:
4
+ reference: https://github.com/meta-llama/llama3/blob/main/llama/model.py
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Optional, Tuple
10
+
11
+ from .configuration_biome import BioMEConfig
12
+
13
+
14
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
15
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
16
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
17
+ freqs = torch.outer(t, freqs)
18
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
19
+ return freqs_cis
20
+
21
+
22
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
23
+ ndim = x.ndim
24
+ assert 0 <= 1 < ndim
25
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
26
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
27
+ return freqs_cis.view(*shape)
28
+
29
+
30
+ def apply_rotary_emb(
31
+ xq: torch.Tensor,
32
+ xk: torch.Tensor,
33
+ freqs_cis: torch.Tensor,
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
36
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
37
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
38
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
39
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
40
+ return xq_out.type_as(xq), xk_out.type_as(xk)
41
+
42
+
43
+ class GroupedQueryAttention(nn.Module):
44
+ """
45
+ A MultiHeadGroupedQueryAttention implementation.
46
+ Paper: 'GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints' (https://arxiv.org/pdf/2305.13245)
47
+ Code heavily inspired on:
48
+ - https://github.com/meta-llama/llama3/blob/main/llama/model.py
49
+ - https://docs.pytorch.org/torchtune/0.4/_modules/torchtune/modules/attention.html
50
+
51
+ Args:
52
+ qdim (int): Query input dimension. Default: 512
53
+ kdim (int, optional): Key input dimension. Default: qdim
54
+ vdim (int, optional): Value input dimension. Default: qdim
55
+ embd_dim (int, optional): Embedding dimension after projection. Must be
56
+ divisible by nheads. Default: qdim
57
+ nheads (int): Number of attention heads. Default: 8
58
+ dropout (float): Dropout probability. Default: 0
59
+ bias (bool): Use bias in projections. Default: True
60
+ use_gqa (bool): Enable grouped query attention. Default: False
61
+ device (torch.device, optional): Device for parameters
62
+ dtype (optional): Data type for parameters
63
+
64
+ Shape:
65
+ - Query: (B, L_q, qdim)
66
+ - Key: (B, L_k, kdim)
67
+ - Value: (B, L_k, vdim)
68
+ - Output: (B, L_q, qdim)
69
+ where B is batch size and L is sequence length
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ dim: int = 512,
75
+ num_q_heads: int = 16,
76
+ num_kv_heads: int = 4,
77
+ dropout: float = 0.0,
78
+ bias: bool = True,
79
+ device: Optional[torch.device] = None,
80
+ dtype: Optional[torch.dtype] = None,
81
+ ) -> None:
82
+ super().__init__()
83
+ self.dim = dim
84
+ self.num_q_heads = num_q_heads
85
+ self.num_kv_heads = num_kv_heads
86
+ self.dropout = dropout
87
+ self.bias = bias
88
+ factory_kwargs = {"device": device, "dtype": dtype}
89
+
90
+ assert dim % num_q_heads == 0, "Embedding dim is not divisible by nheads"
91
+ self.dim_per_head = dim // num_q_heads
92
+
93
+ self.q_proj = nn.Linear(self.dim, num_q_heads * self.dim_per_head, bias=bias, **factory_kwargs)
94
+ self.k_proj = nn.Linear(self.dim, num_kv_heads * self.dim_per_head, bias=bias, **factory_kwargs)
95
+ self.v_proj = nn.Linear(self.dim, num_kv_heads * self.dim_per_head, bias=bias, **factory_kwargs)
96
+ self.out_proj = nn.Linear(num_q_heads * self.dim_per_head, self.dim, bias=bias, **factory_kwargs)
97
+
98
+ def forward(
99
+ self,
100
+ x: torch.Tensor,
101
+ start_pos: int,
102
+ freqs_cis: torch.Tensor,
103
+ attn_mask: Optional[torch.Tensor] = None,
104
+ is_causal: bool = False,
105
+ ) -> torch.Tensor:
106
+ """
107
+ Args:
108
+ x (torch.Tensor): Input tensor of shape (batch_size, tgt_len, qdim)
109
+ start_pos (int): Start position for rotary embeddings
110
+ freqs_cis (torch.Tensor): Rotary embeddings
111
+ attn_mask (torch.Tensor): Attention mask
112
+ is_causal (bool): If True, applies a causal mask to prevent attending to future positions.
113
+
114
+ Returns:
115
+ torch.Tensor: Output tensor of shape (batch_size, tgt_len, qdim)
116
+ """
117
+ bsz, seqlen, _ = x.shape
118
+
119
+ # Step 1: Apply projections
120
+ xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
121
+
122
+ # Step 2: Split the heads before the scale-dot product attention
123
+ xq = xq.view(bsz, seqlen, self.num_q_heads, self.dim_per_head)
124
+ xk = xk.view(bsz, seqlen, self.num_kv_heads, self.dim_per_head)
125
+ xv = xv.view(bsz, seqlen, self.num_kv_heads, self.dim_per_head)
126
+
127
+ # Step 3: Apply rotary embeddings
128
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
129
+
130
+ # Step 4: Apply scale-dot product attention
131
+ # Note: torch sdpa expects (batch_size, num_heads, seq_len, dim_per_head)
132
+ attn_output = (
133
+ F.scaled_dot_product_attention(
134
+ xq.transpose(1, 2),
135
+ xk.transpose(1, 2),
136
+ xv.transpose(1, 2),
137
+ attn_mask=attn_mask,
138
+ dropout_p=self.dropout if self.training else 0.0,
139
+ is_causal=is_causal,
140
+ enable_gqa=True,
141
+ )
142
+ .transpose(1, 2)
143
+ .flatten(-2) # (B, nheads, L, dim_per_head) -> (B, L_t, E_total)
144
+ )
145
+
146
+ return self.out_proj(attn_output)
147
+
148
+
149
+ class RMSNorm(nn.Module):
150
+ def __init__(self, dim: int, eps: float = 1e-6):
151
+ super().__init__()
152
+ self.eps = eps
153
+ self.weight = nn.Parameter(torch.ones(dim))
154
+
155
+ def _norm(self, x):
156
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
157
+
158
+ def forward(self, x):
159
+ output = self._norm(x.float()).type_as(x)
160
+ return output * self.weight
161
+
162
+
163
+ class FiLM(nn.Module):
164
+ """
165
+ A Feature-wise Linear Modulation Layer from
166
+ 'FiLM: Visual Reasoning with a General Conditioning Layer'
167
+ """
168
+ def __init__(self, d_model: int, context_dim: int):
169
+ super().__init__()
170
+ self.d_model = d_model
171
+ self.context_dim = context_dim
172
+
173
+ self.shared_modulator = nn.Linear(context_dim, 2 * d_model)
174
+
175
+ def forward(self, x, ctx):
176
+ """
177
+ Arguments
178
+ ----------
179
+ x: torch.Tensor
180
+ Activations / Tensor in the Transformer of shape (B, T, d_model)
181
+ ctx: torch.Tensor
182
+ Side channel information. It can be (B, F) or (B, T, F).
183
+ If 3-dimensional, note that the sequence-dimension, T, must match
184
+ the input tensor where you are going to combine the FiLM'ed result.
185
+ """
186
+ params = self.shared_modulator(ctx)
187
+ params = params.view(params.size(0), 1, -1)
188
+ gammas, betas = params.chunk(2, dim=-1)
189
+
190
+ return (gammas * x) + betas
191
+
192
+
193
+ class TransformerFFN(nn.Module):
194
+ def __init__(self, dim, hidden_dim, bias: bool = False):
195
+ super().__init__()
196
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
197
+ self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
198
+ self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
199
+
200
+ def forward(self, x):
201
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
202
+
203
+
204
+ class TransformerEncoderLayer(nn.Module):
205
+ def __init__(self, config: BioMEConfig):
206
+ super().__init__()
207
+ self.use_context = config.use_context
208
+ if self.use_context:
209
+ self.film = FiLM(
210
+ d_model=config.hidden_size, context_dim=config.ctx_hidden_size
211
+ )
212
+ self.film_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
213
+ self.film_norm_ctx = RMSNorm(config.ctx_hidden_size, eps=config.norm_eps)
214
+
215
+ self.attention = GroupedQueryAttention(
216
+ dim=config.hidden_size,
217
+ num_q_heads=config.num_query_heads,
218
+ num_kv_heads=config.num_kv_heads,
219
+ dropout=config.dropout,
220
+ bias=config.bias,
221
+ )
222
+
223
+ self.feed_forward = TransformerFFN(
224
+ dim=config.hidden_size,
225
+ hidden_dim=config.ffn_hidden_size,
226
+ )
227
+
228
+ self.attention_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
229
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
230
+
231
+ def forward(
232
+ self,
233
+ x: torch.Tensor,
234
+ start_pos: int,
235
+ freqs_cis: torch.Tensor,
236
+ ctx: torch.Tensor = None,
237
+ padding_mask: torch.Tensor = None,
238
+ ):
239
+ if padding_mask is not None:
240
+ x[padding_mask] = 0
241
+
242
+ h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis)
243
+ if self.use_context:
244
+ h = self.film(self.film_norm(h), self.film_norm_ctx(ctx))
245
+ out = h + self.feed_forward(self.ffn_norm(h))
246
+ return out
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BioMEModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_biome.BioMEConfig",
7
+ "AutoModel": "biome_model.BioMEModel"
8
+ },
9
+ "bias": false,
10
+ "context_type": "mss",
11
+ "ctx_hidden_size": 258,
12
+ "dropout": 0.1,
13
+ "dtype": "float32",
14
+ "embed_dim": 384,
15
+ "ffn_hidden_size": 1344,
16
+ "frame_length": 25,
17
+ "frame_shift": 10,
18
+ "hidden_size": 384,
19
+ "input_patch_size": 16,
20
+ "max_cache_size": 10,
21
+ "max_seq_len": 1024,
22
+ "model_type": "biome",
23
+ "mss_n_fft1": 256,
24
+ "mss_n_fft2": 256,
25
+ "mss_win_shift": 128,
26
+ "mss_win_size": 256,
27
+ "n_mels": 128,
28
+ "norm_eps": 1e-05,
29
+ "num_kv_heads": 4,
30
+ "num_layers": 12,
31
+ "num_query_heads": 8,
32
+ "rope_theta": 10000.0,
33
+ "sample_rate": 16000,
34
+ "transformers_version": "5.0.0",
35
+ "use_context": true
36
+ }
configuration_biome.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedConfig
2
+
3
+ class BioMEConfig(PreTrainedConfig):
4
+ model_type = "biome"
5
+
6
+ def __init__(
7
+ self,
8
+ num_layers: int = 12,
9
+ num_query_heads: int = 12,
10
+ num_kv_heads: int = 4,
11
+ embed_dim: int = 512,
12
+ hidden_size: int = 384,
13
+ ffn_hidden_size: int = 1344,
14
+ dropout: float = 0.1,
15
+ sample_rate: int = 16000,
16
+ frame_length: int = 25,
17
+ frame_shift: int = 10,
18
+ n_mels: int = 128,
19
+ input_patch_size: int = 16,
20
+ norm_eps: float = 1e-5,
21
+ max_seq_len: int = 1024,
22
+ rope_theta: float = 10000.0,
23
+ bias: bool = False,
24
+ use_context: bool = True,
25
+ context_type: str = "mss",
26
+ max_cache_size: int = 10,
27
+ ctx_hidden_size: int = 258,
28
+ mss_n_fft1: int = 256,
29
+ mss_n_fft2: int = 256,
30
+ mss_win_size: int = 256,
31
+ mss_win_shift: int = 128,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(**kwargs)
35
+
36
+ # Transformer Parameters
37
+ self.num_layers = num_layers
38
+ self.num_query_heads = num_query_heads
39
+ self.num_kv_heads = num_kv_heads
40
+ self.embed_dim = embed_dim
41
+ self.hidden_size = hidden_size
42
+ self.ffn_hidden_size = ffn_hidden_size
43
+ self.dropout = dropout
44
+ self.sample_rate = sample_rate
45
+ self.frame_length = frame_length
46
+ self.frame_shift = frame_shift
47
+ self.n_mels = n_mels
48
+ self.input_patch_size = input_patch_size
49
+ self.norm_eps = norm_eps
50
+ self.max_seq_len = max_seq_len
51
+ self.rope_theta = rope_theta
52
+ self.bias = bias
53
+
54
+ # Context Parameters
55
+ self.use_context = use_context
56
+ self.context_type = context_type
57
+ self.max_cache_size = max_cache_size
58
+ self.ctx_hidden_size = ctx_hidden_size
59
+ self.mss_n_fft1 = mss_n_fft1
60
+ self.mss_n_fft2 = mss_n_fft2
61
+ self.mss_win_size = mss_win_size
62
+ self.mss_win_shift = mss_win_shift
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3586650791ff61bdfe48be1f7c7564c97bb052dbc53b3275d12135fb576146a
3
+ size 105578728
modeling_biome.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio.compliance.kaldi as ta_kaldi
5
+
6
+ from .biome_modules import RMSNorm
7
+ from .configuration_biome import BioMEConfig
8
+ from .biome_modules import precompute_freqs_cis
9
+ from .biome_modules import TransformerEncoderLayer
10
+
11
+
12
+ class BioME(nn.Module):
13
+ def __init__(self, cfg: BioMEConfig):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+
17
+ self.n_layers = cfg.num_layers
18
+ self.patch_embedding = nn.Conv2d(
19
+ 1,
20
+ cfg.embed_dim,
21
+ kernel_size=cfg.input_patch_size,
22
+ stride=cfg.input_patch_size,
23
+ bias=False,
24
+ )
25
+
26
+ self.dropout_input = nn.Dropout(cfg.dropout)
27
+
28
+ self.post_extract_proj = (
29
+ nn.Linear(cfg.embed_dim, cfg.hidden_size)
30
+ if cfg.embed_dim != cfg.hidden_size
31
+ else nn.Identity()
32
+ )
33
+
34
+ self.layers = torch.nn.ModuleList()
35
+ for _ in range(cfg.num_layers):
36
+ self.layers.append(TransformerEncoderLayer(cfg))
37
+
38
+ self.feature_norm = RMSNorm(cfg.embed_dim, eps=cfg.norm_eps)
39
+ self.freqs_cis = precompute_freqs_cis(
40
+ cfg.hidden_size // cfg.num_query_heads,
41
+ cfg.max_seq_len * 2,
42
+ cfg.rope_theta,
43
+ )
44
+
45
+ self.modulation_cache = {}
46
+
47
+ # Weights initialization
48
+ deep_norm_beta = math.pow(8 * cfg.num_layers, -1 / 4)
49
+ for i in range(cfg.num_layers):
50
+ nn.init.xavier_normal_(self.layers[i].attention.k_proj.weight, gain=1)
51
+ nn.init.xavier_normal_(
52
+ self.layers[i].attention.v_proj.weight, gain=deep_norm_beta
53
+ )
54
+ nn.init.xavier_normal_(self.layers[i].attention.q_proj.weight, gain=1)
55
+ nn.init.xavier_normal_(
56
+ self.layers[i].attention.out_proj.weight, gain=deep_norm_beta
57
+ )
58
+ nn.init.xavier_normal_(
59
+ self.layers[i].feed_forward.w1.weight, gain=deep_norm_beta
60
+ )
61
+ nn.init.xavier_normal_(
62
+ self.layers[i].feed_forward.w2.weight, gain=deep_norm_beta
63
+ )
64
+ nn.init.xavier_normal_(
65
+ self.layers[i].feed_forward.w3.weight, gain=deep_norm_beta
66
+ )
67
+
68
+ def forward_padding_mask(
69
+ self,
70
+ features: torch.Tensor,
71
+ padding_mask: torch.Tensor,
72
+ ) -> torch.Tensor:
73
+ extra = padding_mask.size(1) % features.size(1)
74
+ if extra > 0:
75
+ padding_mask = padding_mask[:, :-extra]
76
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
77
+ padding_mask = padding_mask.all(-1)
78
+ return padding_mask
79
+
80
+ def forward(
81
+ self,
82
+ wavs: torch.Tensor,
83
+ start_pos: int,
84
+ padding_mask: torch.Tensor = None,
85
+ fbank_mean: float = 15.41663,
86
+ fbank_std: float = 6.55582,
87
+ apply_mask: bool = False,
88
+ ):
89
+ # 1. Get input features
90
+ fbank = self.wav_to_fbank(wavs, fbank_mean=fbank_mean, fbank_std=fbank_std)
91
+ ctx = self.get_modulation_spectrum(wavs) # Side-channel (MSAB) features
92
+
93
+ # 2. Patchfy the input
94
+ features = self.feature_patchfy(fbank)
95
+
96
+ patch_padding_mask = None
97
+ if padding_mask is not None:
98
+ padding_mask = self.forward_padding_mask(features, padding_mask)
99
+ patch_padding_mask = padding_mask.clone()
100
+
101
+ ids_restore, kept_mask = None, None
102
+ if apply_mask:
103
+ B, T, F = features.shape
104
+ u = torch.rand(B, T, device=features.device)
105
+ to_mask = (u < self.cfg.mlm_mask_prob)
106
+
107
+ kept_mask = ~to_mask
108
+ features = features.masked_fill(~kept_mask.unsqueeze(-1), 0.0)
109
+
110
+ features = self.post_extract_proj(features)
111
+
112
+ _, seqlen, _ = features.shape
113
+
114
+ # 3. Apply positional encoding
115
+ if self.freqs_cis.device.type == "meta":
116
+ self.freqs_cis = self._get_freqs_cis()
117
+
118
+ self.freqs_cis = self.freqs_cis.to(features.device)
119
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
120
+
121
+ # 4. Apply transformer layers
122
+ x = self.dropout_input(features)
123
+
124
+ layer_results = []
125
+ for layer in self.layers:
126
+ x = layer(
127
+ x, start_pos=0, freqs_cis=freqs_cis, ctx=ctx, padding_mask=padding_mask
128
+ )
129
+ layer_results.append(x)
130
+
131
+ # 5. Apply post-processing
132
+ return x, layer_results, padding_mask, ids_restore, kept_mask, patch_padding_mask
133
+
134
+ def wav_to_fbank(
135
+ self,
136
+ source: torch.Tensor,
137
+ fbank_mean: float = -4.268,
138
+ fbank_std: float = 4.569,
139
+ ):
140
+ fbanks = []
141
+ for waveform in source:
142
+ waveform = waveform.unsqueeze(0) * 2**15
143
+ fbank = ta_kaldi.fbank(
144
+ waveform,
145
+ num_mel_bins=self.cfg.n_mels,
146
+ sample_frequency=self.cfg.sample_rate,
147
+ frame_length=self.cfg.frame_length,
148
+ frame_shift=self.cfg.frame_shift,
149
+ use_energy=False,
150
+ window_type="hanning",
151
+ dither=0.0,
152
+ )
153
+ fbanks.append(fbank)
154
+ fbank = torch.stack(fbanks, dim=0)
155
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
156
+ return fbank
157
+
158
+ def feature_patchfy(self, rep: torch.Tensor) -> torch.Tensor:
159
+ """
160
+ Patchify the feature representation.
161
+ """
162
+ rep = rep.unsqueeze(1)
163
+ features = self.patch_embedding(rep)
164
+ features = features.reshape(features.shape[0], features.shape[1], -1)
165
+ features = features.transpose(1, 2)
166
+ features = self.feature_norm(features)
167
+
168
+ return features
169
+
170
+ def _get_freqs_cis(self):
171
+ return precompute_freqs_cis(
172
+ self.cfg.hidden_size // self.cfg.num_query_heads,
173
+ self.cfg.max_seq_len * 2,
174
+ self.cfg.rope_theta,
175
+ )
176
+
177
+ @torch.no_grad()
178
+ def normalize_fft(self, spec_data, window, n_samples, n_fft, fs):
179
+ # Normalizations
180
+ win_rms = torch.sqrt(window.pow(2.0).sum() / n_samples)
181
+
182
+ # Compute the power spectrogram
183
+ spec_data /= win_rms
184
+ spec_data = spec_data.abs().pow(
185
+ 2.0
186
+ ) # same as X_pwr = abs(np.multiply(Xt, np.conj(Xt)))
187
+ spec_data *= 1.0 / n_fft**2 # make it orthonormal
188
+
189
+ if n_fft % 2 != 0:
190
+ n_freqs = (n_fft + 1) / 2
191
+ spec_data[
192
+ :, 1:, :
193
+ ] *= 2 # double all frequency components except DC component
194
+ else:
195
+ n_freqs = (n_fft / 2) + 1
196
+ spec_data[
197
+ :, 1:-1, :
198
+ ] *= 2 # double all frequency components except DC and fs/2 components
199
+
200
+ f_delta = fs / n_fft
201
+ spec_data = torch.divide(spec_data, f_delta) # scale by frequency delta
202
+
203
+ return f_delta, spec_data
204
+
205
+ @torch.no_grad()
206
+ def get_modulation_spectrum(self, wavs: torch.Tensor):
207
+ # number of samples and number of channels
208
+ _, n_samples = wavs.shape
209
+
210
+ # Step 1: compute STFT spectrogram
211
+ window = torch.hamming_window(
212
+ self.cfg.mss_win_size, periodic=True, device=wavs.device
213
+ )
214
+ spec_data = torch.stft(
215
+ wavs,
216
+ n_fft=self.cfg.mss_n_fft1,
217
+ win_length=self.cfg.mss_win_size,
218
+ hop_length=self.cfg.mss_win_shift,
219
+ window=window,
220
+ return_complex=True,
221
+ onesided=True,
222
+ ) # We add pad while old code remove the last window if necessary
223
+ _, _, n_windows = spec_data.shape
224
+
225
+ # Normalizations
226
+ _, spec_data = self.normalize_fft(
227
+ spec_data, window, n_samples, self.cfg.mss_n_fft1, self.cfg.sample_rate
228
+ )
229
+
230
+ # Step 2: Modulation Features
231
+ # modulation sampling frequency
232
+ fs_mod = 1 / (self.cfg.mss_win_shift / self.cfg.sample_rate)
233
+
234
+ n_fft2 = self.cfg.mss_n_fft2
235
+ if n_fft2 is None:
236
+ n_fft2 = n_windows
237
+
238
+ # the AM analysis is made in the Amplitude derived from the Power Spectrogram
239
+ window = torch.hamming_window(n_windows, periodic=True, device=wavs.device)
240
+ spec_data = torch.multiply(spec_data, window)
241
+ mod_psd = torch.fft.rfft(spec_data, n=n_fft2, dim=2)
242
+
243
+ _, mod_psd = self.normalize_fft(mod_psd, window, n_samples, n_fft2, fs_mod)
244
+
245
+ return torch.cat([mod_psd.mean(dim=1), mod_psd.mean(dim=2)], dim=1)