""" utils.grokfast — accelerated grokking by amplifying slow-varying gradient components (Lee et al. 2024, arXiv:2405.20233). Maintain an EMA of gradients across steps; the slow-EMA component corresponds to the generalising circuit. Adding it back into the live gradient (scaled by `lamb`) accelerates the grokking transition 20-100×. """ from __future__ import annotations def gradfilter_ema(model, grads_ema, alpha: float = 0.98, lamb: float = 2.0): """ Call this AFTER `loss.backward()` and BEFORE `optimizer.step()`. Args: model: the network whose gradients we are filtering. grads_ema: dict {param_name: ema_grad}, or None on the first call. alpha: EMA decay (0.98 → very slow, emphasises persistent grads). lamb: amplification factor for the slow component. Returns: Updated `grads_ema` dict — pass it back in on the next step. """ if grads_ema is None: grads_ema = {} for name, p in model.named_parameters(): if p.requires_grad and p.grad is not None: if name not in grads_ema: grads_ema[name] = p.grad.data.detach().clone() else: grads_ema[name] = ( grads_ema[name] * alpha + p.grad.data.detach() * (1 - alpha) ) p.grad.data = p.grad.data + grads_ema[name] * lamb return grads_ema