Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| import logging | |
| import types | |
| from collections import defaultdict | |
| from typing import Any | |
| import torch | |
| import torch.distributed as dist | |
| from torch.distributed.tensor import DTensor, Replicate, Shard | |
| from torch.profiler import record_function | |
| from .adamw import _placement_cache, _tensor_cache, step_adamw | |
| from .async_utils import run_pipeline | |
| from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, | |
| get_default_muon_param_groups, is_expert_param, update_p) | |
| from .cpu_offload import CPUOffloadPool | |
| from .distributed.utils import (_is_shard, construct_shard_mesh, | |
| get_slices_of_dtensor) | |
| from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, | |
| _zeropower_via_newtonschulz5, | |
| zeropower_via_newtonschulz5, | |
| zeropower_via_newtonschulz5_batched) | |
| from .pipeline import muon_chunk_pipeline, prelaunch_first_gather | |
| from .qk_clip import compute_scales, get_qk_clip_info, qk_clip | |
| logger = logging.getLogger(__name__) | |
| def _expand_expert_params(names, params, expert_keys): | |
| """Expand expert params by splitting on dim 0 (expert dimension). | |
| Params whose name matches any key in ``expert_keys`` are treated as | |
| expert-parallel tensors. Their outermost dimension is the expert | |
| dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D | |
| ``nn.Parameter`` views so that in-place updates propagate back to | |
| the original storage. | |
| Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — | |
| if they are expert params, their key must be added to ``expert_keys``. | |
| The grad must already be set on each expert param (e.g. after momentum). | |
| For DTensor expert params, placements that shard on dim 0 (expert dim) | |
| are consumed by the split. Non-dim-0 shard placements (e.g. TP) are | |
| preserved: each 2D slice is wrapped as a DTensor on the corresponding | |
| submesh so the parallel pipeline handles the TP communication. | |
| """ | |
| expanded_names = [] | |
| expanded_params = [] | |
| for n, p in zip(names, params): | |
| is_expert = is_expert_param(n, expert_keys) | |
| is_dtensor = isinstance(p.data, DTensor) | |
| if is_expert: | |
| if is_dtensor: | |
| logger.debug( | |
| "[expand_expert] %s: expert DTensor, shape=%s, " | |
| "placements=%s, mesh=%s, local_shape=%s", n, p.shape, | |
| p.placements, p.device_mesh.mesh_dim_names, | |
| p.to_local().shape) | |
| else: | |
| logger.debug( | |
| "[expand_expert] %s: expert plain tensor, shape=%s", n, | |
| p.data.shape) | |
| if not is_expert: | |
| assert p.data.ndim <= 2, ( | |
| f"Param {n} has ndim={p.data.ndim} but does not match " | |
| f"expert_keys={expert_keys}. If this is an expert param, " | |
| f"add its key to expert_keys.") | |
| expanded_names.append(n) | |
| expanded_params.append(p) | |
| continue | |
| g = p.grad | |
| assert g is not None, ( | |
| f"Expert param {n} must have grad set before expansion") | |
| tp_mesh = None | |
| tp_placements_2d = None | |
| if is_dtensor: | |
| local_data = p.to_local() | |
| local_grad = g.to_local() if isinstance(g, DTensor) else g | |
| # Find non-dim-0 shard placements (e.g. TP sharding). | |
| # After splitting on dim 0, Shard(k) becomes Shard(k-1). | |
| tp_dim_indices = [] | |
| tp_placements_2d = [] | |
| for i, pl in enumerate(p.placements): | |
| if _is_shard(pl) and pl.dim != 0: | |
| tp_dim_indices.append(i) | |
| tp_placements_2d.append(Shard(pl.dim - 1)) | |
| if tp_dim_indices: | |
| tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] | |
| for i in tp_dim_indices) | |
| if len(tp_dim_names) == 1: | |
| tp_mesh = p.device_mesh[tp_dim_names[0]] | |
| else: | |
| tp_mesh = p.device_mesh[tp_dim_names] | |
| else: | |
| local_data = p.data | |
| local_grad = g | |
| # Expand: split dim 0, reshape each slice to 2D. | |
| num_local_experts = local_data.shape[0] | |
| for i in range(num_local_experts): | |
| slice_data = local_data[i] | |
| slice_grad = local_grad[i] | |
| if tp_mesh is not None: | |
| # Wrap as DTensor on TP submesh so the pipeline handles | |
| # TP communication (gather/scatter across TP ranks). | |
| dt_data = DTensor.from_local(slice_data, | |
| device_mesh=tp_mesh, | |
| placements=tp_placements_2d) | |
| dt_grad = DTensor.from_local(slice_grad, | |
| device_mesh=tp_mesh, | |
| placements=tp_placements_2d) | |
| expert_param = torch.nn.Parameter(dt_data, requires_grad=False) | |
| expert_param.grad = dt_grad | |
| else: | |
| expert_param = torch.nn.Parameter(slice_data, | |
| requires_grad=False) | |
| expert_param.grad = slice_grad | |
| expanded_names.append(f"{n}[{i}]") | |
| expanded_params.append(expert_param) | |
| p.grad = None # allow expert grad storage to be freed after pipeline | |
| return expanded_names, expanded_params | |
| class Muon(torch.optim.Optimizer): | |
| """ | |
| Muon - MomentUm Orthogonalized by Newton-schulz | |
| Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- | |
| processing step, in which each 2D parameter's update is replaced with the nearest orthogonal | |
| matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has | |
| the advantage that it can be stably run in bfloat16 on the GPU. | |
| Some warnings: | |
| - We believe this optimizer is unlikely to work well for training with small batch size. | |
| - We believe it may not work well for finetuning pretrained models, but we haven't tested this. | |
| Arguments: | |
| model: The model to be optimized by Muon. | |
| is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. | |
| lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) | |
| momentum: The momentum used by the internal SGD. (0.95 is a good default) | |
| nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) | |
| ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) | |
| weight_decay: The weight decay for Muon and AdamW. | |
| Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. | |
| adamw_lr: The learning rate for the internal AdamW. | |
| adamw_betas: The betas for the internal AdamW. | |
| adamw_eps: The epsilon for the internal AdamW. | |
| none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. | |
| debug: Whether to print debug information. | |
| clip_info : Configuration for QK clipping. Expected keys: | |
| - "q_indices" (list[int]): Indices of query heads to consider. | |
| - "k_indices" (list[int]): Indices of key heads to consider. | |
| - "head_dim" (int): Dimensionality of each attention head. | |
| - "threshold" (float): Threshold value; heads whose QK logits exceed | |
| this value will be scaled down. | |
| Default is: | |
| { | |
| "q_indices": [], | |
| "k_indices": [], | |
| "head_dim": 128, | |
| "threshold": 100 | |
| } | |
| warmup_step : How many all2all gather, compute operations are launched in advance | |
| before the corresponding all2all scatter steps begin. | |
| A higher warmup_step increases memory usage but can improve | |
| performance by overlapping communication. | |
| Parallel muon only. | |
| chunk_size : Batch size of parameters to process in each | |
| all2all gather/compute/scatter step. | |
| Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. | |
| use_distributed_muon: Use distributed muon by Liu et al. (2024). | |
| For testing purpose only. | |
| expert_keys: List of strings to identify expert-parallel parameters. | |
| If any key appears in a parameter's name, its outermost | |
| dimension is treated as the expert dimension and expanded | |
| into per-expert 2D params for Muon. For example, | |
| ``expert_keys=["experts"]`` matches any param whose name | |
| contains "experts". 3D+ params not matched by any key | |
| will raise an error. | |
| """ | |
| def __init__(self, | |
| params, | |
| lr=1e-3, | |
| momentum=0.95, | |
| nesterov=True, | |
| ns_steps=5, | |
| weight_decay=0.1, | |
| adamw_betas=(0.9, 0.95), | |
| adamw_eps=1e-8, | |
| none_grad=True, | |
| debug=False, | |
| clip_config=None, | |
| warmup_step=5, | |
| chunk_size=-1, | |
| use_distributed_muon=False, | |
| expert_keys=None): | |
| defaults = dict( | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| momentum=momentum, | |
| nesterov=nesterov, | |
| ns_steps=ns_steps, | |
| adamw_betas=adamw_betas, | |
| adamw_eps=adamw_eps, | |
| none_grad=none_grad, | |
| use_muon=True, | |
| ) | |
| error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." | |
| instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" | |
| if isinstance(params, types.GeneratorType): | |
| raise ValueError(error_message.format(idx=0) + instruction_code) | |
| for _idx, param_group in enumerate(params): | |
| if param_group.get("use_muon", None) is None: | |
| raise ValueError( | |
| error_message.format(idx=_idx) + instruction_code) | |
| super().__init__(params, defaults) | |
| self.debug = debug | |
| self.clip_config = clip_config if clip_config is not None else { | |
| "q_indices": [], | |
| "k_indices": [], | |
| "head_dim": 128, | |
| "threshold": 100, | |
| } | |
| self.warmup_step = warmup_step | |
| self.chunk_size = chunk_size | |
| self.use_distributed_muon = use_distributed_muon | |
| self.expert_keys = expert_keys | |
| self.cpu_offload = False | |
| self._cpu_offload_pool: CPUOffloadPool | None = None | |
| self._offload_initialized = False | |
| self._parallel_cache: dict[tuple[str, ...], dict] = {} | |
| self._expert_expand_cache: dict[tuple[int, ...], dict] = {} | |
| def _calc_flops(self, G, steps): | |
| assert len(G.shape) == 2 | |
| M, N = G.shape | |
| if M > N: | |
| M, N = N, M | |
| return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) | |
| def get_shard_mesh(self, p): | |
| """ | |
| Get the shard mesh for a parameter p on the given rank. | |
| """ | |
| assert isinstance( | |
| p, DTensor), "Parallel Muon only supports DTensor parameters." | |
| shard_mesh, shard_pg, shard_placements = construct_shard_mesh( | |
| p.placements, p.device_mesh) | |
| return shard_mesh, shard_pg, shard_placements | |
| def init_state_and_assign_params(self, names, params, group, qk_logits): | |
| param_to_state = {} | |
| param_to_flops = {} | |
| total_flops = 0 | |
| for p in params: | |
| g = p.grad | |
| if g is None: | |
| continue | |
| assert g.ndim == 2, "Muon only supports 2D parameters." | |
| flops = self._calc_flops(g, group["ns_steps"]) | |
| param_to_flops[id(p)] = flops | |
| total_flops += flops | |
| if self.debug: | |
| logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", | |
| total_flops / 1e12) | |
| paired = list(zip(names, params)) | |
| paired_sorted = sorted(paired, | |
| key=lambda x: param_to_flops[id(x[1])], | |
| reverse=True) | |
| names_sorted, params_sorted = zip(*paired_sorted) | |
| ordered_names = list(names_sorted) | |
| ordered_params = list(params_sorted) | |
| round_robin = 0 | |
| mesh = ordered_params[0].device_mesh | |
| placements = ordered_params[0].placements | |
| shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( | |
| ordered_params[0]) | |
| shard_mesh_flattened = shard_mesh.mesh.flatten() | |
| num_ranks = dist.get_world_size(group=shard_pg) | |
| for n, p in zip(ordered_names, ordered_params): | |
| if mesh != p.device_mesh: | |
| raise ValueError("All parameters must be on the same mesh.") | |
| if placements != p.placements: | |
| raise ValueError("All parameters must have same placements.") | |
| worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks | |
| round_robin = (round_robin + 1) % len(shard_mesh_flattened) | |
| qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) | |
| # Precompute per-rank indices and numels for all-to-all. | |
| rank_indices: dict[int, tuple] = {} | |
| rank_numels: dict[int, int] = {} | |
| for r in range(num_ranks): | |
| indices = get_slices_of_dtensor(p, r, shard_mesh, | |
| shard_placements) | |
| rank_indices[r] = indices | |
| numel = 1 | |
| for idx, dim_size in zip(indices, p.shape): | |
| if isinstance(idx, slice): | |
| start, stop, step = idx.indices(dim_size) | |
| numel *= max(0, (stop - start + (step - 1)) // step) | |
| else: | |
| numel *= len(idx) | |
| rank_numels[r] = numel | |
| param_to_state[id(p)] = _muon_state( | |
| worker_rank=worker_rank, | |
| process_group=shard_pg, | |
| rank_indices=rank_indices, | |
| rank_numels=rank_numels, | |
| name=n, | |
| qk_clip_state=qk_clip_state, | |
| ) | |
| return param_to_state, ordered_params | |
| def base(self, names, params, group, lr, weight_decay, qk_logits): | |
| # Momentum is already applied by _step_muon before this method. | |
| for n, p in zip(names, params): | |
| g = p.grad | |
| if g is None: | |
| continue | |
| u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), | |
| steps=group["ns_steps"]) | |
| adjusted_lr = adjust_lr_for_muon(lr, p.shape) | |
| update_p(p, u, lr, adjusted_lr, weight_decay) | |
| qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) | |
| scales_full = compute_scales( | |
| p, qk_clip_state) if qk_clip_state is not None else None | |
| if scales_full is not None: | |
| qk_clip(p, scales_full, qk_clip_state) | |
| def distributed_muon( | |
| self, | |
| names: list[str], | |
| params: list[torch.nn.Parameter], | |
| group: dict[str, Any], | |
| lr: float, | |
| weight_decay: float, | |
| qk_logits: list[torch.Tensor | DTensor] | None, | |
| ): | |
| """Batched Distributed Muon — for testing/correctness verification only. | |
| Uses all-gather to reconstruct full tensors, computes Newton-Schulz on | |
| the full grad, then slices back to local shards. This is simpler but | |
| slower than the parallel pipeline (all2all) path, so it serves as a | |
| reference implementation for verifying correctness. | |
| """ | |
| with record_function("distributed_muon"): | |
| # Momentum is already applied by _step_muon before this method. | |
| ns_steps = group["ns_steps"] | |
| # Separate plain tensors (no communication) from DTensors. | |
| plain_names, plain_params = [], [] | |
| dtensor_names, dtensor_params = [], [] | |
| for n, p in zip(names, params): | |
| if p.grad is None: | |
| continue | |
| if isinstance(p.data, DTensor): | |
| dtensor_names.append(n) | |
| dtensor_params.append(p) | |
| else: | |
| plain_names.append(n) | |
| plain_params.append(p) | |
| # Process plain tensors per-param (no communication). | |
| for n, p in zip(plain_names, plain_params): | |
| u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), | |
| steps=ns_steps) | |
| adjusted_lr = adjust_lr_for_muon(lr, p.shape) | |
| update_p(p, u, lr, adjusted_lr, weight_decay) | |
| qk_clip_state = get_qk_clip_info(self.clip_config, n, | |
| qk_logits) | |
| scales_full = compute_scales( | |
| p, qk_clip_state) if qk_clip_state is not None else None | |
| if scales_full is not None: | |
| qk_clip(p, scales_full, qk_clip_state) | |
| if not dtensor_params: | |
| return | |
| # Group DTensors by (placements, mesh) for batched all-gather. | |
| placement_groups: dict[tuple, | |
| tuple[list, | |
| list]] = defaultdict(lambda: ([], [])) | |
| for n, p in zip(dtensor_names, dtensor_params): | |
| key = (p.placements, p.device_mesh) | |
| placement_groups[key][0].append(n) | |
| placement_groups[key][1].append(p) | |
| logger.info( | |
| "distributed_muon: %d placement groups, %d total dtensors", | |
| len(placement_groups), len(dtensor_params)) | |
| for (placements, mesh), (grp_names, | |
| grp_params) in placement_groups.items(): | |
| shard_mesh, shard_pg, shard_placements = construct_shard_mesh( | |
| placements, mesh) | |
| rank = dist.get_rank(shard_pg) | |
| world_size = dist.get_world_size(shard_pg) | |
| logger.info(" group: %d params, placements=%s, world_size=%d", | |
| len(grp_params), placements, world_size) | |
| # Separate params that can be batched (all shard dims evenly | |
| # divisible) from those needing per-param full_tensor | |
| # (e.g. MoE gate weights with fewer rows than shard ranks). | |
| # all_gather_into_tensor requires equal buffer sizes across | |
| # ranks, so uneven splits must use DTensor full_tensor(). | |
| batch_names, batch_params = [], [] | |
| single_names, single_params = [], [] | |
| for n, p in zip(grp_names, grp_params): | |
| even = all(p.shape[pl.dim] % | |
| shard_mesh.mesh.shape[dim_idx] == 0 | |
| for dim_idx, pl in enumerate(shard_placements)) | |
| if even: | |
| batch_names.append(n) | |
| batch_params.append(p) | |
| else: | |
| single_names.append(n) | |
| single_params.append(p) | |
| # Process uneven-split params per-param via full_tensor(). | |
| for n, p in zip(single_names, single_params): | |
| with record_function("distributed_muon::newton_schulz"): | |
| g_full = p.grad.full_tensor().to(COMM_DTYPE) | |
| u_full = _zeropower_via_newtonschulz5(g_full, | |
| steps=ns_steps) | |
| del g_full | |
| with record_function("distributed_muon::update"): | |
| adjusted_lr = adjust_lr_for_muon(lr, p.shape) | |
| p._local_tensor.mul_(1 - lr * weight_decay) | |
| local_indices = get_slices_of_dtensor( | |
| p, rank, shard_mesh, shard_placements) | |
| u_local = u_full[local_indices] | |
| p._local_tensor.add_(u_local, alpha=-adjusted_lr) | |
| del u_full | |
| qk_clip_state = get_qk_clip_info( | |
| self.clip_config, n, qk_logits) | |
| scales_full = compute_scales( | |
| p, qk_clip_state | |
| ) if qk_clip_state is not None else None | |
| if scales_full is not None: | |
| ratio = p.shape[0] // scales_full.shape[0] | |
| idx0 = local_indices[0] | |
| if isinstance(idx0, slice): | |
| start = idx0.start or 0 | |
| idx0 = torch.arange(start, | |
| idx0.stop, | |
| device=scales_full.device) | |
| row_scales = scales_full[idx0 // ratio] | |
| p._local_tensor.mul_(row_scales.view(-1, 1)) | |
| if not batch_params: | |
| continue | |
| logger.info(" batched=%d, single=%d", len(batch_params), | |
| len(single_params)) | |
| # Concat all local grad shards into a single flat buffer. | |
| with record_function("distributed_muon::gather"): | |
| grad_locals = [ | |
| p.grad.to_local().to(COMM_DTYPE).flatten() | |
| for p in batch_params | |
| ] | |
| numels = [g.numel() for g in grad_locals] | |
| grad_concat = torch.cat(grad_locals) | |
| del grad_locals | |
| # Single all-gather (replaces N separate full_tensor). | |
| grad_gathered = torch.empty( | |
| grad_concat.numel() * world_size, | |
| dtype=COMM_DTYPE, | |
| device="cuda", | |
| ) | |
| dist.all_gather_into_tensor(grad_gathered, | |
| grad_concat, | |
| group=shard_pg) | |
| total_numel = grad_concat.numel() | |
| del grad_concat | |
| # Precompute per-param offsets within the concat buffer. | |
| offsets = [] | |
| off = 0 | |
| for ne in numels: | |
| offsets.append(off) | |
| off += ne | |
| # Per-param: reconstruct full grad → NS → local update. | |
| for i, (n, p) in enumerate(zip(batch_names, batch_params)): | |
| with record_function("distributed_muon::newton_schulz"): | |
| g_full = torch.empty(p.shape, | |
| dtype=COMM_DTYPE, | |
| device="cuda") | |
| for r in range(world_size): | |
| r_start = r * total_numel + offsets[i] | |
| shard = grad_gathered[r_start:r_start + numels[i]] | |
| indices = get_slices_of_dtensor( | |
| p, r, shard_mesh, shard_placements) | |
| g_full[indices] = shard.reshape( | |
| g_full[indices].shape) | |
| u_full = _zeropower_via_newtonschulz5(g_full, | |
| steps=ns_steps) | |
| del g_full | |
| with record_function("distributed_muon::update"): | |
| adjusted_lr = adjust_lr_for_muon(lr, p.shape) | |
| p._local_tensor.mul_(1 - lr * weight_decay) | |
| local_indices = get_slices_of_dtensor( | |
| p, rank, shard_mesh, shard_placements) | |
| u_local = u_full[local_indices] | |
| p._local_tensor.add_(u_local, alpha=-adjusted_lr) | |
| del u_full | |
| qk_clip_state = get_qk_clip_info( | |
| self.clip_config, n, qk_logits) | |
| scales_full = compute_scales( | |
| p, qk_clip_state | |
| ) if qk_clip_state is not None else None | |
| if scales_full is not None: | |
| ratio = p.shape[0] // scales_full.shape[0] | |
| idx0 = local_indices[0] | |
| if isinstance(idx0, slice): | |
| start = idx0.start or 0 | |
| idx0 = torch.arange(start, | |
| idx0.stop, | |
| device=scales_full.device) | |
| row_scales = scales_full[idx0 // ratio] | |
| p._local_tensor.mul_(row_scales.view(-1, 1)) | |
| def _setup_parallel(self, names, params, group, qk_logits): | |
| """Compute (or retrieve cached) parallel pipeline metadata. | |
| Returns: | |
| (ordered_params, param_to_state, rank, chunk_size) | |
| """ | |
| cache_key = tuple(names) | |
| if cache_key not in self._parallel_cache: | |
| # First call: compute metadata and populate cache. | |
| param_to_state, ordered_params = self.init_state_and_assign_params( | |
| names, params, group, qk_logits) | |
| shard_pg = param_to_state[id(ordered_params[0])].process_group | |
| rank = dist.get_rank(group=shard_pg) | |
| if self.chunk_size == -1: | |
| shard_ranks = dist.get_world_size(shard_pg) | |
| chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO | |
| elif self.chunk_size > 0: | |
| chunk_size = self.chunk_size | |
| else: | |
| raise ValueError( | |
| "chunk_size must be -1 or a positive integer.") | |
| ordered_names = [ | |
| param_to_state[id(p)].name for p in ordered_params | |
| ] | |
| name_to_state = { | |
| param_to_state[id(p)].name: param_to_state[id(p)] | |
| for p in ordered_params | |
| } | |
| self._parallel_cache[cache_key] = { | |
| 'ordered_names': ordered_names, | |
| 'name_to_state': name_to_state, | |
| 'rank': rank, | |
| 'chunk_size': chunk_size, | |
| } | |
| else: | |
| # Cached path: rebuild param_to_state with current id(p) keys. | |
| cache = self._parallel_cache[cache_key] | |
| rank = cache['rank'] | |
| chunk_size = cache['chunk_size'] | |
| name_to_param = dict(zip(names, params)) | |
| ordered_params = [name_to_param[n] for n in cache['ordered_names']] | |
| param_to_state = {} | |
| for p, n in zip(ordered_params, cache['ordered_names']): | |
| cached_state = cache['name_to_state'][n] | |
| param_to_state[id(p)] = _muon_state( | |
| worker_rank=cached_state.worker_rank, | |
| process_group=cached_state.process_group, | |
| rank_indices=cached_state.rank_indices, | |
| rank_numels=cached_state.rank_numels, | |
| name=n, | |
| qk_clip_state=get_qk_clip_info(self.clip_config, n, | |
| qk_logits), | |
| ) | |
| return ordered_params, param_to_state, rank, chunk_size | |
| def parallel(self, | |
| names, | |
| params, | |
| group, | |
| lr, | |
| weight_decay, | |
| qk_logits, | |
| prelaunch_gather=None): | |
| """ | |
| Perform a parallel optimization step using Muon. | |
| Parameters are chunked and each chunk is processed by a | |
| :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` | |
| interleaves multiple chunks so that communication and computation | |
| overlap across chunks (the same overlap previously achieved by the | |
| warmup + main-loop index scheduling). | |
| If ``prelaunch_gather`` is provided, it is passed to the first | |
| chunk's generator to skip re-launching the already in-flight | |
| A2A gather. | |
| """ | |
| # Momentum is already applied by _step_muon before this method. | |
| ordered_params, param_to_state, rank, chunk_size = ( | |
| self._setup_parallel(names, params, group, qk_logits)) | |
| def pipelines(): | |
| first = True | |
| for start in range(0, len(ordered_params), chunk_size): | |
| chunk = ordered_params[start:start + chunk_size] | |
| if chunk: | |
| kwargs = dict( | |
| params=chunk, | |
| param_to_state=param_to_state, | |
| rank=rank, | |
| ns_steps=group["ns_steps"], | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| none_grad=group["none_grad"], | |
| ) | |
| if first and prelaunch_gather is not None: | |
| kwargs['prelaunch_gather'] = prelaunch_gather | |
| first = False | |
| yield muon_chunk_pipeline(**kwargs) | |
| with record_function("muon::pipeline"): | |
| run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) | |
| def _step_muon(self, group, qk_logits=None): | |
| params = group["params"] | |
| lr = group["lr"] | |
| weight_decay = group["weight_decay"] | |
| momentum = group["momentum"] | |
| names = group["names"] | |
| # Apply momentum to all params before routing/expansion. | |
| # Batched using _foreach_* ops (compiled, fullgraph=True). | |
| with record_function("muon::momentum"): | |
| active_params = [p for p in params if p.grad is not None] | |
| if active_params: | |
| # Ensure momentum buffers exist (avoid zeros_like when already present). | |
| for p in active_params: | |
| if "momentum_buffer" not in self.state[p]: | |
| self.state[p]["momentum_buffer"] = torch.zeros_like( | |
| p.grad) | |
| # Extract local tensors for compiled batch function. | |
| local_grads = [ | |
| p.grad._local_tensor | |
| if isinstance(p.grad, DTensor) else p.grad | |
| for p in active_params | |
| ] | |
| local_bufs = [ | |
| self.state[p]["momentum_buffer"]._local_tensor | |
| if isinstance(self.state[p]["momentum_buffer"], DTensor) | |
| else self.state[p]["momentum_buffer"] | |
| for p in active_params | |
| ] | |
| # Wrap momentum as tensor for torch.compile. | |
| batch_pre_ortho(local_grads, local_bufs, | |
| torch.tensor(momentum), group["nesterov"]) | |
| # For non-nesterov, the result is the momentum buffer. | |
| if not group["nesterov"]: | |
| for p in active_params: | |
| p.grad = self.state[p]["momentum_buffer"] | |
| # Identify batched experts for deferred NS. | |
| # Detection is cheap (condition checks only); actual NS compute is | |
| # deferred so it can overlap with the first chunk's A2A gather. | |
| deferred_expert_work = [] | |
| if self.expert_keys: | |
| batched_expert_indices = [] | |
| for i, (n, p) in enumerate(zip(names, params)): | |
| if not (is_expert_param(n, self.expert_keys) | |
| and p.grad is not None): | |
| continue | |
| # Eligible: plain tensor, or DTensor with no non-dim-0 shards. | |
| if isinstance(p.data, DTensor): | |
| has_tp = any( | |
| _is_shard(pl) and pl.dim != 0 for pl in p.placements) | |
| if has_tp: | |
| continue | |
| batched_expert_indices.append(i) | |
| if batched_expert_indices: | |
| # Save refs for deferred NS; free grads from param list. | |
| for i in batched_expert_indices: | |
| p = params[i] | |
| g = p.grad | |
| local_g = (g._local_tensor | |
| if isinstance(g, DTensor) else g) | |
| local_data = (p.data._local_tensor if isinstance( | |
| p.data, DTensor) else p.data) | |
| deferred_expert_work.append((local_data, local_g)) | |
| p.grad = None | |
| # Remove batched experts from lists before expansion. | |
| keep = sorted( | |
| set(range(len(params))) - set(batched_expert_indices)) | |
| names = [names[i] for i in keep] | |
| params = [params[i] for i in keep] | |
| def _run_deferred_expert_ns(): | |
| """Execute deferred batched expert NS.""" | |
| if not deferred_expert_work: | |
| return | |
| with record_function("muon::batched_expert_ns"): | |
| ns_steps = group["ns_steps"] | |
| for local_data, local_g in deferred_expert_work: | |
| u = zeropower_via_newtonschulz5_batched( | |
| local_g.to(COMM_DTYPE), steps=ns_steps) | |
| adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) | |
| local_data.mul_(1 - lr * weight_decay) | |
| local_data.add_(u, alpha=-adjusted_lr) | |
| # Expand expert params by splitting on dim 0. | |
| logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", | |
| len(params), self.expert_keys) | |
| if self.expert_keys: | |
| cache_key = tuple(id(p) for p in params) | |
| cache = self._expert_expand_cache.get(cache_key) | |
| if cache is None: | |
| # Cold path: full expansion + build cache metadata. | |
| exp_names, exp_params = _expand_expert_params( | |
| names, params, self.expert_keys) | |
| # Build per-expert-group info for hot-path grad updates. | |
| grad_info = [] | |
| exp_idx = 0 | |
| for orig_idx, (n, p) in enumerate(zip(names, params)): | |
| if not is_expert_param(n, self.expert_keys): | |
| exp_idx += 1 | |
| continue | |
| is_dt = isinstance(p.data, DTensor) | |
| num_experts = (p.to_local() if is_dt else p.data).shape[0] | |
| # Detect TP mesh from the first expanded expert param. | |
| tp_mesh = None | |
| tp_pls = None | |
| sample = exp_params[exp_idx] | |
| if isinstance(sample.data, DTensor): | |
| tp_mesh = sample.data.device_mesh | |
| tp_pls = list(sample.data.placements) | |
| grad_info.append((orig_idx, num_experts, exp_idx, is_dt, | |
| tp_mesh, tp_pls)) | |
| exp_idx += num_experts | |
| self._expert_expand_cache[cache_key] = { | |
| 'names': exp_names, | |
| 'params': exp_params, | |
| 'grad_info': grad_info, | |
| } | |
| names, params = exp_names, exp_params | |
| else: | |
| # Hot path: reuse cached params, only update expert grads. | |
| for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, | |
| tp_pls) in cache['grad_info']: | |
| p = params[orig_idx] | |
| g = p.grad | |
| local_grad = (g.to_local() | |
| if is_dt and isinstance(g, DTensor) else g) | |
| for i in range(num_experts): | |
| expert_p = cache['params'][exp_start + i] | |
| sg = local_grad[i] | |
| if tp_mesh is not None: | |
| expert_p.grad = DTensor.from_local( | |
| sg, device_mesh=tp_mesh, placements=tp_pls) | |
| else: | |
| expert_p.grad = sg | |
| p.grad = None | |
| names = cache['names'] | |
| params = cache['params'] | |
| else: | |
| names, params = _expand_expert_params(names, params, | |
| self.expert_keys) | |
| logger.debug("[_step_muon] after expand: %d params", len(params)) | |
| param_dtensors = [] | |
| name_dtensors = [] | |
| param_tensors = [] | |
| name_tensors = [] | |
| # distributed_muon is a reference implementation for testing only. | |
| # The parallel pipeline (all2all) path below is the production path. | |
| if self.use_distributed_muon: | |
| _run_deferred_expert_ns() | |
| self.distributed_muon(names=names, | |
| params=params, | |
| group=group, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| qk_logits=qk_logits) | |
| return | |
| for n, p in zip(names, params): | |
| if p is None or p.grad is None: | |
| continue | |
| if isinstance(p.data, DTensor): | |
| if all( | |
| isinstance(placement, Replicate) | |
| for placement in p.placements): | |
| logger.debug( | |
| "[route] %s → base (DTensor all-Replicate), " | |
| "shape=%s, placements=%s", n, p.shape, p.placements) | |
| param_tensors.append(p) | |
| name_tensors.append(n) | |
| else: | |
| logger.debug( | |
| "[route] %s → parallel (DTensor), shape=%s, " | |
| "placements=%s, mesh=%s", n, p.shape, p.placements, | |
| p.device_mesh.mesh_dim_names) | |
| param_dtensors.append(p) | |
| name_dtensors.append(n) | |
| elif isinstance(p.data, torch.Tensor): | |
| logger.debug("[route] %s → base (plain tensor), shape=%s", n, | |
| p.data.shape) | |
| param_tensors.append(p) | |
| name_tensors.append(n) | |
| else: | |
| raise TypeError(f"Unsupported parameter type: {type(p.data)}") | |
| logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " | |
| f"{len(param_tensors)} Tensors → base") | |
| def group_dtensors(dtensors, names): | |
| # To support different placements, we group parameters by placements | |
| # and run parallel Muon on each group. | |
| placement_to_params = defaultdict(lambda: ([], [])) | |
| assert len(dtensors) == len(names) | |
| for p, n in zip(dtensors, names): | |
| placement_to_params[tuple([p.placements, | |
| p.device_mesh])][0].append(n) | |
| placement_to_params[tuple([p.placements, | |
| p.device_mesh])][1].append(p) | |
| return placement_to_params | |
| if len(param_dtensors) > 0: | |
| if not dist.is_initialized(): | |
| raise RuntimeError( | |
| "Parallel Muon requires torch.distributed to be initialized." | |
| ) | |
| dtensor_group = group_dtensors(param_dtensors, name_dtensors) | |
| # Pre-launch the first chunk's A2A gather so that the NCCL | |
| # communication overlaps with the (deferred) batched expert NS | |
| # compute on the default CUDA stream. | |
| prelaunch = None | |
| if deferred_expert_work: | |
| first_names, first_params = next(iter(dtensor_group.values())) | |
| ordered, pts, rnk, csz = self._setup_parallel( | |
| first_names, first_params, group, qk_logits) | |
| first_chunk = ordered[:csz] | |
| if first_chunk: | |
| prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, | |
| group["none_grad"]) | |
| _run_deferred_expert_ns() | |
| first_group = True | |
| for _, (names, params) in dtensor_group.items(): | |
| pg = prelaunch if first_group else None | |
| first_group = False | |
| self.parallel( | |
| names, | |
| params, | |
| group, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| qk_logits=qk_logits, | |
| prelaunch_gather=pg, | |
| ) | |
| else: | |
| _run_deferred_expert_ns() | |
| if len(param_tensors) > 0: | |
| self.base( | |
| name_tensors, | |
| param_tensors, | |
| group, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| qk_logits=qk_logits, | |
| ) | |
| def _register_states_for_offload(self): | |
| """Register all optimizer state tensors with the CPU offload pool. | |
| Called once after the first step when states have been lazily created. | |
| Offloads all param states (momentum buffers for Muon, moment1/moment2 | |
| for AdamW) to free GPU memory between steps. | |
| """ | |
| pool = self._cpu_offload_pool | |
| tracked = 0 | |
| for group in self.param_groups: | |
| for p in group["params"]: | |
| if p not in self.state: | |
| continue | |
| state = self.state[p] | |
| if group.get("use_muon", False): | |
| if "momentum_buffer" in state: | |
| pool.track(state["momentum_buffer"]) | |
| tracked += 1 | |
| else: | |
| if "moment1" in state: | |
| pool.track(state["moment1"]) | |
| if "moment2" in state: | |
| pool.track(state["moment2"]) | |
| tracked += 1 | |
| logger.info("[CPUOffload] Registered %d param states for offload", | |
| tracked) | |
| def step(self, closure=None, qk_logits=None): | |
| """Perform a single optimization step. | |
| Args: | |
| closure (Callable, optional): A closure that reevaluates the model | |
| and returns the loss. | |
| qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices | |
| to 1D tensors of shape (num_heads,), representing the maximum | |
| QK logits across all tokens, computed as | |
| (1 / sqrt(head_dim)) * (Q @ K^T). | |
| """ | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| # H2D: reload optimizer states from CPU before computation. | |
| if self.cpu_offload and self._offload_initialized: | |
| self._cpu_offload_pool.reload() | |
| logger.debug("[Muon.step] expert_keys=%s, %d param groups", | |
| self.expert_keys, len(self.param_groups)) | |
| for i, group in enumerate(self.param_groups): | |
| if group["use_muon"]: | |
| logger.debug("[Muon.step] group %d: use_muon=True, %d params", | |
| i, len(group["params"])) | |
| self._step_muon(group, qk_logits=qk_logits) | |
| else: | |
| logger.debug( | |
| "[Muon.step] group %d: use_muon=False (AdamW), %d params", | |
| i, len(group["params"])) | |
| step_adamw(self.state, group) | |
| # D2H: offload optimizer states to CPU after computation. | |
| if self.cpu_offload: | |
| if not self._offload_initialized: | |
| if self._cpu_offload_pool is None: | |
| self._cpu_offload_pool = CPUOffloadPool() | |
| self._register_states_for_offload() | |
| self._offload_initialized = True | |
| self._cpu_offload_pool.offload() | |
| return loss | |
| # ------------------------------------------------------------------ | |
| # CPU offload public helpers | |
| # ------------------------------------------------------------------ | |
| def turn_on_cpu_offload(self): | |
| """Enable CPU offload for optimizer states.""" | |
| if self.cpu_offload: | |
| return | |
| logger.info("[Muon] turn_on_cpu_offload") | |
| self.cpu_offload = True | |
| if not self.state: | |
| return | |
| self._cpu_offload_pool = CPUOffloadPool() | |
| self._offload_initialized = False | |
| self._register_states_for_offload() | |
| self._offload_initialized = True | |
| self._cpu_offload_pool.offload() | |
| def turn_off_cpu_offload(self): | |
| """Disable CPU offload and keep optimizer states resident on GPU.""" | |
| if not self.cpu_offload: | |
| return | |
| logger.info("[Muon] turn_off_cpu_offload") | |
| if self._offload_initialized: | |
| self._cpu_offload_pool.reload() | |
| torch.cuda.current_stream().synchronize() | |
| self._cpu_offload_pool = None | |
| self._offload_initialized = False | |
| self.cpu_offload = False | |
| # ------------------------------------------------------------------ | |
| # Checkpoint support for cpu_offload | |
| # ------------------------------------------------------------------ | |
| def state_dict(self) -> dict: | |
| if self.cpu_offload: | |
| raise RuntimeError( | |
| "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." | |
| ) | |
| return super().state_dict() | |
| def load_state_dict(self, state_dict: dict) -> None: | |
| if self.cpu_offload: | |
| raise RuntimeError( | |
| "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." | |
| ) | |
| super().load_state_dict(state_dict) | |
| # Invalidate adamw.py's module-level tensor caches so that | |
| # the next step rebuilds them with the newly loaded state tensors. | |
| _placement_cache.clear() | |
| _tensor_cache.clear() | |