# Update Smoothing + Snoo TL;DR: Combining #128 and #129, decreases iters to 5590 (-20 from #129, and p-values are much more robust.) This PR combines #128 (Snoo optimizer) and #128 (EMA on top of Muon). Both PRs are in some way “smoothing out” the updates: #129 smoothes the Muon update, and #128 applies a lookahead smoothing wrapper to the entire optimizer. Here, we just apply #128 to #129. After combining the two, the total iterations decreases to 5590. ## More detail on method #129 smooths out the Muon updates: ``` muon_update = NS(EMA(grads)) final_update = EMA(muon_update) ``` Here, unlike in #129, we use a constant ema coefficient of 0.2. #128 applies a lookahead step to the updates: run an inner optimizer for K iterations, and treat the parameter displacement as a “gradient” for an inner SGD optimizer. Note that if K=1 and the SGD optimizer does not employ Nesterov momentum, then I think the two are equivalent - with the exception that #128 works on every parameter rather than just the Muon parameters. Here, we simply use the smoothed Muon updates of #129 as the inner optimizer for #128, in addition to importing some learning rate tuning from #129. Overall, the total iterations can be decreased to 5590 (from 5610 in #129 or 5640 in #128). I also was more stringent with the p-value criterion, so that it’s likely there is a bit more “slack” in this submission than in either #128 or #129. ## Baselines (80 runs each) I have noticed that there is substantial variance in the p-values for these runs, so I ran 80 runs of each baseline, and then created 1000 bootstrap samples of size 40 to compute the fraction of times the p-value was less than 0.01. I’m not a real statistician, but I feel better about this methodology than the one employed in #129 to estimate the probability of seeing a p-value below 0.01. #129: ``` --- Val Loss Stats --- mean: 2.919815 std: 0.000751 val loss 99% confidence interval: (2.919594 - 2.920037) val_loss t-test p=0.015461 (small means <2.92) --- Bootstrap p-value analysis --- (1000 samples of size 40) Mean p-value: 0.139028 Variance of p-values: 0.029849 Percentage of p-values below 0.01: 21.00% --- Training Time Stats --- train time (minutes): mean=23.4811, std=0.1983 train time 99% confidence interval: (23.4227 - 23.5396) avg ms per iteration: 251.1352. 99%% confidence interval: (250.5097 - 251.7608) ``` #128 (here I use the current configuration with 5640 steps) ``` --- Val Loss Stats --- mean: 2.919738 std: 0.000884 val loss 99% confidence interval: (2.919477 - 2.919999) val_loss t-test p=0.004818 (small means <2.92) --- Bootstrap p-value analysis --- (1000 samples of size 40) Mean p-value: 0.092580 Variance of p-values: 0.018915 Percentage of p-values below 0.01: 32.10% --- Training Time Stats --- train time (minutes): mean=23.6421, std=0.1916 train time 99% confidence interval: (23.5856 - 23.6986) avg ms per iteration: 251.5118. 99%% confidence interval: (250.9105 - 252.1131) ``` So, from this we see that there both of these runs have a reasonable chance of hitting the required p-value in 40 samples. The “mean p-value” for the bootstrap analysis is very high because the mean is disproportionately favoring larger numbers. ## This PR I ran 160 runs for the new changes in order to have more data, and from these again created 1000 bootstrapped samples of size 40 each to get an idea for the variance in the p-value calculation. Over these samples, we see: ``` --- Val Loss stats over all 160 runs --- mean: 2.919547 std: 0.000798 val loss 99% confidence interval: (2.919383 - 2.919712) val_loss t-test p=0.000000 (small means <2.92) --- Bootstrap p-value analysis (1000 samples of size 40 each) --- Mean p-value: 0.006984 Max p-value: 0.262882 Variance of p-values: 0.000433 Percentage of p-values below 0.01: 85.40% --- Training Time Stats --- train time (minutes): mean=23.4283, std=0.1866 train time 99% confidence interval: (23.3899 - 23.4668) avg ms per iteration: 251.4670. 99%% confidence interval: (251.0542 - 251.8799) ``` ## More Aggressive run with 5580 iterations: I also checked 120 runs of 5580 iterations. As expected, this still hits the target, but the p-value is a bit less robust. ``` --- Val loss stats over all 120 runs --- mean: 2.919583 std: 0.000897 val loss 99% confidence interval: (2.919368 - 2.919797) val_loss t-test p=0.000001 (small means <2.92) --- Bootstrap p-value analysis (1000 samples of size 40 each) --- Mean p-value: 0.018333 Max p-value: 0.376981 Variance of p-values: 0.001668 Percentage of p-values below 0.01: 67.90% --- Training time stats --- train time (minutes): mean=23.4492, std=0.2011 train time 99% confidence interval: (23.4012 - 23.4973) avg ms per iteration: 252.1423. 99%% confidence interval: (251.6256 - 252.6591) ``` ## Ablation To make sure that the improvement over 128 is not just from the new LR tuning, I turned off the update smoothing, but kept the LR tuning. I also increase the number of number of iterations to 5600, which I guessed would more than make up for any improved time-per-step: ``` --- Val Loss Stats --- mean: 2.920357 std: 0.000802 val loss 99% confidence interval: (2.920120 - 2.920593) val_loss t-test p=0.999924 (small means <2.92) --- Bootstrap p-value analysis --- (1000 samples of size 40) Mean p-value: 0.973404 Max p-value: 1.000000 Variance of p-values: 0.003703 Percentage of p-values below 0.01: 0.00% --- Training Time Stats --- train time (minutes): mean=23.5413, std=0.2211 train time 99% confidence interval: (23.4761 - 23.6065) avg ms per iteration: 252.2285. 99%% confidence interval: (251.5297 - 252.9272) ``` So, it does not seem to hit the target without the smoothing. I also tried tuning the LR cooldown fraction a bit (both with and without smoothing) as suggested by @YouJiacheng in a comment on #129, but also did not find any improvement from this. A list of all 120 validation losses: ``` 2.919161 2.920945 2.91878 2.91945 2.920088 2.918436 2.919751 2.919509 2.919121 2.920388 2.920208 2.920169 2.920938 2.918948 2.919245 2.919653 2.918682 2.918916 2.919926 2.920458 2.918769 2.918555 2.91991 2.919425 2.922082 2.919508 2.920449 2.919091 2.921161 2.919444 2.920434 2.918194 2.919289 2.919533 2.9209 2.918483 2.919002 2.919399 2.920047 2.920363 2.918821 2.920426 2.920432 2.918828 2.918984 2.918681 2.918769 2.918822 2.919352 2.919853 2.919699 2.919783 2.918965 2.919565 2.918902 2.919225 2.920187 2.919625 2.921371 2.919239 2.919902 2.918071 2.919462 2.918726 2.920078 2.918884 2.919408 2.920146 2.919939 2.920311 2.920426 2.919574 2.919629 2.921047 2.918987 2.918633 2.918057 2.919441 2.920069 2.921082 2.920105 2.920009 2.918286 2.919617 2.920899 2.919312 2.919833 2.918901 2.920027 2.919553 2.918713 2.920759 2.919725 2.91843 2.919194 2.920136 2.919102 2.920179 2.919613 2.919428 2.920121 2.918931 2.918599 2.919027 2.918768 2.920173 2.91906 2.919343 2.921149 2.919538 2.919927 2.919984 2.920188 2.919886 2.918576 2.919965 2.919993 2.919684 2.918075 2.920297 2.920482 2.920536 2.919626 2.919845 2.919099 2.919832 2.918258 2.920294 2.920837 2.918292 2.918897 2.917934 2.919626 2.919178 2.918989 2.919164 2.918687 2.918274 2.918378 2.920714 2.920003 2.919554 2.918437 2.919514 2.920284 2.918734 2.920206 2.919427 2.918294 2.920774 2.918721 2.918992 2.919474 2.920078 2.918853 2.917999 2.919675 2.91946 2.920768 2.920036 ``` A list of all 160 timings: ``` 1404459 1395516 1417959 1400905 1389809 1416569 1402959 1414998 1397510 1403570 1398782 1425294 1395871 1415545 1396661 1408827 1396578 1413940 1385842 1412084 1410787 1389944 1395654 1392580 1400412 1392653 1404503 1395197 1400111 1397917 1433637 1395602 1415162 1397540 1404055 1403213 1418843 1415625 1406476 1416766 1430944 1396947 1395281 1405680 1415661 1405027 1420947 1403780 1401405 1385218 1397050 1397721 1395464 1400580 1391885 1406463 1404080 1434562 1399548 1397347 1393351 1401256 1425077 1395779 1399501 1405212 1401776 1403159 1386425 1408381 1414500 1406785 1395996 1404508 1418131 1396965 1416503 1418343 1434041 1414912 1409539 1404085 1398737 1401971 1403821 1403490 1417833 1407924 1403386 1414721 1406877 1399508 1386470 1416935 1407935 1397397 1398797 1427337 1397961 1400335 1392212 1403485 1398353 1407128 1396716 1401178 1398133 1402724 1402052 1414994 1420134 1435721 1399607 1432821 1403566 1403252 1397278 1396447 1427209 1417239 1385722 1386063 1405059 1402496 1407869 1399238 1393754 1422199 1403489 1405654 1419264 1396958 1395686 1416636 1418610 1401210 1403842 1404138 1404863 1426769 1408366 1398821 1399002 1399979 1403091 1405100 1424880 1388376 1405475 1416867 1403473 1403733 1419152 1416690 1386661 1402672 1432678 1399397 1419047 1417291 ```