File size: 2,393 Bytes
ab876da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Add additional input embedding

Includes [PR#119](https://github.com/KellerJordan/modded-nanogpt/pull/119).

Previously, modded-nanogpt medium added x0 to the residual at the input of every layer:

```python
# GPT
def forward(self, input_seq, ...):
    ...

    x = x0 = norm(self.embed(input_seq[None]))

    ...

    for i in range(len(self.blocks)):
        ...
        x = self.blocks[i](x, x0, lambdas, ...)
        ...

# Block
def forward(self, x, x0, lambdas, ...):
    x = lambdas[0] * x + lambdas[1] * x0
```

Where `lambdas` are learned scalars.

This update adds another embedding module and adds it at every layer:

```python
# GPT
def forward(self, input_seq, ...):
    ...

    x = x00 = norm(self.embed1(input_seq[None]))
    x01 = norm(self.embed2(input_seq[None]))

    ...

    for i in range(len(self.blocks)):
        ...
        x = self.blocks[i](x, x00, x01, lambdas, ...)
        ...

# Block
def forward(self, x, x00, x01, lambdas, ...):
    x = lambdas[0] * x + lambdas[1] * x00 + lambdas[2] * x02
```

While this slows down training, it increases learning per step, thus allowing us to reduce the step count to 5690.

Here are the resulting final validation losses over 19 runs:

```python
[2.919502, 2.91976, 2.920582, 2.919331, 2.919008, 2.919827, 2.918785, 2.918519, 2.919297, 2.920061, 2.918938, 2.919342, 2.918186, 2.920546, 2.91954, 2.919093, 2.918951, 2.919599, 2.919956]
```

And these are the basic stats:

- Mean: 2.9194117368421053
- Median: 2.919342
- Std: 0.000613243648848653
- Min: 2.918186
- Max: 2.920582

And t-test results:

```python
{
    'n': 19,
    'sample_mean': 2.9194117368421053,
    'sample_std': 0.0006300479560324045,
    't_stat': -4.069816643193549,
    'p_value': 0.00035946114919240566,
    'alpha': 0.05,
    'decision': 'REJECT H0 (mean < threshold)',
    'upper_conf_bound_mean': 2.919662383449226,
    'threshold': 2.92
}
```

The final loss is below 2.92 with >99% likelihood.

Here are the corresponding run-times in seconds:

```python
[1414.299, 1412.033, 1411.668, 1421.735, 1411.998, 1411.094, 1412.637, 1410.047, 1410.509, 1412.048, 1411.574, 1415.299, 1411.649, 1412.94, 1412.508, 1410.912, 1415.296, 1410.778, 1407.511]
```

Leading to the following stats:

- Mean: 1412.4492105263155
- Median: 1411.998
- Std: 2.8062021268488864
- Min: 1407.511
- Max: 1421.735

The mean time is ~1412.5 seconds, or 23.54 minutes.