from itertools import repeat from math import inf, sqrt import numpy as np import torch from .matmul_transpose_triton import matmul_transpose_assign COMM_DTYPE = torch.bfloat16 DEFAULT_CHUNK_SIZE_RATIO = 4 def _optimal_quintic(l, u, max_iter=1000): """ Use the simplified Remez algorithm to find the optimal odd quintic approximant to the constant function x -> 1 over the interval [l, u]. Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the two interior equioscillation nodes q, r until convergence. Returns the closed-form equioscillating solution when l ≈ u. Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite (NaN or inf). Raises RuntimeError if convergence is not reached within max_iter iterations. """ assert 0 <= l <= u if 1 - 5e-6 <= l / u: return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) q = (3 * l + u) / 4 r = (l + 3 * u) / 4 E = inf for _ in range(max_iter): old_E = E LHS = np.array( [ [l, l**3, l**5, 1], [q, q**3, q**5, -1], [r, r**3, r**5, 1], [u, u**3, u**5, -1], ] ) a, b, c, E = np.linalg.solve(LHS, np.ones(4)) if not np.all(np.isfinite([a, b, c, E])): raise ValueError( f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" ) q, r = np.sqrt( (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) ) if not np.all(np.isfinite([q, r])): raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") if abs(old_E - E) <= 1e-15: break else: raise RuntimeError( f"_optimal_quintic: did not converge after {max_iter} iterations" ) return float(a), float(b), float(c) def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): """ Compute the Polar Express coefficient series for `num_iters` quintic iterations. Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that compose to map singular values from [l, 1] toward 1. At each step: 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` prevents near-zero singular values from stalling by raising the effective lower bound; if it is active (cushion*u > l), the coefficients are rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the last iteration, providing numerical headroom at the cost of a slightly slower final convergence step. 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). Returns a list of (a, b, c) tuples, one per iteration. Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ u = 1 assert 0 <= l <= u safety_factor = 1 + safety_factor_eps coefficients = [] for iter in range(num_iters): a, b, c = _optimal_quintic(max(l, cushion * u), u) if cushion * u > l: pl = a * l + b * l**3 + c * l**5 pu = a * u + b * u**3 + c * u**5 rescaler = 2 / (pl + pu) a *= rescaler b *= rescaler c *= rescaler if iter < num_iters - 1: a /= safety_factor b /= safety_factor**3 c /= safety_factor**5 coefficients.append((a, b, c)) l = a * l + b * l**3 + c * l**5 u = 2 - l return coefficients # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic # approximant to x->1 over the current singular-value interval, computed once at # import time and reused across all optimizer steps. # # Contrast with the former hardcoded NS coefficients (5 fixed tuples): # - Former: empirically tuned to maximize slope at zero; did not converge # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead # of the true polar factor UV^T. # - Polar Express: analytically optimal per step, adapting to the shrinking # singular-value interval [l, u] as iterations progress; converges all # singular values to 1, producing the exact polar factor UV^T. _coeffs_list = _optimal_composition( l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 ) # This code is adapted from: # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) @torch.no_grad() def _zeropower_via_newtonschulz5(G, steps): """ Compute the polar factor of G via the Polar Express method. Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) are the Polar Express coefficients from `_coeffs_list`. Each step is the optimal odd quintic approximant to x -> 1 over the current singular-value interval, minimizing the maximum approximation error (Remez / minimax criterion). The composition maps singular values from [l, 1] to near 1, producing the polar factor (orthogonal factor in the polar decomposition G = UP). `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 """ assert len(G.shape) == 2 assert G.dtype == COMM_DTYPE X = G # no manual typecast if G.size(0) > G.size(1): X = X.T X = X / (X.norm() + 1e-7) hs = _coeffs_list[:steps] + list( repeat(_coeffs_list[-1], steps - len(_coeffs_list)) ) buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) # Perform the NS iterations for a, b, c in hs: matmul_transpose_assign(X, buf1) matmul_transpose_assign(buf1, buf2) buf1.mul_(b).add_(buf2, alpha=c) X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) if G.size(0) > G.size(1): X = X.T return X @torch.no_grad() def _zeropower_via_newtonschulz5_batched(G, steps): """Batched polar factor computation for 3D (E, out, in) tensors. Same algorithm as ``_zeropower_via_newtonschulz5`` but uses ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, processing all E expert matrices in a single batched call. """ assert len(G.shape) == 3 assert G.dtype == COMM_DTYPE X = G if G.size(1) > G.size(2): X = X.transpose(-2, -1) # Per-expert Frobenius norm. X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) hs = _coeffs_list[:steps] + list( repeat(_coeffs_list[-1], steps - len(_coeffs_list)) ) for a, b, c in hs: buf1 = torch.bmm(X, X.transpose(-2, -1)) buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) buf1.mul_(b).add_(buf2, alpha=c) X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) if G.size(1) > G.size(2): X = X.transpose(-2, -1) return X _ns_per_shape: dict[tuple[int, ...], callable] = {} _use_compile = True def set_ns_compile(enabled: bool): """Toggle torch.compile for Newton-Schulz iteration.""" global _use_compile _use_compile = enabled def zeropower_via_newtonschulz5(G, steps=5): if not _use_compile: return _zeropower_via_newtonschulz5(G, steps) key = G.shape if key not in _ns_per_shape: _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, options={ "triton.cudagraphs": True, "shape_padding": False }) torch.compiler.cudagraph_mark_step_begin() return _ns_per_shape[key](G, steps).clone() def zeropower_via_newtonschulz5_batched(G, steps=5): """Compile-cached batched Newton-Schulz for 3D expert tensors.""" if not _use_compile: return _zeropower_via_newtonschulz5_batched(G, steps) key = G.shape if key not in _ns_per_shape: _ns_per_shape[key] = torch.compile( _zeropower_via_newtonschulz5_batched, options={ "triton.cudagraphs": True, "shape_padding": False }) torch.compiler.cudagraph_mark_step_begin() return _ns_per_shape[key](G, steps).clone()