Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| import logging | |
| from typing import Generator | |
| import torch | |
| import torch.distributed as dist | |
| from torch.distributed.tensor import DTensor | |
| from torch.profiler import record_function | |
| from .core import _muon_state, adjust_lr_for_muon | |
| from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 | |
| from .qk_clip import compute_scales | |
| logger = logging.getLogger(__name__) | |
| # ====================================================================== | |
| # Stage helpers | |
| # ====================================================================== | |
| def _launch_gather( | |
| params: list[DTensor], | |
| owned_params: list[DTensor], | |
| param_to_state: dict[int, _muon_state], | |
| rank: int, | |
| num_ranks: int, | |
| process_group: dist.ProcessGroup, | |
| ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: | |
| """Allocate gather buffers, build send/recv, and launch async all-to-all. | |
| Returns: | |
| work: Async operation handle. | |
| recv_buf: Flat receive buffer (needed by ``_complete_gather``). | |
| gathered_grads: ``{id(p): empty_tensor}`` for owned params, | |
| ``None`` for non-owned. | |
| recv_counts: Per-source-rank element counts. | |
| """ | |
| # Allocate gathered-grad buffers | |
| gathered_grads: dict[int, torch.Tensor | None] = {} | |
| for p in params: | |
| state = param_to_state[id(p)] | |
| if rank == state.worker_rank: | |
| gathered_grads[id(p)] = torch.empty(p.shape, | |
| dtype=COMM_DTYPE, | |
| device="cuda") | |
| else: | |
| gathered_grads[id(p)] = None | |
| # Build send buffer – batch grad copies via torch.cat | |
| # (1-2 fused kernels vs N individual narrow().copy_() calls). | |
| send_counts = [0] * num_ranks | |
| for p in params: | |
| state = param_to_state[id(p)] | |
| send_counts[state.worker_rank] += state.rank_numels[rank] | |
| total_send = sum(send_counts) | |
| if total_send > 0: | |
| # Group grad slices by destination rank in a single pass. | |
| dst_to_grads = [[] for _ in range(num_ranks)] | |
| for p in params: | |
| state = param_to_state[id(p)] | |
| n = state.rank_numels[rank] | |
| if n > 0: | |
| g = p.grad.to_local() | |
| dst_to_grads[state.worker_rank].append(g.reshape(-1)) | |
| # Flatten in dst order and cat once. | |
| all_slices = [] | |
| for dst in range(num_ranks): | |
| all_slices.extend(dst_to_grads[dst]) | |
| send_buf = torch.cat(all_slices) | |
| if send_buf.dtype != COMM_DTYPE: | |
| send_buf = send_buf.to(COMM_DTYPE) | |
| else: | |
| send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") | |
| # Build recv buffer | |
| recv_counts = [0] * num_ranks | |
| for src in range(num_ranks): | |
| total = 0 | |
| for p in owned_params: | |
| state = param_to_state[id(p)] | |
| assert state.worker_rank == rank | |
| total += state.rank_numels[src] | |
| recv_counts[src] = total | |
| recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") | |
| # Launch async all-to-all | |
| logger.debug(f"send_buf size: {send_buf.numel()}, " | |
| f"recv_buf size: {recv_buf.numel()}, " | |
| f"recv_counts: {recv_counts}, " | |
| f"send_counts: {send_counts}, " | |
| f"process_group: {str(process_group)}") | |
| work = dist.all_to_all_single( | |
| recv_buf, | |
| send_buf, | |
| output_split_sizes=recv_counts, | |
| input_split_sizes=send_counts, | |
| group=process_group, | |
| async_op=True, | |
| ) | |
| return work, recv_buf, gathered_grads, recv_counts | |
| def _complete_gather( | |
| recv_buf: torch.Tensor, | |
| recv_counts: list[int], | |
| owned_params: list[DTensor], | |
| gathered_grads: dict[int, torch.Tensor | None], | |
| param_to_state: dict[int, _muon_state], | |
| rank: int, | |
| ) -> None: | |
| """Reconstruct gathered grads from the recv buffer (in-place).""" | |
| off = 0 | |
| for src in range(len(recv_counts)): | |
| if recv_counts[src] == 0: | |
| continue | |
| block = recv_counts[src] | |
| inner_off = 0 | |
| for p in owned_params: | |
| state = param_to_state[id(p)] | |
| assert state.worker_rank == rank | |
| indices = state.rank_indices[src] | |
| shard_view = gathered_grads[id(p)][indices] | |
| n = shard_view.numel() | |
| if n == 0: | |
| continue | |
| sg = recv_buf.narrow(0, off + inner_off, n) | |
| sg = sg.reshape(shard_view.shape) | |
| gathered_grads[id(p)][indices] = sg | |
| inner_off += n | |
| assert inner_off == block | |
| off += block | |
| def _compute_ns( | |
| owned_params: list[DTensor], | |
| gathered_grads: dict[int, torch.Tensor | None], | |
| ns_steps: int, | |
| ) -> dict[int, torch.Tensor | None]: | |
| """Run Newton-Schulz orthogonalization on owned parameters. | |
| Returns: | |
| computed_us: ``{id(p): orthogonalized_update}`` for owned params. | |
| """ | |
| computed_us: dict[int, torch.Tensor | None] = {} | |
| for p in owned_params: | |
| u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) | |
| gathered_grads[id(p)] = None # free gathered grad | |
| computed_us[id(p)] = u | |
| return computed_us | |
| def _launch_scatter( | |
| params: list[DTensor], | |
| owned_params: list[DTensor], | |
| param_to_state: dict[int, _muon_state], | |
| rank: int, | |
| num_ranks: int, | |
| process_group: dist.ProcessGroup, | |
| computed_us: dict[int, torch.Tensor | None], | |
| ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: | |
| """Allocate scatter buffers, build send/recv, and launch async all-to-all. | |
| Returns: | |
| work: Async operation handle. | |
| recv_buf: Flat receive buffer (needed by ``_complete_scatter``). | |
| scattered_us: Empty dict, populated by ``_complete_scatter`` with | |
| zero-copy views into ``recv_buf``. | |
| recv_counts: Per-source-rank element counts. | |
| """ | |
| # scattered_us is populated by _complete_scatter with zero-copy views | |
| # into recv_buf, avoiding N empty_like allocations + N copy_ calls. | |
| # Pre-seed entries for params whose local shard is empty (rank_numels == 0) | |
| # so _update_params can iterate all params without KeyError. | |
| scattered_us: dict[int, torch.Tensor] = {} | |
| for p in params: | |
| if param_to_state[id(p)].rank_numels[rank] == 0: | |
| scattered_us[id(p)] = torch.empty_like(p.to_local(), | |
| dtype=COMM_DTYPE) | |
| # Build send buffer – batch via torch.cat | |
| # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). | |
| send_counts = [0] * num_ranks | |
| if owned_params: | |
| for p in owned_params: | |
| state = param_to_state[id(p)] | |
| for dst_rank in range(num_ranks): | |
| send_counts[dst_rank] += state.rank_numels[dst_rank] | |
| total_send = sum(send_counts) | |
| if total_send > 0: | |
| # Cache u_full conversions to avoid redundant .to() per dst_rank. | |
| u_fulls = {} | |
| for p in owned_params: | |
| u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() | |
| # Collect slices in dst order (matches all-to-all send layout). | |
| all_slices = [] | |
| for dst_rank in range(num_ranks): | |
| for p in owned_params: | |
| state = param_to_state[id(p)] | |
| su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() | |
| if su.numel() > 0: | |
| all_slices.append(su) | |
| send_buf = torch.cat(all_slices) if all_slices else torch.empty( | |
| 0, dtype=COMM_DTYPE, device="cuda") | |
| else: | |
| send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") | |
| # Build recv buffer | |
| recv_counts = [0] * num_ranks | |
| for src in range(num_ranks): | |
| total = 0 | |
| for p in params: | |
| state = param_to_state[id(p)] | |
| if state.worker_rank != src: | |
| continue | |
| total += state.rank_numels[rank] | |
| recv_counts[src] = total | |
| recv_total = sum(recv_counts) | |
| recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") | |
| # Launch async all-to-all | |
| work = dist.all_to_all_single( | |
| recv_buf, | |
| send_buf, | |
| output_split_sizes=recv_counts, | |
| input_split_sizes=send_counts, | |
| group=process_group, | |
| async_op=True, | |
| ) | |
| return work, recv_buf, scattered_us, recv_counts | |
| def _complete_scatter( | |
| recv_buf: torch.Tensor, | |
| recv_counts: list[int], | |
| params: list[DTensor], | |
| param_to_state: dict[int, _muon_state], | |
| rank: int, | |
| scattered_us: dict[int, torch.Tensor], | |
| ) -> None: | |
| """Populate scattered_us with zero-copy views into recv_buf. | |
| Instead of pre-allocating tensors and copying, we assign views directly | |
| from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. | |
| The underlying storage of ``recv_buf`` is kept alive through the views | |
| until ``scattered_us`` is cleared after ``_update_params``. | |
| """ | |
| off = 0 | |
| for src in range(len(recv_counts)): | |
| block = recv_counts[src] | |
| if block == 0: | |
| continue | |
| inner_off = 0 | |
| for p in params: | |
| state = param_to_state[id(p)] | |
| if state.worker_rank != src: | |
| continue | |
| n = state.rank_numels[rank] | |
| if n == 0: | |
| continue | |
| scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, | |
| n).view_as(p.to_local()) | |
| inner_off += n | |
| assert inner_off == block | |
| off += block | |
| def _update_params( | |
| params: list[DTensor], | |
| param_to_state: dict[int, _muon_state], | |
| rank: int, | |
| scattered_us: dict[int, torch.Tensor], | |
| lr: float, | |
| weight_decay: float, | |
| ) -> None: | |
| """Apply weight decay, Muon update, and optional QK clipping. | |
| Uses batched ``_foreach_mul_`` for weight decay and batched | |
| ``_foreach_add_`` for the Muon update, grouping parameters by | |
| adjusted_lr to minimize kernel launches while preserving float32 | |
| precision for the alpha scaling. | |
| """ | |
| if not params: | |
| return | |
| # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. | |
| p_locals = [p._local_tensor for p in params] | |
| torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) | |
| # Group params by adjusted_lr so _foreach_add_ can use a single | |
| # alpha per group (preserves float32 precision for alpha scaling). | |
| lr_groups: dict[float, tuple[list, list]] = {} | |
| for p in params: | |
| adjusted_lr = adjust_lr_for_muon(lr, p.shape) | |
| if adjusted_lr not in lr_groups: | |
| lr_groups[adjusted_lr] = ([], []) | |
| lr_groups[adjusted_lr][0].append(p._local_tensor) | |
| lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) | |
| for adjusted_lr, (p_group, u_group) in lr_groups.items(): | |
| torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) | |
| # QK clipping – applied directly on the local tensor to | |
| # avoid DTensor sharding-propagation issues with _StridedShard. | |
| for p in params: | |
| state = param_to_state[id(p)] | |
| if state.qk_clip_state is None: | |
| continue | |
| scales_full = compute_scales(p, state.qk_clip_state) | |
| if scales_full is not None: | |
| ratio = p.shape[0] // scales_full.shape[0] | |
| idx0 = state.rank_indices[rank][0] | |
| if isinstance(idx0, slice): | |
| start = idx0.start or 0 | |
| idx0 = torch.arange(start, | |
| idx0.stop, | |
| device=scales_full.device) | |
| row_scales = scales_full[idx0 // ratio] | |
| p._local_tensor.mul_(row_scales.view(-1, 1)) | |
| # ====================================================================== | |
| # Pre-launch helper for overlapping first chunk's gather with other work. | |
| # ====================================================================== | |
| def prelaunch_first_gather( | |
| params: list[DTensor], | |
| param_to_state: dict[int, _muon_state], | |
| rank: int, | |
| none_grad: bool, | |
| ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: | |
| """Launch the first chunk's A2A gather early for overlap with other compute. | |
| Call this *before* expensive GPU work (e.g. batched expert NS) so that | |
| the NCCL all-to-all runs concurrently on the NCCL stream while the | |
| default stream executes compute. | |
| Returns the same 4-tuple that ``_launch_gather`` produces, which should | |
| be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. | |
| """ | |
| process_group = param_to_state[id(params[0])].process_group | |
| num_ranks = dist.get_world_size(group=process_group) | |
| owned_params = [ | |
| p for p in params if param_to_state[id(p)].worker_rank == rank | |
| ] | |
| with record_function("muon::prelaunch_gather"): | |
| work, recv_buf, gathered_grads, recv_counts = _launch_gather( | |
| params, owned_params, param_to_state, rank, num_ranks, | |
| process_group) | |
| if none_grad: | |
| for p in params: | |
| p.grad = None | |
| return work, recv_buf, gathered_grads, recv_counts | |
| # ====================================================================== | |
| # Main generator – thin orchestrator that wires stages together. | |
| # ====================================================================== | |
| def muon_chunk_pipeline( | |
| params: list[DTensor], | |
| param_to_state: dict[int, _muon_state], | |
| rank: int, | |
| ns_steps: int, | |
| lr: float, | |
| weight_decay: float, | |
| none_grad: bool, | |
| prelaunch_gather: tuple | None = None, | |
| ) -> Generator[None, None, None]: | |
| """Process one chunk of parameters through the full Muon pipeline. | |
| Stages: gather -> compute (Newton-Schulz) -> scatter -> update. | |
| Each ``yield`` lets :func:`run_pipeline` interleave other chunks so | |
| that communication and computation overlap across chunks. Async | |
| communication is launched via ``async_op=True`` and completed after | |
| the yield with ``work.wait()``. | |
| Overlap happens because :func:`run_pipeline` admits one new chunk | |
| per iteration (staggered admission). While chunk *N* does NS | |
| compute on the default CUDA stream, chunk *N+1*'s async all-to-all | |
| runs concurrently on the NCCL stream — no separate ``comm_stream`` | |
| is required. | |
| If ``prelaunch_gather`` is provided, the gather was already launched | |
| by :func:`prelaunch_first_gather` and we skip launching it again. | |
| Yields exactly **2** times: | |
| 1. After launching async all-to-all gather (or immediately if pre-launched). | |
| 2. After launching async all-to-all scatter. | |
| """ | |
| process_group = param_to_state[id(params[0])].process_group | |
| num_ranks = dist.get_world_size(group=process_group) | |
| owned_params = [ | |
| p for p in params if param_to_state[id(p)].worker_rank == rank | |
| ] | |
| if prelaunch_gather is not None: | |
| # Gather was pre-launched; none_grad already handled by caller. | |
| work, recv_buf, gathered_grads, recv_counts = prelaunch_gather | |
| else: | |
| # Normal path: launch async gather. | |
| with record_function("muon::launch_gather"): | |
| work, recv_buf, gathered_grads, recv_counts = _launch_gather( | |
| params, owned_params, param_to_state, rank, num_ranks, | |
| process_group) | |
| if none_grad: | |
| for p in params: | |
| p.grad = None | |
| yield # --- YIELD 1: other chunks can launch their gather --- | |
| with record_function("muon::wait_gather"): | |
| work.wait() | |
| _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, | |
| param_to_state, rank) | |
| del recv_buf | |
| # Stage 3: Newton-Schulz orthogonalization. | |
| with record_function("muon::newton_schulz"): | |
| computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) | |
| gathered_grads.clear() | |
| # Stages 4-5: launch async scatter. | |
| with record_function("muon::launch_scatter"): | |
| work, recv_buf, scattered_us, recv_counts = _launch_scatter( | |
| params, owned_params, param_to_state, rank, num_ranks, | |
| process_group, computed_us) | |
| computed_us.clear() | |
| yield # --- YIELD 2: other chunks can launch their scatter --- | |
| with record_function("muon::wait_scatter"): | |
| work.wait() | |
| _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, | |
| scattered_us) | |
| del recv_buf | |
| # Stage 6: apply parameter updates. | |
| with record_function("muon::update_params"): | |
| _update_params(params, param_to_state, rank, scattered_us, lr, | |
| weight_decay) | |
| scattered_us.clear() | |