Add additional input embedding
Includes PR#119.
Previously, modded-nanogpt medium added x0 to the residual at the input of every layer:
# GPT
def forward(self, input_seq, ...):
...
x = x0 = norm(self.embed(input_seq[None]))
...
for i in range(len(self.blocks)):
...
x = self.blocks[i](x, x0, lambdas, ...)
...
# Block
def forward(self, x, x0, lambdas, ...):
x = lambdas[0] * x + lambdas[1] * x0
Where lambdas are learned scalars.
This update adds another embedding module and adds it at every layer:
# GPT
def forward(self, input_seq, ...):
...
x = x00 = norm(self.embed1(input_seq[None]))
x01 = norm(self.embed2(input_seq[None]))
...
for i in range(len(self.blocks)):
...
x = self.blocks[i](x, x00, x01, lambdas, ...)
...
# Block
def forward(self, x, x00, x01, lambdas, ...):
x = lambdas[0] * x + lambdas[1] * x00 + lambdas[2] * x02
While this slows down training, it increases learning per step, thus allowing us to reduce the step count to 5690.
Here are the resulting final validation losses over 19 runs:
[2.919502, 2.91976, 2.920582, 2.919331, 2.919008, 2.919827, 2.918785, 2.918519, 2.919297, 2.920061, 2.918938, 2.919342, 2.918186, 2.920546, 2.91954, 2.919093, 2.918951, 2.919599, 2.919956]
And these are the basic stats:
- Mean: 2.9194117368421053
- Median: 2.919342
- Std: 0.000613243648848653
- Min: 2.918186
- Max: 2.920582
And t-test results:
{
'n': 19,
'sample_mean': 2.9194117368421053,
'sample_std': 0.0006300479560324045,
't_stat': -4.069816643193549,
'p_value': 0.00035946114919240566,
'alpha': 0.05,
'decision': 'REJECT H0 (mean < threshold)',
'upper_conf_bound_mean': 2.919662383449226,
'threshold': 2.92
}
The final loss is below 2.92 with >99% likelihood.
Here are the corresponding run-times in seconds:
[1414.299, 1412.033, 1411.668, 1421.735, 1411.998, 1411.094, 1412.637, 1410.047, 1410.509, 1412.048, 1411.574, 1415.299, 1411.649, 1412.94, 1412.508, 1410.912, 1415.296, 1410.778, 1407.511]
Leading to the following stats:
- Mean: 1412.4492105263155
- Median: 1411.998
- Std: 2.8062021268488864
- Min: 1407.511
- Max: 1421.735
The mean time is ~1412.5 seconds, or 23.54 minutes.