Kernels
File size: 9,049 Bytes
10848ab
 
 
 
33929c0
 
 
 
 
 
 
 
10848ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e2c81
 
 
 
 
 
 
 
10848ab
 
e8e2c81
 
 
10848ab
e8e2c81
 
10848ab
e8e2c81
10848ab
 
 
 
e8e2c81
 
10848ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e2c81
 
 
10848ab
 
 
 
 
 
33929c0
 
 
10848ab
 
 
 
 
 
 
 
 
 
 
 
 
 
33929c0
 
 
 
 
 
 
10848ab
33929c0
10848ab
e8e2c81
 
33929c0
 
 
10848ab
33929c0
 
 
 
 
 
 
10848ab
33929c0
10848ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8e2c81
 
10848ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313d56a
10848ab
 
313d56a
10848ab
 
 
 
 
 
 
 
 
 
 
 
313d56a
10848ab
 
313d56a
10848ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from itertools import repeat
from math import inf, sqrt

import numpy as np
import torch

from .matmul_transpose_triton import matmul_transpose_assign

COMM_DTYPE = torch.bfloat16
DEFAULT_CHUNK_SIZE_RATIO = 4


def _optimal_quintic(l, u, max_iter=1000):
    """
    Use the simplified Remez algorithm to find the optimal odd quintic approximant
    to the constant function x -> 1 over the interval [l, u].

    Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
    approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
    two interior equioscillation nodes q, r until convergence. Returns the
    closed-form equioscillating solution when l ≈ u.

    Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
    (NaN or inf). Raises RuntimeError if convergence is not reached within
    max_iter iterations.
    """
    assert 0 <= l <= u
    if 1 - 5e-6 <= l / u:
        return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
    q = (3 * l + u) / 4
    r = (l + 3 * u) / 4
    E = inf
    for _ in range(max_iter):
        old_E = E
        LHS = np.array(
            [
                [l, l**3, l**5, 1],
                [q, q**3, q**5, -1],
                [r, r**3, r**5, 1],
                [u, u**3, u**5, -1],
            ]
        )
        a, b, c, E = np.linalg.solve(LHS, np.ones(4))
        if not np.all(np.isfinite([a, b, c, E])):
            raise ValueError(
                f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
            )
        q, r = np.sqrt(
            (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
        )
        if not np.all(np.isfinite([q, r])):
            raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
        if abs(old_E - E) <= 1e-15:
            break
    else:
        raise RuntimeError(
            f"_optimal_quintic: did not converge after {max_iter} iterations"
        )
    return float(a), float(b), float(c)


def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
    """
    Compute the Polar Express coefficient series for `num_iters` quintic iterations.

    Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
    compose to map singular values from [l, 1] toward 1. At each step:
      1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
         prevents near-zero singular values from stalling by raising the effective
         lower bound; if it is active (cushion*u > l), the coefficients are
         rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
      2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
         last iteration, providing numerical headroom at the cost of a slightly slower
         final convergence step.
      3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).

    Returns a list of (a, b, c) tuples, one per iteration.

    Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
    Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
    """
    u = 1
    assert 0 <= l <= u
    safety_factor = 1 + safety_factor_eps
    coefficients = []
    for iter in range(num_iters):
        a, b, c = _optimal_quintic(max(l, cushion * u), u)
        if cushion * u > l:
            pl = a * l + b * l**3 + c * l**5
            pu = a * u + b * u**3 + c * u**5
            rescaler = 2 / (pl + pu)
            a *= rescaler
            b *= rescaler
            c *= rescaler
        if iter < num_iters - 1:
            a /= safety_factor
            b /= safety_factor**3
            c /= safety_factor**5
        coefficients.append((a, b, c))
        l = a * l + b * l**3 + c * l**5
        u = 2 - l
    return coefficients


# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
# approximant to x->1 over the current singular-value interval, computed once at
# import time and reused across all optimizer steps.
#
# Contrast with the former hardcoded NS coefficients (5 fixed tuples):
#   - Former: empirically tuned to maximize slope at zero; did not converge
#     singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
#     of the true polar factor UV^T.
#   - Polar Express: analytically optimal per step, adapting to the shrinking
#     singular-value interval [l, u] as iterations progress; converges all
#     singular values to 1, producing the exact polar factor UV^T.
_coeffs_list = _optimal_composition(
    l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
)


# This code is adapted from:
#   KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
#   NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
#   matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
@torch.no_grad()
def _zeropower_via_newtonschulz5(G, steps):
    """
    Compute the polar factor of G via the Polar Express method.

    Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
    are the Polar Express coefficients from `_coeffs_list`. Each step is the
    optimal odd quintic approximant to x -> 1 over the current singular-value
    interval, minimizing the maximum approximation error (Remez / minimax criterion).
    The composition maps singular values from [l, 1] to near 1, producing the
    polar factor (orthogonal factor in the polar decomposition G = UP).

    `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
    cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.

    Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
    Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
    """
    assert len(G.shape) == 2
    assert G.dtype == COMM_DTYPE
    X = G  # no manual typecast

    if G.size(0) > G.size(1):
        X = X.T

    X = X / (X.norm() + 1e-7)
    hs = _coeffs_list[:steps] + list(
        repeat(_coeffs_list[-1], steps - len(_coeffs_list))
    )
    buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
    buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
    # Perform the NS iterations
    for a, b, c in hs:
        matmul_transpose_assign(X, buf1)
        matmul_transpose_assign(buf1, buf2)
        buf1.mul_(b).add_(buf2, alpha=c)
        X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)

    if G.size(0) > G.size(1):
        X = X.T

    return X


@torch.no_grad()
def _zeropower_via_newtonschulz5_batched(G, steps):
    """Batched polar factor computation for 3D (E, out, in) tensors.

    Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
    ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
    processing all E expert matrices in a single batched call.
    """
    assert len(G.shape) == 3
    assert G.dtype == COMM_DTYPE
    X = G

    if G.size(1) > G.size(2):
        X = X.transpose(-2, -1)

    # Per-expert Frobenius norm.
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

    hs = _coeffs_list[:steps] + list(
        repeat(_coeffs_list[-1], steps - len(_coeffs_list))
    )
    for a, b, c in hs:
        buf1 = torch.bmm(X, X.transpose(-2, -1))
        buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
        buf1.mul_(b).add_(buf2, alpha=c)
        X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)

    if G.size(1) > G.size(2):
        X = X.transpose(-2, -1)

    return X


_ns_per_shape: dict[tuple[int, ...], callable] = {}
_use_compile = True


def set_ns_compile(enabled: bool):
    """Toggle torch.compile for Newton-Schulz iteration."""
    global _use_compile
    _use_compile = enabled


def zeropower_via_newtonschulz5(G, steps=5):
    if not _use_compile:
        return _zeropower_via_newtonschulz5(G, steps)
    key = G.shape
    if key not in _ns_per_shape:
        _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
                                           options={
                                               "triton.cudagraphs": True,
                                               "shape_padding": False
                                           })
    torch.compiler.cudagraph_mark_step_begin()
    return _ns_per_shape[key](G, steps).clone()


def zeropower_via_newtonschulz5_batched(G, steps=5):
    """Compile-cached batched Newton-Schulz for 3D expert tensors."""
    if not _use_compile:
        return _zeropower_via_newtonschulz5_batched(G, steps)
    key = G.shape
    if key not in _ns_per_shape:
        _ns_per_shape[key] = torch.compile(
            _zeropower_via_newtonschulz5_batched,
            options={
                "triton.cudagraphs": True,
                "shape_padding": False
            })
    torch.compiler.cudagraph_mark_step_begin()
    return _ns_per_shape[key](G, steps).clone()