Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| from itertools import repeat | |
| from math import inf, sqrt | |
| import numpy as np | |
| import torch | |
| from .matmul_transpose_triton import matmul_transpose_assign | |
| COMM_DTYPE = torch.bfloat16 | |
| DEFAULT_CHUNK_SIZE_RATIO = 4 | |
| def _optimal_quintic(l, u, max_iter=1000): | |
| """ | |
| Use the simplified Remez algorithm to find the optimal odd quintic approximant | |
| to the constant function x -> 1 over the interval [l, u]. | |
| Returns (a, b, c) for p(x) = ax + bx^3 + cx^5 that minimizes the maximum | |
| approximation error max_{x in [l,u]} |p(x) - 1|. Iterates by updating the | |
| two interior equioscillation nodes q, r until convergence. Returns the | |
| closed-form equioscillating solution when l ≈ u. | |
| Raises ValueError if any intermediate value (a, b, c, E, q, r) is non-finite | |
| (NaN or inf). Raises RuntimeError if convergence is not reached within | |
| max_iter iterations. | |
| """ | |
| assert 0 <= l <= u | |
| if 1 - 5e-6 <= l / u: | |
| return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5) | |
| q = (3 * l + u) / 4 | |
| r = (l + 3 * u) / 4 | |
| E = inf | |
| for _ in range(max_iter): | |
| old_E = E | |
| LHS = np.array( | |
| [ | |
| [l, l**3, l**5, 1], | |
| [q, q**3, q**5, -1], | |
| [r, r**3, r**5, 1], | |
| [u, u**3, u**5, -1], | |
| ] | |
| ) | |
| a, b, c, E = np.linalg.solve(LHS, np.ones(4)) | |
| if not np.all(np.isfinite([a, b, c, E])): | |
| raise ValueError( | |
| f"_optimal_quintic: non-finite solve result a={a}, b={b}, c={c}, E={E}" | |
| ) | |
| q, r = np.sqrt( | |
| (-3 * b + np.array([-1, 1]) * sqrt(9 * b**2 - 20 * a * c)) / (10 * c) | |
| ) | |
| if not np.all(np.isfinite([q, r])): | |
| raise ValueError(f"_optimal_quintic: non-finite node update q={q}, r={r}") | |
| if abs(old_E - E) <= 1e-15: | |
| break | |
| else: | |
| raise RuntimeError( | |
| f"_optimal_quintic: did not converge after {max_iter} iterations" | |
| ) | |
| return float(a), float(b), float(c) | |
| def _optimal_composition(l, num_iters, safety_factor_eps=0, cushion=0): | |
| """ | |
| Compute the Polar Express coefficient series for `num_iters` quintic iterations. | |
| Builds a sequence of per-step optimal odd quintic coefficients (a, b, c) that | |
| compose to map singular values from [l, 1] toward 1. At each step: | |
| 1. Solves `_optimal_quintic` on [max(l, cushion*u), u]. The `cushion` | |
| prevents near-zero singular values from stalling by raising the effective | |
| lower bound; if it is active (cushion*u > l), the coefficients are | |
| rescaled so that p(l) and p(u) are centered around 1 w.r.t. the true [l, u]. | |
| 2. Deflates the coefficients by (1 + safety_factor_eps)^degree for all but the | |
| last iteration, providing numerical headroom at the cost of a slightly slower | |
| final convergence step. | |
| 3. Advances the interval: l <- p(l), u <- 2 - p(l) (by symmetry of p around 1). | |
| Returns a list of (a, b, c) tuples, one per iteration. | |
| Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and | |
| Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 | |
| """ | |
| u = 1 | |
| assert 0 <= l <= u | |
| safety_factor = 1 + safety_factor_eps | |
| coefficients = [] | |
| for iter in range(num_iters): | |
| a, b, c = _optimal_quintic(max(l, cushion * u), u) | |
| if cushion * u > l: | |
| pl = a * l + b * l**3 + c * l**5 | |
| pu = a * u + b * u**3 + c * u**5 | |
| rescaler = 2 / (pl + pu) | |
| a *= rescaler | |
| b *= rescaler | |
| c *= rescaler | |
| if iter < num_iters - 1: | |
| a /= safety_factor | |
| b /= safety_factor**3 | |
| c /= safety_factor**5 | |
| coefficients.append((a, b, c)) | |
| l = a * l + b * l**3 + c * l**5 | |
| u = 2 - l | |
| return coefficients | |
| # Precomputed Polar Express coefficients (a, b, c) for 10 quintic Newton-Schulz | |
| # iterations. Each tuple is the minimax-optimal (Remez/equioscillation) odd quintic | |
| # approximant to x->1 over the current singular-value interval, computed once at | |
| # import time and reused across all optimizer steps. | |
| # | |
| # Contrast with the former hardcoded NS coefficients (5 fixed tuples): | |
| # - Former: empirically tuned to maximize slope at zero; did not converge | |
| # singular values to 1, yielding US'V^T with S' ~ Uniform(0.5, 1.5) instead | |
| # of the true polar factor UV^T. | |
| # - Polar Express: analytically optimal per step, adapting to the shrinking | |
| # singular-value interval [l, u] as iterations progress; converges all | |
| # singular values to 1, producing the exact polar factor UV^T. | |
| _coeffs_list = _optimal_composition( | |
| l=1e-3, num_iters=10, safety_factor_eps=1e-2, cushion=0.02 | |
| ) | |
| # This code is adapted from: | |
| # KellerJordan/Muon (https://github.com/KellerJordan/Muon/blob/master/muon.py) | |
| # NoahAmsel/PolarExpress (https://github.com/NoahAmsel/PolarExpress) | |
| # matmul_transpose_assign kernel from nil0x9/flash-muon (https://github.com/nil0x9/flash-muon) | |
| def _zeropower_via_newtonschulz5(G, steps): | |
| """ | |
| Compute the polar factor of G via the Polar Express method. | |
| Applies `steps` quintic iterations X <- aX + bX^3 + cX^5, where (a, b, c) | |
| are the Polar Express coefficients from `_coeffs_list`. Each step is the | |
| optimal odd quintic approximant to x -> 1 over the current singular-value | |
| interval, minimizing the maximum approximation error (Remez / minimax criterion). | |
| The composition maps singular values from [l, 1] to near 1, producing the | |
| polar factor (orthogonal factor in the polar decomposition G = UP). | |
| `_coeffs_list` is precomputed for 10 iterations (l=1e-3, safety_factor_eps=1e-2, | |
| cushion=0.02). If `steps` exceeds 10, the final coefficient set is repeated. | |
| Reference: Amsel et al., "The Polar Express: Optimal Matrix Sign Methods and | |
| Their Application to the Muon Algorithm", https://arxiv.org/abs/2505.16932 | |
| """ | |
| assert len(G.shape) == 2 | |
| assert G.dtype == COMM_DTYPE | |
| X = G # no manual typecast | |
| if G.size(0) > G.size(1): | |
| X = X.T | |
| X = X / (X.norm() + 1e-7) | |
| hs = _coeffs_list[:steps] + list( | |
| repeat(_coeffs_list[-1], steps - len(_coeffs_list)) | |
| ) | |
| buf1 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) | |
| buf2 = torch.empty(X.size(0), X.size(0), dtype=X.dtype, device=X.device) | |
| # Perform the NS iterations | |
| for a, b, c in hs: | |
| matmul_transpose_assign(X, buf1) | |
| matmul_transpose_assign(buf1, buf2) | |
| buf1.mul_(b).add_(buf2, alpha=c) | |
| X = torch.addmm(X, buf1, X, alpha=1.0, beta=a) | |
| if G.size(0) > G.size(1): | |
| X = X.T | |
| return X | |
| def _zeropower_via_newtonschulz5_batched(G, steps): | |
| """Batched polar factor computation for 3D (E, out, in) tensors. | |
| Same algorithm as ``_zeropower_via_newtonschulz5`` but uses | |
| ``torch.bmm`` / ``torch.baddbmm`` instead of the 2D Triton kernel, | |
| processing all E expert matrices in a single batched call. | |
| """ | |
| assert len(G.shape) == 3 | |
| assert G.dtype == COMM_DTYPE | |
| X = G | |
| if G.size(1) > G.size(2): | |
| X = X.transpose(-2, -1) | |
| # Per-expert Frobenius norm. | |
| X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) | |
| hs = _coeffs_list[:steps] + list( | |
| repeat(_coeffs_list[-1], steps - len(_coeffs_list)) | |
| ) | |
| for a, b, c in hs: | |
| buf1 = torch.bmm(X, X.transpose(-2, -1)) | |
| buf2 = torch.bmm(buf1, buf1.transpose(-2, -1)) | |
| buf1.mul_(b).add_(buf2, alpha=c) | |
| X = torch.baddbmm(X, buf1, X, alpha=1.0, beta=a) | |
| if G.size(1) > G.size(2): | |
| X = X.transpose(-2, -1) | |
| return X | |
| _ns_per_shape: dict[tuple[int, ...], callable] = {} | |
| _use_compile = True | |
| def set_ns_compile(enabled: bool): | |
| """Toggle torch.compile for Newton-Schulz iteration.""" | |
| global _use_compile | |
| _use_compile = enabled | |
| def zeropower_via_newtonschulz5(G, steps=5): | |
| if not _use_compile: | |
| return _zeropower_via_newtonschulz5(G, steps) | |
| key = G.shape | |
| if key not in _ns_per_shape: | |
| _ns_per_shape[key] = torch.compile(_zeropower_via_newtonschulz5, | |
| options={ | |
| "triton.cudagraphs": True, | |
| "shape_padding": False | |
| }) | |
| torch.compiler.cudagraph_mark_step_begin() | |
| return _ns_per_shape[key](G, steps).clone() | |
| def zeropower_via_newtonschulz5_batched(G, steps=5): | |
| """Compile-cached batched Newton-Schulz for 3D expert tensors.""" | |
| if not _use_compile: | |
| return _zeropower_via_newtonschulz5_batched(G, steps) | |
| key = G.shape | |
| if key not in _ns_per_shape: | |
| _ns_per_shape[key] = torch.compile( | |
| _zeropower_via_newtonschulz5_batched, | |
| options={ | |
| "triton.cudagraphs": True, | |
| "shape_padding": False | |
| }) | |
| torch.compiler.cudagraph_mark_step_begin() | |
| return _ns_per_shape[key](G, steps).clone() | |