Kernels
dongseokmotif's picture
feat: extend QK-Clip to support MLA (MuonClip Algorithm 1) [skip-build] (#28)
e8e2c81 unverified
from itertools import repeat
from math import inf, sqrt
import numpy as np
import torch
from .matmul_transpose_triton import matmul_transpose_assign
COMM_DTYPE = torch.bfloat16
DEFAULT_CHUNK_SIZE_RATIO = 4
def _optimal_quintic(l, u, max_iter=1000):
"""
Use the simplified Remez algorithm to find the optimal odd quintic approximant
to the constant function x -> 1 over the interval [l, u].
Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
two interior equioscillation nodes q, r until convergence. Returns the
closed-form equioscillating solution when l ≈ u.
Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
(NaN or inf). Raises RuntimeError if convergence is not reached within
max_iter iterations.
"""
assert 0 <= l <= u
if 1 - 5e-6 <= l / u:
return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
q = (3 * l + u) / 4
r = (l + 3 * u) / 4
E = inf
for _ in range(max_iter):
old_E = E
LHS = np.array(
[
[l, l**3, l**5, 1],
[q, q**3, q**5, -1],
[r, r**3, r**5, 1],
[u, u**3, u**5, -1],
]
)
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
if not np.all(np.isfinite([a, b, c, E])):
raise ValueError(
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
)
q, r = np.sqrt(
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
)
if not np.all(np.isfinite([q, r])):
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
if abs(old_E - E) <= 1e-15:
break
else:
raise RuntimeError(
f"_optimal_quintic: did not converge after {max_iter} iterations"
)
return float(a), float(b), float(c)
def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
"""
Compute the Polar Express coefficient series for `num_iters` quintic iterations.
Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
compose to map singular values from [l, 1] toward 1. At each step:
1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
prevents near-zero singular values from stalling by raising the effective
lower bound; if it is active (cushion*u > l), the coefficients are
rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
last iteration, providing numerical headroom at the cost of a slightly slower
final convergence step.
3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
Returns a list of (a, b, c) tuples, one per iteration.
Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
"""
u = 1
assert 0 <= l <= u
safety_factor = 1 + safety_factor_eps
coefficients = []
for iter in range(num_iters):
a, b, c = _optimal_quintic(max(l, cushion * u), u)
if cushion * u > l:
pl = a * l + b * l**3 + c * l**5
pu = a * u + b * u**3 + c * u**5
rescaler = 2 / (pl + pu)
a *= rescaler
b *= rescaler
c *= rescaler
if iter < num_iters - 1:
a /= safety_factor
b /= safety_factor**3
c /= safety_factor**5
coefficients.append((a, b, c))
l = a * l + b * l**3 + c * l**5
u = 2 - l
return coefficients
# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
# approximant to x->1 over the current singular-value interval, computed once at
# import time and reused across all optimizer steps.
#
# Contrast with the former hardcoded NS coefficients (5 fixed tuples):
# - Former: empirically tuned to maximize slope at zero; did not converge
# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
# of the true polar factor UV^T.
# - Polar Express: analytically optimal per step, adapting to the shrinking
# singular-value interval [l, u] as iterations progress; converges all
# singular values to 1, producing the exact polar factor UV^T.
_coeffs_list = _optimal_composition(
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
)
# This code is adapted from:
# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
@torch.no_grad()
def _zeropower_via_newtonschulz5(G, steps):
"""
Compute the polar factor of G via the Polar Express method.
Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
are the Polar Express coefficients from `_coeffs_list`. Each step is the
optimal odd quintic approximant to x -> 1 over the current singular-value
interval, minimizing the maximum approximation error (Remez / minimax criterion).
The composition maps singular values from [l, 1] to near 1, producing the
polar factor (orthogonal factor in the polar decomposition G = UP).
`_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
"""
assert len(G.shape) == 2
assert G.dtype == COMM_DTYPE
X = G # no manual typecast
if G.size(0) > G.size(1):
X = X.T
X = X / (X.norm() + 1e-7)
hs = _coeffs_list[:steps] + list(
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
)
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
# Perform the NS iterations
for a, b, c in hs:
matmul_transpose_assign(X, buf1)
matmul_transpose_assign(buf1, buf2)
buf1.mul_(b).add_(buf2, alpha=c)
X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
if G.size(0) > G.size(1):
X = X.T
return X
@torch.no_grad()
def _zeropower_via_newtonschulz5_batched(G, steps):
"""Batched polar factor computation for 3D (E, out, in) tensors.
Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
processing all E expert matrices in a single batched call.
"""
assert len(G.shape) == 3
assert G.dtype == COMM_DTYPE
X = G
if G.size(1) > G.size(2):
X = X.transpose(-2, -1)
# Per-expert Frobenius norm.
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
hs = _coeffs_list[:steps] + list(
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
)
for a, b, c in hs:
buf1 = torch.bmm(X, X.transpose(-2, -1))
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
buf1.mul_(b).add_(buf2, alpha=c)
X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
if G.size(1) > G.size(2):
X = X.transpose(-2, -1)
return X
_ns_per_shape: dict[tuple[int, ...], callable] = {}
_use_compile = True
def set_ns_compile(enabled: bool):
"""Toggle torch.compile for Newton-Schulz iteration."""
global _use_compile
_use_compile = enabled
def zeropower_via_newtonschulz5(G, steps=5):
if not _use_compile:
return _zeropower_via_newtonschulz5(G, steps)
key = G.shape
if key not in _ns_per_shape:
_ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
options={
"triton.cudagraphs": True,
"shape_padding": False
})
torch.compiler.cudagraph_mark_step_begin()
return _ns_per_shape[key](G, steps).clone()
def zeropower_via_newtonschulz5_batched(G, steps=5):
"""Compile-cached batched Newton-Schulz for 3D expert tensors."""
if not _use_compile:
return _zeropower_via_newtonschulz5_batched(G, steps)
key = G.shape
if key not in _ns_per_shape:
_ns_per_shape[key] = torch.compile(
_zeropower_via_newtonschulz5_batched,
options={
"triton.cudagraphs": True,
"shape_padding": False
})
torch.compiler.cudagraph_mark_step_begin()
return _ns_per_shape[key](G, steps).clone()