Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
File size: 9,049 Bytes
10848ab 33929c0 10848ab e8e2c81 10848ab e8e2c81 10848ab e8e2c81 10848ab e8e2c81 10848ab e8e2c81 10848ab e8e2c81 10848ab 33929c0 10848ab 33929c0 10848ab 33929c0 10848ab e8e2c81 33929c0 10848ab 33929c0 10848ab 33929c0 10848ab e8e2c81 10848ab 313d56a 10848ab 313d56a 10848ab 313d56a 10848ab 313d56a 10848ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 | from itertools import repeat
from math import inf, sqrt
import numpy as np
import torch
from .matmul_transpose_triton import matmul_transpose_assign
COMM_DTYPE = torch.bfloat16
DEFAULT_CHUNK_SIZE_RATIO = 4
def _optimal_quintic(l, u, max_iter=1000):
"""
Use the simplified Remez algorithm to find the optimal odd quintic approximant
to the constant function x -> 1 over the interval [l, u].
Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum
approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the
two interior equioscillation nodes q, r until convergence. Returns the
closed-form equioscillating solution when l ≈ u.
Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite
(NaN or inf). Raises RuntimeError if convergence is not reached within
max_iter iterations.
"""
assert 0 <= l <= u
if 1 - 5e-6 <= l / u:
return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
q = (3 * l + u) / 4
r = (l + 3 * u) / 4
E = inf
for _ in range(max_iter):
old_E = E
LHS = np.array(
[
[l, l**3, l**5, 1],
[q, q**3, q**5, -1],
[r, r**3, r**5, 1],
[u, u**3, u**5, -1],
]
)
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
if not np.all(np.isfinite([a, b, c, E])):
raise ValueError(
f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}"
)
q, r = np.sqrt(
(-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c)
)
if not np.all(np.isfinite([q, r])):
raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}")
if abs(old_E - E) <= 1e-15:
break
else:
raise RuntimeError(
f"_optimal_quintic: did not converge after {max_iter} iterations"
)
return float(a), float(b), float(c)
def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0):
"""
Compute the Polar Express coefficient series for `num_iters` quintic iterations.
Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that
compose to map singular values from [l, 1] toward 1. At each step:
1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion`
prevents near-zero singular values from stalling by raising the effective
lower bound; if it is active (cushion*u > l), the coefficients are
rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u].
2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the
last iteration, providing numerical headroom at the cost of a slightly slower
final convergence step.
3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1).
Returns a list of (a, b, c) tuples, one per iteration.
Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
"""
u = 1
assert 0 <= l <= u
safety_factor = 1 + safety_factor_eps
coefficients = []
for iter in range(num_iters):
a, b, c = _optimal_quintic(max(l, cushion * u), u)
if cushion * u > l:
pl = a * l + b * l**3 + c * l**5
pu = a * u + b * u**3 + c * u**5
rescaler = 2 / (pl + pu)
a *= rescaler
b *= rescaler
c *= rescaler
if iter < num_iters - 1:
a /= safety_factor
b /= safety_factor**3
c /= safety_factor**5
coefficients.append((a, b, c))
l = a * l + b * l**3 + c * l**5
u = 2 - l
return coefficients
# Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz
# iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic
# approximant to x->1 over the current singular-value interval, computed once at
# import time and reused across all optimizer steps.
#
# Contrast with the former hardcoded NS coefficients (5 fixed tuples):
# - Former: empirically tuned to maximize slope at zero; did not converge
# singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead
# of the true polar factor UV^T.
# - Polar Express: analytically optimal per step, adapting to the shrinking
# singular-value interval [l, u] as iterations progress; converges all
# singular values to 1, producing the exact polar factor UV^T.
_coeffs_list = _optimal_composition(
l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02
)
# This code is adapted from:
# KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py)
# NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress)
# matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon)
@torch.no_grad()
def _zeropower_via_newtonschulz5(G, steps):
"""
Compute the polar factor of G via the Polar Express method.
Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c)
are the Polar Express coefficients from `_coeffs_list`. Each step is the
optimal odd quintic approximant to x -> 1 over the current singular-value
interval, minimizing the maximum approximation error (Remez / minimax criterion).
The composition maps singular values from [l, 1] to near 1, producing the
polar factor (orthogonal factor in the polar decomposition G = UP).
`_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2,
cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated.
Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and
Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932
"""
assert len(G.shape) == 2
assert G.dtype == COMM_DTYPE
X = G # no manual typecast
if G.size(0) > G.size(1):
X = X.T
X = X / (X.norm() + 1e-7)
hs = _coeffs_list[:steps] + list(
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
)
buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device)
# Perform the NS iterations
for a, b, c in hs:
matmul_transpose_assign(X, buf1)
matmul_transpose_assign(buf1, buf2)
buf1.mul_(b).add_(buf2, alpha=c)
X = torch.addmm(X, buf1, X, alpha=1.0, beta=a)
if G.size(0) > G.size(1):
X = X.T
return X
@torch.no_grad()
def _zeropower_via_newtonschulz5_batched(G, steps):
"""Batched polar factor computation for 3D (E, out, in) tensors.
Same algorithm as ``_zeropower_via_newtonschulz5`` but uses
``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel,
processing all E expert matrices in a single batched call.
"""
assert len(G.shape) == 3
assert G.dtype == COMM_DTYPE
X = G
if G.size(1) > G.size(2):
X = X.transpose(-2, -1)
# Per-expert Frobenius norm.
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
hs = _coeffs_list[:steps] + list(
repeat(_coeffs_list[-1], steps - len(_coeffs_list))
)
for a, b, c in hs:
buf1 = torch.bmm(X, X.transpose(-2, -1))
buf2 = torch.bmm(buf1, buf1.transpose(-2, -1))
buf1.mul_(b).add_(buf2, alpha=c)
X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a)
if G.size(1) > G.size(2):
X = X.transpose(-2, -1)
return X
_ns_per_shape: dict[tuple[int, ...], callable] = {}
_use_compile = True
def set_ns_compile(enabled: bool):
"""Toggle torch.compile for Newton-Schulz iteration."""
global _use_compile
_use_compile = enabled
def zeropower_via_newtonschulz5(G, steps=5):
if not _use_compile:
return _zeropower_via_newtonschulz5(G, steps)
key = G.shape
if key not in _ns_per_shape:
_ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5,
options={
"triton.cudagraphs": True,
"shape_padding": False
})
torch.compiler.cudagraph_mark_step_begin()
return _ns_per_shape[key](G, steps).clone()
def zeropower_via_newtonschulz5_batched(G, steps=5):
"""Compile-cached batched Newton-Schulz for 3D expert tensors."""
if not _use_compile:
return _zeropower_via_newtonschulz5_batched(G, steps)
key = G.shape
if key not in _ns_per_shape:
_ns_per_shape[key] = torch.compile(
_zeropower_via_newtonschulz5_batched,
options={
"triton.cudagraphs": True,
"shape_padding": False
})
torch.compiler.cudagraph_mark_step_begin()
return _ns_per_shape[key](G, steps).clone()
|