File size: 34,902 Bytes
8633758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375

-------------------------- DeepSpeed Flops Profiler --------------------------
Profile Summary at step 2:
Notations:
data parallel size (dp_size), model parallel size(mp_size),
number of parameters (params), number of multiply-accumulate operations(MACs),
number of floating-point operations (flops), floating-point operations per second (FLOPS),
fwd latency (forward propagation latency), bwd latency (backward propagation latency),
step (weights update latency), iter latency (sum of fwd, bwd and step latency)

world size:                                                             16      
data parallel size:                                                     16      
model parallel size:                                                    1       
batch size per GPU:                                                     32      
params per GPU:                                                         1.03 B  
params of model = params per GPU * mp_size:                             1.03 B  
fwd MACs per GPU:                                                       6.66 TMACs
fwd flops per GPU:                                                      13.32 T 
fwd flops of model = fwd flops per GPU * mp_size:                       13.32 T 
fwd latency:                                                            88.54 ms
fwd FLOPS per GPU = fwd flops per GPU / fwd latency:                    150.44 TFLOPS
bwd latency:                                                            206.11 ms
bwd FLOPS per GPU = 2 * fwd flops per GPU / bwd latency:                129.24 TFLOPS
fwd+bwd FLOPS per GPU = 3 * fwd flops per GPU / (fwd+bwd latency):      135.61 TFLOPS
step latency:                                                           57.92 ms
iter latency:                                                           352.56 ms
FLOPS per GPU = 3 * fwd flops per GPU / iter latency:                   113.33 TFLOPS
samples/second:                                                         1452.22 

----------------------------- Aggregated Profile per GPU -----------------------------
Top 1 modules in terms of params, MACs or fwd latency at different model depths:
depth 0:
    params      - {'DiT': '1.03 B'}
    MACs        - {'DiT': '6.66 TMACs'}
    fwd latency - {'DiT': '88.36 ms'}
depth 1:
    params      - {'ModuleList': '1.01 B'}
    MACs        - {'ModuleList': '6.54 TMACs'}
    fwd latency - {'ModuleList': '83.5 ms'}
depth 2:
    params      - {'DiTLayer': '1.01 B'}
    MACs        - {'DiTLayer': '6.54 TMACs'}
    fwd latency - {'DiTLayer': '83.5 ms'}
depth 3:
    params      - {'GemmaMLP': '503.32 M'}
    MACs        - {'GemmaMLP': '4.12 TMACs'}
    fwd latency - {'DiTSelfAttention': '43.56 ms'}

------------------------------ Detailed Profile per GPU ------------------------------
Each module profile is listed after its name in the following order: 
params, percentage of total params, MACs, percentage of total MACs, fwd latency, percentage of total fwd latency, fwd FLOPS

Note: 1. A module can have torch.nn.module or torch.nn.functional to compute logits (e.g. CrossEntropyLoss). They are not counted as submodules, thus not to be printed out. However they make up the difference between a parent's MACs (or latency) and the sum of its submodules'.
2. Number of floating-point operations is a theoretical estimation, thus FLOPS computed using that could be larger than the maximum system throughput.
3. The fwd latency listed in the top module's profile is directly captured at the module forward function in PyTorch, thus it's less than the fwd latency shown above which is captured in DeepSpeed.

DiT(
  1.03 B = 100% Params, 6.66 TMACs = 100% MACs, 88.36 ms = 100% latency, 150.74 TFLOPS
  (layers): ModuleList(
    (0): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.43 ms = 9.54% latency, 155.13 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 901.94 us = 1.02% latency, 1.79 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 37.67 us = 0.04% latency, 1.74 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 176.91 us = 0.2% latency, 9.1 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 461.34 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.39 ms = 4.97% latency, 109.58 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 860.93 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 487.33 us = 0.55% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 367.64 us = 0.42% latency, 373.84 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 203.85 us = 0.23% latency, 168.56 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 200.03 us = 0.23% latency, 171.77 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 164.51 us = 0.19% latency, 208.86 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 162.36 us = 0.18% latency, 211.62 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 367.4 us = 0.42% latency, 374.08 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 459.19 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 2 ms = 2.27% latency, 411.45 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 536.92 us = 0.61% latency, 511.95 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 527.86 us = 0.6% latency, 520.74 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 514.03 us = 0.58% latency, 534.75 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 156.88 us = 0.18% latency, 427.77 GFLOPS)
      )
    )
    (1): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.35 ms = 9.45% latency, 156.66 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 877.86 us = 0.99% latency, 1.83 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 36.72 us = 0.04% latency, 1.78 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 169.52 us = 0.19% latency, 9.5 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 458.96 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.36 ms = 4.93% latency, 110.38 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 860.45 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 493.29 us = 0.56% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 344.75 us = 0.39% latency, 398.66 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 201.23 us = 0.23% latency, 170.75 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 198.36 us = 0.22% latency, 173.22 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 164.99 us = 0.19% latency, 208.26 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 162.84 us = 0.18% latency, 211 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 373.36 us = 0.42% latency, 368.11 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 458.24 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 1.99 ms = 2.25% latency, 414.41 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 536.2 us = 0.61% latency, 512.64 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 526.67 us = 0.6% latency, 521.92 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 508.79 us = 0.58% latency, 540.26 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 148.53 us = 0.17% latency, 451.81 GFLOPS)
      )
    )
    (2): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.36 ms = 9.46% latency, 156.42 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 878.57 us = 0.99% latency, 1.83 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 36 us = 0.04% latency, 1.82 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 167.37 us = 0.19% latency, 9.62 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 460.15 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.36 ms = 4.94% latency, 110.22 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 860.21 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 491.62 us = 0.56% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 351.67 us = 0.4% latency, 390.82 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 203.61 us = 0.23% latency, 168.75 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 199.32 us = 0.23% latency, 172.39 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 167.13 us = 0.19% latency, 205.59 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 165.7 us = 0.19% latency, 207.36 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 361.44 us = 0.41% latency, 380.25 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 457.29 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 2 ms = 2.26% latency, 413.27 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 539.78 us = 0.61% latency, 509.24 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 524.76 us = 0.59% latency, 523.82 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 510.69 us = 0.58% latency, 538.25 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 147.1 us = 0.17% latency, 456.2 GFLOPS)
      )
    )
    (3): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.33 ms = 9.43% latency, 156.87 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 878.57 us = 0.99% latency, 1.83 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 36.24 us = 0.04% latency, 1.81 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 169.99 us = 0.19% latency, 9.47 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 457.76 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.35 ms = 4.92% latency, 110.58 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 858.55 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 487.57 us = 0.55% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 353.81 us = 0.4% latency, 388.45 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 203.37 us = 0.23% latency, 168.95 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 198.13 us = 0.22% latency, 173.42 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 166.42 us = 0.19% latency, 206.47 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 163.32 us = 0.18% latency, 210.39 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 356.67 us = 0.4% latency, 385.33 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 462.29 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 1.98 ms = 2.24% latency, 416.25 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 533.34 us = 0.6% latency, 515.39 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 526.67 us = 0.6% latency, 521.92 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 507.59 us = 0.57% latency, 541.53 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 144.48 us = 0.16% latency, 464.48 GFLOPS)
      )
    )
    (4): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.33 ms = 9.42% latency, 157 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 867.37 us = 0.98% latency, 1.86 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 30.76 us = 0.03% latency, 2.13 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 153.78 us = 0.17% latency, 10.47 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 462.29 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.36 ms = 4.93% latency, 110.41 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 861.41 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 490.9 us = 0.56% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 346.9 us = 0.39% latency, 396.19 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 200.75 us = 0.23% latency, 171.16 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 197.89 us = 0.22% latency, 173.63 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 167.13 us = 0.19% latency, 205.59 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 166.65 us = 0.19% latency, 206.17 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 359.06 us = 0.41% latency, 382.78 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 459.91 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 1.98 ms = 2.24% latency, 416.05 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 529.77 us = 0.6% latency, 518.87 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 527.62 us = 0.6% latency, 520.98 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 508.07 us = 0.58% latency, 541.02 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 145.91 us = 0.17% latency, 459.93 GFLOPS)
      )
    )
    (5): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.33 ms = 9.43% latency, 156.87 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 874.52 us = 0.99% latency, 1.84 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 35.52 us = 0.04% latency, 1.84 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 166.89 us = 0.19% latency, 9.65 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 457.76 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.35 ms = 4.92% latency, 110.68 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 859.98 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 489.23 us = 0.55% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 345.47 us = 0.39% latency, 397.83 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 201.23 us = 0.23% latency, 170.75 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 199.56 us = 0.23% latency, 172.18 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 167.13 us = 0.19% latency, 205.59 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 164.75 us = 0.19% latency, 208.56 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 361.68 us = 0.41% latency, 380 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 458 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 1.99 ms = 2.25% latency, 414.51 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 534.77 us = 0.61% latency, 514.01 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 525 us = 0.59% latency, 523.58 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 512.12 us = 0.58% latency, 536.74 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 144.72 us = 0.16% latency, 463.71 GFLOPS)
      )
    )
    (6): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.36 ms = 9.47% latency, 156.33 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 886.68 us = 1% latency, 1.82 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 37.91 us = 0.04% latency, 1.73 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 171.9 us = 0.19% latency, 9.37 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 458.96 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.36 ms = 4.94% latency, 110.23 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 863.31 us = 0.98% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 493.05 us = 0.56% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 342.85 us = 0.39% latency, 400.88 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 201.46 us = 0.23% latency, 170.55 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 199.32 us = 0.23% latency, 172.39 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 168.56 us = 0.19% latency, 203.84 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 168.32 us = 0.19% latency, 204.13 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 360.73 us = 0.41% latency, 381.01 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 460.62 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 1.98 ms = 2.24% latency, 415.8 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 535.49 us = 0.61% latency, 513.32 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 524.04 us = 0.59% latency, 524.53 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 508.31 us = 0.58% latency, 540.77 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 145.44 us = 0.16% latency, 461.43 GFLOPS)
      )
    )
    (7): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.35 ms = 9.45% latency, 156.52 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 875.95 us = 0.99% latency, 1.84 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 37.19 us = 0.04% latency, 1.76 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 168.09 us = 0.19% latency, 9.58 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 459.19 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.35 ms = 4.93% latency, 110.52 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 860.21 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 489.47 us = 0.55% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 342.85 us = 0.39% latency, 400.88 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 203.13 us = 0.23% latency, 169.15 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 200.27 us = 0.23% latency, 171.57 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 170.23 us = 0.19% latency, 201.84 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 166.42 us = 0.19% latency, 206.47 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 362.87 us = 0.41% latency, 378.75 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 458.72 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 2 ms = 2.26% latency, 412.68 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 539.3 us = 0.61% latency, 509.69 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 526.43 us = 0.6% latency, 522.16 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 514.27 us = 0.58% latency, 534.5 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 146.39 us = 0.17% latency, 458.43 GFLOPS)
      )
    )
    (8): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.33 ms = 9.42% latency, 157.03 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 887.87 us = 1% latency, 1.81 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 38.15 us = 0.04% latency, 1.72 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 171.9 us = 0.19% latency, 9.37 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 463.25 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.34 ms = 4.92% latency, 110.77 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 860.21 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 484.94 us = 0.55% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 349.52 us = 0.4% latency, 393.22 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 201.7 us = 0.23% latency, 170.35 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 197.65 us = 0.22% latency, 173.84 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 167.13 us = 0.19% latency, 205.59 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 165.46 us = 0.19% latency, 207.66 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 355.48 us = 0.4% latency, 386.63 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 457.53 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 1.98 ms = 2.24% latency, 416.45 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 530.72 us = 0.6% latency, 517.93 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 524.76 us = 0.59% latency, 523.82 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 508.07 us = 0.58% latency, 541.02 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 144.96 us = 0.16% latency, 462.95 GFLOPS)
      )
    )
    (9): DiTLayer(
      100.68 M = 9.73% Params, 653.64 GMACs = 9.82% MACs, 8.33 ms = 9.43% latency, 156.96 TFLOPS
      (input_layernorm): AdaLayerNormZero(
        25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 887.63 us = 1% latency, 1.81 TFLOPS
        (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 38.39 us = 0.04% latency, 1.71 GFLOPS)
        (linear): Linear(25.18 M = 2.43% Params, 805.31 MMACs = 0.01% MACs, 169.52 us = 0.19% latency, 9.5 TFLOPS, in_features=2048, out_features=12288, bias=True)
        (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 460.62 us = 0.52% latency, 0 FLOPS)
      )
      (self_attn): DiTSelfAttention(
        25.17 M = 2.43% Params, 240.52 GMACs = 3.61% MACs, 4.34 ms = 4.91% latency, 110.89 TFLOPS
        (q_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 859.02 us = 0.97% latency, 0 FLOPS)
        (k_norm): GemmaRMSNorm(128 = 0% Params, 0 MACs = 0% MACs, 490.9 us = 0.56% latency, 0 FLOPS)
        (q_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 352.86 us = 0.4% latency, 389.5 TFLOPS, in_features=2048, out_features=4096, bias=False)
        (k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 201.23 us = 0.23% latency, 170.75 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 198.13 us = 0.22% latency, 173.42 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_k_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 163.79 us = 0.19% latency, 209.77 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (text_v_proj): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 160.69 us = 0.18% latency, 213.82 TFLOPS, in_features=2048, out_features=1024, bias=False)
        (o_proj): Linear(8.39 M = 0.81% Params, 68.72 GMACs = 1.03% MACs, 359.3 us = 0.41% latency, 382.52 TFLOPS, in_features=4096, out_features=2048, bias=False)
      )
      (post_attention_layernorm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 457.53 us = 0.52% latency, 0 FLOPS)
      (mlp): GemmaMLP(
        50.33 M = 4.86% Params, 412.32 GMACs = 6.19% MACs, 1.98 ms = 2.24% latency, 417.46 TFLOPS
        (gate_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 531.2 us = 0.6% latency, 517.47 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 524.52 us = 0.59% latency, 524.06 TFLOPS, in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(16.78 M = 1.62% Params, 137.44 GMACs = 2.06% MACs, 504.26 us = 0.57% latency, 545.12 TFLOPS, in_features=8192, out_features=2048, bias=False)
        (act_fn): PytorchGELUTanh(0 = 0% Params, 0 MACs = 0% MACs, 145.44 us = 0.16% latency, 461.43 GFLOPS)
      )
    )
  )
  (patch_embed): PatchEmbed(
    133.12 K = 0.01% Params, 1.07 GMACs = 0.02% MACs, 628.95 us = 0.71% latency, 3.44 TFLOPS
    (proj): Conv2d(133.12 K = 0.01% Params, 1.07 GMACs = 0.02% MACs, 374.79 us = 0.42% latency, 5.77 TFLOPS, 16, 2048, kernel_size=(2, 2), stride=(2, 2))
  )
  (rotary_emb): GemmaRotaryEmbedding(0 = 0% Params, 0 MACs = 0% MACs, 0 s = 0% latency, 0 FLOPS)
  (time_proj): Timesteps(0 = 0% Params, 0 MACs = 0% MACs, 258.68 us = 0.29% latency, 0 FLOPS)
  (timestep_embedder): Sequential(
    4.72 M = 0.46% Params, 150.99 MMACs = 0% MACs, 543.36 us = 0.61% latency, 555.91 GFLOPS
    (0): Linear(526.34 K = 0.05% Params, 16.78 MMACs = 0% MACs, 233.41 us = 0.26% latency, 143.76 GFLOPS, in_features=256, out_features=2048, bias=True)
    (1): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 43.63 us = 0.05% latency, 1.5 GFLOPS)
    (2): Linear(4.2 M = 0.41% Params, 134.22 MMACs = 0% MACs, 187.64 us = 0.21% latency, 1.43 TFLOPS, in_features=2048, out_features=2048, bias=True)
  )
  (context_embedder): Sequential(
    4.2 M = 0.41% Params, 34.36 GMACs = 0.52% MACs, 607.73 us = 0.69% latency, 113.08 TFLOPS
    (0): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 279.66 us = 0.32% latency, 0 FLOPS)
    (1): Linear(4.2 M = 0.41% Params, 34.36 GMACs = 0.52% MACs, 271.56 us = 0.31% latency, 253.06 TFLOPS, in_features=2048, out_features=2048, bias=True)
  )
  (norm_out): AdaLayerNormOut(
    8.39 M = 0.81% Params, 268.44 MMACs = 0% MACs, 849.96 us = 0.96% latency, 631.72 GFLOPS
    (silu): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 36.72 us = 0.04% latency, 1.78 GFLOPS)
    (linear): Linear(8.39 M = 0.81% Params, 268.44 MMACs = 0% MACs, 148.3 us = 0.17% latency, 3.62 TFLOPS, in_features=2048, out_features=4096, bias=True)
    (norm): GemmaRMSNorm(2.05 K = 0% Params, 0 MACs = 0% MACs, 456.09 us = 0.52% latency, 0 FLOPS)
  )
  (proj_out): Linear(131.14 K = 0.01% Params, 1.07 GMACs = 0.02% MACs, 183.58 us = 0.21% latency, 11.7 TFLOPS, in_features=2048, out_features=64, bias=True)
  (repa_projector): Sequential(
    10.49 M = 1.01% Params, 85.9 GMACs = 1.29% MACs, 908.61 us = 1.03% latency, 189.11 TFLOPS
    (0): Linear(4.2 M = 0.41% Params, 34.36 GMACs = 0.52% MACs, 303.03 us = 0.34% latency, 226.77 TFLOPS, in_features=2048, out_features=2048, bias=True)
    (1): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 50.31 us = 0.06% latency, 333.5 GFLOPS)
    (2): Linear(4.2 M = 0.41% Params, 34.36 GMACs = 0.52% MACs, 235.32 us = 0.27% latency, 292.03 TFLOPS, in_features=2048, out_features=2048, bias=True)
    (3): SiLU(0 = 0% Params, 0 MACs = 0% MACs, 46.49 us = 0.05% latency, 360.87 GFLOPS)
    (4): Linear(2.1 M = 0.2% Params, 17.18 GMACs = 0.26% MACs, 169.04 us = 0.19% latency, 203.27 TFLOPS, in_features=2048, out_features=1024, bias=True)
  )
)
------------------------------------------------------------------------------