adambuttrick commited on
Commit
c099594
·
verified ·
1 Parent(s): ad18509

Upload block.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. block.py +470 -0
block.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
2
+ # Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
3
+
4
+ # Copyright (c) 2024, Tri Dao.
5
+
6
+ from functools import partial
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.fx
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+
15
+ from .mha import MHA
16
+ from .mlp import Mlp
17
+
18
+ try:
19
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
+ except ImportError:
21
+ layer_norm_fn, RMSNorm = None, None
22
+
23
+
24
+ def stochastic_depth(
25
+ input: Tensor, p: float, mode: str, training: bool = True
26
+ ) -> Tensor:
27
+ """
28
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
29
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
30
+ branches of residual architectures.
31
+ Args:
32
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
33
+ being its batch i.e. a batch with ``N`` rows.
34
+ p (float): probability of the input to be zeroed.
35
+ mode (str): ``"batch"`` or ``"row"``.
36
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
37
+ randomly selected rows from the batch.
38
+ training: apply stochastic depth if is ``True``. Default: ``True``
39
+ Returns:
40
+ Tensor[N, ...]: The randomly zeroed tensor.
41
+ """
42
+ if p < 0.0 or p > 1.0:
43
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
44
+ if mode not in ["batch", "row"]:
45
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
46
+ if not training or p == 0.0:
47
+ return input
48
+
49
+ survival_rate = 1.0 - p
50
+ if mode == "row":
51
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
52
+ else:
53
+ size = [1] * input.ndim
54
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
55
+ noise = noise.bernoulli_(survival_rate)
56
+ if survival_rate > 0.0:
57
+ noise.div_(survival_rate)
58
+ return input * noise
59
+
60
+
61
+ torch.fx.wrap("stochastic_depth")
62
+
63
+
64
+ class StochasticDepth(nn.Module):
65
+ """
66
+ See :func:`stochastic_depth`.
67
+ """
68
+
69
+ def __init__(self, p: float, mode: str) -> None:
70
+ super().__init__()
71
+ self.p = p
72
+ self.mode = mode
73
+
74
+ def forward(self, input: Tensor) -> Tensor:
75
+ return stochastic_depth(input, self.p, self.mode, self.training)
76
+
77
+ def __repr__(self) -> str:
78
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
79
+ return s
80
+
81
+
82
+ class Block(nn.Module):
83
+ def __init__(
84
+ self,
85
+ dim,
86
+ mixer_cls=None,
87
+ mlp_cls=None,
88
+ norm_cls=nn.LayerNorm,
89
+ dropout_cls=nn.Dropout,
90
+ prenorm=True,
91
+ resid_dropout1=0.0,
92
+ resid_dropout2=0.0,
93
+ drop_path1=0.0,
94
+ drop_path2=0.0,
95
+ fused_dropout_add_ln=False,
96
+ return_residual=False,
97
+ residual_in_fp32=False,
98
+ sequence_parallel=False,
99
+ mark_shared_params=False,
100
+ ):
101
+ """
102
+ For prenorm=True, this Block has a slightly different structure compared to a regular
103
+ prenorm Transformer block.
104
+ The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
105
+ [Ref: https://arxiv.org/abs/2002.04745]
106
+ Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
107
+ the hidden_states (output of the MLP) and the residual.
108
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
109
+ The residual needs to be provided (except for the very first block).
110
+
111
+ For prenorm=False, this Block has the same structure as a regular postnorm Transformer
112
+ block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
113
+
114
+ return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
115
+ This is for performance reason: for post-norm architecture, returning the input allows us
116
+ to fuse the backward of nn.Linear with the residual connection.
117
+ """
118
+ super().__init__()
119
+ self.prenorm = prenorm
120
+ self.fused_dropout_add_ln = fused_dropout_add_ln
121
+ self.return_residual = return_residual
122
+ self.residual_in_fp32 = residual_in_fp32
123
+ if self.residual_in_fp32:
124
+ assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
125
+ if mixer_cls is None:
126
+ mixer_cls = partial(MHA, num_heads=dim // 64)
127
+ if mlp_cls is None:
128
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
129
+ self.mixer = mixer_cls(dim)
130
+ self.dropout1 = dropout_cls(resid_dropout1)
131
+ self.drop_path1 = StochasticDepth(drop_path1, mode="row")
132
+ self.norm1 = norm_cls(dim)
133
+ self.mlp = mlp_cls(dim)
134
+ if not isinstance(self.mlp, nn.Identity):
135
+ self.dropout2 = dropout_cls(resid_dropout2)
136
+ self.drop_path2 = StochasticDepth(drop_path2, mode="row")
137
+ self.norm2 = norm_cls(dim)
138
+
139
+ if self.fused_dropout_add_ln:
140
+ assert layer_norm_fn is not None, "Triton is not installed"
141
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
142
+ self.dropout1, nn.Dropout
143
+ )
144
+
145
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
146
+ # then the input to each worker in the tensor parallel group will be different.
147
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
148
+ # For now this is not an issue because we always use sequence_parallel=True during training
149
+ # and only use sequence_parallel=False during inference.
150
+
151
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
152
+ if sequence_parallel:
153
+ for p in self.norm1.parameters():
154
+ p._sequence_parallel = True
155
+ if hasattr(self, "norm2"):
156
+ for p in self.norm2.parameters():
157
+ p._sequence_parallel = True
158
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
159
+ if mark_shared_params:
160
+ for p in self.norm1.parameters():
161
+ p._shared_params = True
162
+ if hasattr(self, "norm2"):
163
+ for p in self.norm2.parameters():
164
+ p._shared_params = True
165
+
166
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
167
+ return self.mixer.allocate_inference_cache(
168
+ batch_size, max_seqlen, dtype=dtype, **kwargs
169
+ )
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: Tensor,
174
+ residual: Optional[Tensor] = None,
175
+ mixer_subset=None,
176
+ mixer_kwargs=None,
177
+ ):
178
+ r"""Pass the input through the encoder layer.
179
+
180
+ Args:
181
+ hidden_states: the sequence to the encoder layer (required).
182
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
183
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
184
+ before applying the query projection. Useful for e.g., ViT where we only care
185
+ about the CLS token in the last layer.
186
+ """
187
+ if self.prenorm:
188
+ if not self.fused_dropout_add_ln:
189
+ dropped = self.drop_path1(self.dropout1(hidden_states))
190
+ residual = (dropped + residual) if residual is not None else dropped
191
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
192
+ if self.residual_in_fp32:
193
+ residual = residual.to(torch.float32)
194
+ else:
195
+ if self.drop_path1.p == 0 or not self.training:
196
+ rowscale1 = None
197
+ else:
198
+ rowscale1 = self.drop_path1(
199
+ torch.ones(
200
+ hidden_states.shape[:-1],
201
+ device=hidden_states.device,
202
+ dtype=hidden_states.dtype,
203
+ )
204
+ )
205
+ hidden_states, residual = layer_norm_fn(
206
+ hidden_states,
207
+ self.norm1.weight,
208
+ self.norm1.bias,
209
+ residual=residual,
210
+ eps=self.norm1.eps,
211
+ dropout_p=self.dropout1.p if self.training else 0.0,
212
+ rowscale=rowscale1,
213
+ prenorm=True,
214
+ residual_in_fp32=self.residual_in_fp32,
215
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
216
+ )
217
+ if mixer_kwargs is None:
218
+ mixer_kwargs = {}
219
+ if mixer_subset is not None:
220
+ mixer_kwargs["mixer_subset"] = mixer_subset
221
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
222
+ if mixer_subset is not None:
223
+ residual = residual[:, mixer_subset]
224
+ if not isinstance(self.mlp, nn.Identity):
225
+ if not self.fused_dropout_add_ln:
226
+ dropped = self.drop_path2(self.dropout2(hidden_states))
227
+ residual = (dropped + residual) if residual is not None else dropped
228
+ hidden_states = self.norm2(
229
+ residual.to(dtype=self.norm2.weight.dtype)
230
+ )
231
+ if self.residual_in_fp32:
232
+ residual = residual.to(torch.float32)
233
+ else:
234
+ if self.drop_path2.p == 0 or not self.training:
235
+ rowscale2 = None
236
+ else:
237
+ rowscale2 = self.drop_path2(
238
+ torch.ones(
239
+ hidden_states.shape[:-1],
240
+ device=hidden_states.device,
241
+ dtype=hidden_states.dtype,
242
+ )
243
+ )
244
+ hidden_states, residual = layer_norm_fn(
245
+ hidden_states,
246
+ self.norm2.weight,
247
+ self.norm2.bias,
248
+ residual=residual,
249
+ eps=self.norm2.eps,
250
+ dropout_p=self.dropout2.p if self.training else 0.0,
251
+ rowscale=rowscale2,
252
+ prenorm=True,
253
+ residual_in_fp32=self.residual_in_fp32,
254
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
255
+ )
256
+ hidden_states = self.mlp(hidden_states)
257
+ return hidden_states, residual
258
+ else:
259
+ assert residual is None
260
+ mixer_out = self.mixer(
261
+ hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
262
+ )
263
+ if self.return_residual: # mixer out is actually a pair here
264
+ mixer_out, hidden_states = mixer_out
265
+ if not self.fused_dropout_add_ln:
266
+ hidden_states = self.norm1(
267
+ (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
268
+ dtype=self.norm1.weight.dtype
269
+ )
270
+ )
271
+ else:
272
+ if self.drop_path1.p == 0 or not self.training:
273
+ rowscale1 = None
274
+ else:
275
+ rowscale1 = self.drop_path1(
276
+ torch.ones(
277
+ mixer_out.shape[:-1],
278
+ device=mixer_out.device,
279
+ dtype=mixer_out.dtype,
280
+ )
281
+ )
282
+ hidden_states = layer_norm_fn(
283
+ mixer_out,
284
+ self.norm1.weight,
285
+ self.norm1.bias,
286
+ residual=hidden_states,
287
+ eps=self.norm1.eps,
288
+ dropout_p=self.dropout1.p if self.training else 0.0,
289
+ rowscale=rowscale1,
290
+ prenorm=False,
291
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
292
+ )
293
+ if not isinstance(self.mlp, nn.Identity):
294
+ mlp_out = self.mlp(hidden_states)
295
+ if self.return_residual: # mlp out is actually a pair here
296
+ mlp_out, hidden_states = mlp_out
297
+ if not self.fused_dropout_add_ln:
298
+ hidden_states = self.norm2(
299
+ (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
300
+ dtype=self.norm2.weight.dtype
301
+ )
302
+ )
303
+ else:
304
+ if self.drop_path2.p == 0 or not self.training:
305
+ rowscale2 = None
306
+ else:
307
+ rowscale2 = self.drop_path2(
308
+ torch.ones(
309
+ mlp_out.shape[:-1],
310
+ device=mlp_out.device,
311
+ dtype=mlp_out.dtype,
312
+ )
313
+ )
314
+ hidden_states = layer_norm_fn(
315
+ mlp_out,
316
+ self.norm2.weight,
317
+ self.norm2.bias,
318
+ residual=hidden_states,
319
+ eps=self.norm2.eps,
320
+ dropout_p=self.dropout2.p if self.training else 0.0,
321
+ rowscale=rowscale2,
322
+ prenorm=False,
323
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
324
+ )
325
+ return hidden_states
326
+
327
+
328
+ class ParallelBlock(nn.Module):
329
+ """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
330
+ and PaLM.
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ dim,
336
+ mixer_cls=None,
337
+ mlp_cls=None,
338
+ norm_cls=nn.LayerNorm,
339
+ dropout_cls=nn.Dropout,
340
+ resid_dropout1=0.0,
341
+ resid_dropout2=0.0,
342
+ tied_norm=False,
343
+ fused_dropout_add_ln=False,
344
+ residual_in_fp32=False,
345
+ sequence_parallel=False,
346
+ mark_shared_params=False,
347
+ ):
348
+ """
349
+ This Block has a slightly different structure compared to a regular
350
+ prenorm Transformer block.
351
+ The standard block is: LN -> MHA / MLP -> Dropout -> Add.
352
+ [Ref: https://arxiv.org/abs/2002.04745]
353
+ Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
354
+ the hidden_states (output1 of the MHA / MLP) and the residual.
355
+ This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
356
+ The residual needs to be provided (except for the very first block).
357
+ """
358
+ super().__init__()
359
+ self.tied_norm = tied_norm
360
+ self.fused_dropout_add_ln = fused_dropout_add_ln
361
+ self.residual_in_fp32 = residual_in_fp32
362
+ if mixer_cls is None:
363
+ mixer_cls = partial(MHA, num_heads=dim // 64)
364
+ if mlp_cls is None:
365
+ mlp_cls = partial(Mlp, hidden_features=4 * dim)
366
+ self.mixer = mixer_cls(dim)
367
+ self.dropout1 = dropout_cls(resid_dropout1)
368
+ self.norm1 = norm_cls(dim)
369
+ self.mlp = mlp_cls(dim)
370
+ self.dropout2 = dropout_cls(resid_dropout2)
371
+ if not self.tied_norm:
372
+ self.norm2 = norm_cls(dim)
373
+
374
+ if self.fused_dropout_add_ln:
375
+ assert layer_norm_fn is not None, "Triton is not installed"
376
+ assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
377
+ self.dropout1, nn.Dropout
378
+ )
379
+
380
+ # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
381
+ # then the input to each worker in the tensor parallel group will be different.
382
+ # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
383
+ # For now this is not an issue because we always use sequence_parallel=True during training
384
+ # and only use sequence_parallel=False during inference.
385
+
386
+ # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
387
+ if sequence_parallel:
388
+ for p in self.norm1.parameters():
389
+ p._sequence_parallel = True
390
+ if hasattr(self, "norm2"):
391
+ for p in self.norm2.parameters():
392
+ p._sequence_parallel = True
393
+ # Mark the norm parameters as "shared_params" so that we sync their values at init.
394
+ if mark_shared_params:
395
+ for p in self.norm1.parameters():
396
+ p._shared_params = True
397
+ if hasattr(self, "norm2"):
398
+ for p in self.norm2.parameters():
399
+ p._shared_params = True
400
+
401
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
402
+ return self.mixer.allocate_inference_cache(
403
+ batch_size, max_seqlen, dtype=dtype, **kwargs
404
+ )
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states1: Tensor,
409
+ hidden_states2: Optional[Tensor] = None,
410
+ residual: Optional[Tensor] = None,
411
+ mixer_kwargs=None,
412
+ ):
413
+ r"""Pass the input through the encoder layer.
414
+
415
+ Args:
416
+ hidden_states1: the output of the previous attention (mixer) or embedding layer.
417
+ hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
418
+ residual.
419
+ """
420
+ # TODO: Ideally we should only do the allgather / allreduce once for
421
+ # the Linear to MLP & Attention
422
+ if not self.fused_dropout_add_ln:
423
+ dropped1 = self.dropout1(hidden_states1)
424
+ # For the very 1st block, we only want 1 dropout, not two different dropouts
425
+ if hidden_states2 is not None:
426
+ dropped2 = self.dropout2(hidden_states2)
427
+ residual = (
428
+ (residual + dropped1 + dropped2)
429
+ if residual is not None
430
+ else dropped1 + dropped2
431
+ )
432
+ else:
433
+ residual = (residual + dropped1) if residual is not None else dropped1
434
+ hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
435
+ hidden_states2 = (
436
+ self.norm2(residual.to(dtype=self.norm2.weight.dtype))
437
+ if not self.tied_norm
438
+ else hidden_states1
439
+ )
440
+ if self.residual_in_fp32:
441
+ residual = residual.to(torch.float32)
442
+ else:
443
+ weight2, bias2 = (
444
+ (self.norm2.weight, self.norm2.bias)
445
+ if not self.tied_norm
446
+ else (None, None)
447
+ )
448
+ hidden_states1, *rest, residual = layer_norm_fn(
449
+ hidden_states1,
450
+ self.norm1.weight,
451
+ self.norm1.bias,
452
+ residual=residual,
453
+ x1=hidden_states2,
454
+ weight1=weight2,
455
+ bias1=bias2,
456
+ eps=self.norm1.eps,
457
+ dropout_p=self.dropout1.p if self.training else 0.0,
458
+ prenorm=True,
459
+ residual_in_fp32=self.residual_in_fp32,
460
+ is_rms_norm=isinstance(self.norm1, RMSNorm),
461
+ )
462
+ if self.tied_norm:
463
+ hidden_states2 = hidden_states1
464
+ else:
465
+ (hidden_states2,) = rest
466
+ if mixer_kwargs is None:
467
+ mixer_kwargs = {}
468
+ hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
469
+ hidden_states2 = self.mlp(hidden_states2)
470
+ return hidden_states1, hidden_states2, residual