import logging import types from collections import defaultdict from typing import Any import torch import torch.distributed as dist from torch.distributed.tensor import DTensor, Replicate, Shard from torch.profiler import record_function from .adamw import _placement_cache, _tensor_cache, step_adamw from .async_utils import run_pipeline from .core import (_muon_state, adjust_lr_for_muon, batch_pre_ortho, get_default_muon_param_groups, is_expert_param, update_p) from .cpu_offload import CPUOffloadPool from .distributed.utils import (_is_shard, construct_shard_mesh, get_slices_of_dtensor) from .newton_schulz import (COMM_DTYPE, DEFAULT_CHUNK_SIZE_RATIO, _zeropower_via_newtonschulz5, zeropower_via_newtonschulz5, zeropower_via_newtonschulz5_batched) from .pipeline import muon_chunk_pipeline, prelaunch_first_gather from .qk_clip import compute_scales, get_qk_clip_info, qk_clip logger = logging.getLogger(__name__) def _expand_expert_params(names, params, expert_keys): """Expand expert params by splitting on dim 0 (expert dimension). Params whose name matches any key in ``expert_keys`` are treated as expert-parallel tensors. Their outermost dimension is the expert dimension: an ``(E, out, in)`` tensor becomes ``E`` separate 2D ``nn.Parameter`` views so that in-place updates propagate back to the original storage. Non-expert params with ``ndim > 2`` trigger an ``AssertionError`` — if they are expert params, their key must be added to ``expert_keys``. The grad must already be set on each expert param (e.g. after momentum). For DTensor expert params, placements that shard on dim 0 (expert dim) are consumed by the split. Non-dim-0 shard placements (e.g. TP) are preserved: each 2D slice is wrapped as a DTensor on the corresponding submesh so the parallel pipeline handles the TP communication. """ expanded_names = [] expanded_params = [] for n, p in zip(names, params): is_expert = is_expert_param(n, expert_keys) is_dtensor = isinstance(p.data, DTensor) if is_expert: if is_dtensor: logger.debug( "[expand_expert] %s: expert DTensor, shape=%s, " "placements=%s, mesh=%s, local_shape=%s", n, p.shape, p.placements, p.device_mesh.mesh_dim_names, p.to_local().shape) else: logger.debug( "[expand_expert] %s: expert plain tensor, shape=%s", n, p.data.shape) if not is_expert: assert p.data.ndim <= 2, ( f"Param {n} has ndim={p.data.ndim} but does not match " f"expert_keys={expert_keys}. If this is an expert param, " f"add its key to expert_keys.") expanded_names.append(n) expanded_params.append(p) continue g = p.grad assert g is not None, ( f"Expert param {n} must have grad set before expansion") tp_mesh = None tp_placements_2d = None if is_dtensor: local_data = p.to_local() local_grad = g.to_local() if isinstance(g, DTensor) else g # Find non-dim-0 shard placements (e.g. TP sharding). # After splitting on dim 0, Shard(k) becomes Shard(k-1). tp_dim_indices = [] tp_placements_2d = [] for i, pl in enumerate(p.placements): if _is_shard(pl) and pl.dim != 0: tp_dim_indices.append(i) tp_placements_2d.append(Shard(pl.dim - 1)) if tp_dim_indices: tp_dim_names = tuple(p.device_mesh.mesh_dim_names[i] for i in tp_dim_indices) if len(tp_dim_names) == 1: tp_mesh = p.device_mesh[tp_dim_names[0]] else: tp_mesh = p.device_mesh[tp_dim_names] else: local_data = p.data local_grad = g # Expand: split dim 0, reshape each slice to 2D. num_local_experts = local_data.shape[0] for i in range(num_local_experts): slice_data = local_data[i] slice_grad = local_grad[i] if tp_mesh is not None: # Wrap as DTensor on TP submesh so the pipeline handles # TP communication (gather/scatter across TP ranks). dt_data = DTensor.from_local(slice_data, device_mesh=tp_mesh, placements=tp_placements_2d) dt_grad = DTensor.from_local(slice_grad, device_mesh=tp_mesh, placements=tp_placements_2d) expert_param = torch.nn.Parameter(dt_data, requires_grad=False) expert_param.grad = dt_grad else: expert_param = torch.nn.Parameter(slice_data, requires_grad=False) expert_param.grad = slice_grad expanded_names.append(f"{n}[{i}]") expanded_params.append(expert_param) p.grad = None # allow expert grad storage to be freed after pipeline return expanded_names, expanded_params class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has the advantage that it can be stably run in bfloat16 on the GPU. Some warnings: - We believe this optimizer is unlikely to work well for training with small batch size. - We believe it may not work well for finetuning pretrained models, but we haven't tested this. Arguments: model: The model to be optimized by Muon. is_muon_func: A function that takes a parameter and its name, and returns whether the parameter should be optimized by Muon. lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) momentum: The momentum used by the internal SGD. (0.95 is a good default) nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) weight_decay: The weight decay for Muon and AdamW. Parameters that are {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW instead. adamw_lr: The learning rate for the internal AdamW. adamw_betas: The betas for the internal AdamW. adamw_eps: The epsilon for the internal AdamW. none_grad: Whether to set p.grad to None after gathering the gradients. This can save memory. debug: Whether to print debug information. clip_info : Configuration for QK clipping. Expected keys: - "q_indices" (list[int]): Indices of query heads to consider. - "k_indices" (list[int]): Indices of key heads to consider. - "head_dim" (int): Dimensionality of each attention head. - "threshold" (float): Threshold value; heads whose QK logits exceed this value will be scaled down. Default is: { "q_indices": [], "k_indices": [], "head_dim": 128, "threshold": 100 } warmup_step : How many all2all gather, compute operations are launched in advance before the corresponding all2all scatter steps begin. A higher warmup_step increases memory usage but can improve performance by overlapping communication. Parallel muon only. chunk_size : Batch size of parameters to process in each all2all gather/compute/scatter step. Use shard ranks * DEFAULT_CHUNK_SIZE_RATIO when -1 is specified. use_distributed_muon: Use distributed muon by Liu et al. (2024). For testing purpose only. expert_keys: List of strings to identify expert-parallel parameters. If any key appears in a parameter's name, its outermost dimension is treated as the expert dimension and expanded into per-expert 2D params for Muon. For example, ``expert_keys=["experts"]`` matches any param whose name contains "experts". 3D+ params not matched by any key will raise an error. """ def __init__(self, params, lr=1e-3, momentum=0.95, nesterov=True, ns_steps=5, weight_decay=0.1, adamw_betas=(0.9, 0.95), adamw_eps=1e-8, none_grad=True, debug=False, clip_config=None, warmup_step=5, chunk_size=-1, use_distributed_muon=False, expert_keys=None): defaults = dict( lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, adamw_betas=adamw_betas, adamw_eps=adamw_eps, none_grad=none_grad, use_muon=True, ) error_message = "The key 'use_muon' is not set in parameter group {idx}. Assuming all parameters in the group will use muon optimization, which may lead to unexpected behavior." instruction_code = "\n\n please follow this code snippet \n```optimizer = get_kernel('motif-technologies/optimizer')\n\n\nparams = optimizer.muon.get_default_muon_param_groups(model)\n\noptim = optimizer.Muon(params, ...)```" if isinstance(params, types.GeneratorType): raise ValueError(error_message.format(idx=0) + instruction_code) for _idx, param_group in enumerate(params): if param_group.get("use_muon", None) is None: raise ValueError( error_message.format(idx=_idx) + instruction_code) super().__init__(params, defaults) self.debug = debug self.clip_config = clip_config if clip_config is not None else { "q_indices": [], "k_indices": [], "head_dim": 128, "threshold": 100, } self.warmup_step = warmup_step self.chunk_size = chunk_size self.use_distributed_muon = use_distributed_muon self.expert_keys = expert_keys self.cpu_offload = False self._cpu_offload_pool: CPUOffloadPool | None = None self._offload_initialized = False self._parallel_cache: dict[tuple[str, ...], dict] = {} self._expert_expand_cache: dict[tuple[int, ...], dict] = {} def _calc_flops(self, G, steps): assert len(G.shape) == 2 M, N = G.shape if M > N: M, N = N, M return steps * ((M**3) * 2 + (M**2 * N) * 4 + M * N * 2 + M**2 * 3) def get_shard_mesh(self, p): """ Get the shard mesh for a parameter p on the given rank. """ assert isinstance( p, DTensor), "Parallel Muon only supports DTensor parameters." shard_mesh, shard_pg, shard_placements = construct_shard_mesh( p.placements, p.device_mesh) return shard_mesh, shard_pg, shard_placements def init_state_and_assign_params(self, names, params, group, qk_logits): param_to_state = {} param_to_flops = {} total_flops = 0 for p in params: g = p.grad if g is None: continue assert g.ndim == 2, "Muon only supports 2D parameters." flops = self._calc_flops(g, group["ns_steps"]) param_to_flops[id(p)] = flops total_flops += flops if self.debug: logger.debug("Total TFLOPs for Muon: %.2f TFLOPs", total_flops / 1e12) paired = list(zip(names, params)) paired_sorted = sorted(paired, key=lambda x: param_to_flops[id(x[1])], reverse=True) names_sorted, params_sorted = zip(*paired_sorted) ordered_names = list(names_sorted) ordered_params = list(params_sorted) round_robin = 0 mesh = ordered_params[0].device_mesh placements = ordered_params[0].placements shard_mesh, shard_pg, shard_placements = self.get_shard_mesh( ordered_params[0]) shard_mesh_flattened = shard_mesh.mesh.flatten() num_ranks = dist.get_world_size(group=shard_pg) for n, p in zip(ordered_names, ordered_params): if mesh != p.device_mesh: raise ValueError("All parameters must be on the same mesh.") if placements != p.placements: raise ValueError("All parameters must have same placements.") worker_rank = shard_mesh_flattened[round_robin].item() % num_ranks round_robin = (round_robin + 1) % len(shard_mesh_flattened) qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) # Precompute per-rank indices and numels for all-to-all. rank_indices: dict[int, tuple] = {} rank_numels: dict[int, int] = {} for r in range(num_ranks): indices = get_slices_of_dtensor(p, r, shard_mesh, shard_placements) rank_indices[r] = indices numel = 1 for idx, dim_size in zip(indices, p.shape): if isinstance(idx, slice): start, stop, step = idx.indices(dim_size) numel *= max(0, (stop - start + (step - 1)) // step) else: numel *= len(idx) rank_numels[r] = numel param_to_state[id(p)] = _muon_state( worker_rank=worker_rank, process_group=shard_pg, rank_indices=rank_indices, rank_numels=rank_numels, name=n, qk_clip_state=qk_clip_state, ) return param_to_state, ordered_params def base(self, names, params, group, lr, weight_decay, qk_logits): # Momentum is already applied by _step_muon before this method. for n, p in zip(names, params): g = p.grad if g is None: continue u = zeropower_via_newtonschulz5(g.to(COMM_DTYPE), steps=group["ns_steps"]) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: qk_clip(p, scales_full, qk_clip_state) def distributed_muon( self, names: list[str], params: list[torch.nn.Parameter], group: dict[str, Any], lr: float, weight_decay: float, qk_logits: list[torch.Tensor | DTensor] | None, ): """Batched Distributed Muon — for testing/correctness verification only. Uses all-gather to reconstruct full tensors, computes Newton-Schulz on the full grad, then slices back to local shards. This is simpler but slower than the parallel pipeline (all2all) path, so it serves as a reference implementation for verifying correctness. """ with record_function("distributed_muon"): # Momentum is already applied by _step_muon before this method. ns_steps = group["ns_steps"] # Separate plain tensors (no communication) from DTensors. plain_names, plain_params = [], [] dtensor_names, dtensor_params = [], [] for n, p in zip(names, params): if p.grad is None: continue if isinstance(p.data, DTensor): dtensor_names.append(n) dtensor_params.append(p) else: plain_names.append(n) plain_params.append(p) # Process plain tensors per-param (no communication). for n, p in zip(plain_names, plain_params): u = _zeropower_via_newtonschulz5(p.grad.to(COMM_DTYPE), steps=ns_steps) adjusted_lr = adjust_lr_for_muon(lr, p.shape) update_p(p, u, lr, adjusted_lr, weight_decay) qk_clip_state = get_qk_clip_info(self.clip_config, n, qk_logits) scales_full = compute_scales( p, qk_clip_state) if qk_clip_state is not None else None if scales_full is not None: qk_clip(p, scales_full, qk_clip_state) if not dtensor_params: return # Group DTensors by (placements, mesh) for batched all-gather. placement_groups: dict[tuple, tuple[list, list]] = defaultdict(lambda: ([], [])) for n, p in zip(dtensor_names, dtensor_params): key = (p.placements, p.device_mesh) placement_groups[key][0].append(n) placement_groups[key][1].append(p) logger.info( "distributed_muon: %d placement groups, %d total dtensors", len(placement_groups), len(dtensor_params)) for (placements, mesh), (grp_names, grp_params) in placement_groups.items(): shard_mesh, shard_pg, shard_placements = construct_shard_mesh( placements, mesh) rank = dist.get_rank(shard_pg) world_size = dist.get_world_size(shard_pg) logger.info(" group: %d params, placements=%s, world_size=%d", len(grp_params), placements, world_size) # Separate params that can be batched (all shard dims evenly # divisible) from those needing per-param full_tensor # (e.g. MoE gate weights with fewer rows than shard ranks). # all_gather_into_tensor requires equal buffer sizes across # ranks, so uneven splits must use DTensor full_tensor(). batch_names, batch_params = [], [] single_names, single_params = [], [] for n, p in zip(grp_names, grp_params): even = all(p.shape[pl.dim] % shard_mesh.mesh.shape[dim_idx] == 0 for dim_idx, pl in enumerate(shard_placements)) if even: batch_names.append(n) batch_params.append(p) else: single_names.append(n) single_params.append(p) # Process uneven-split params per-param via full_tensor(). for n, p in zip(single_names, single_params): with record_function("distributed_muon::newton_schulz"): g_full = p.grad.full_tensor().to(COMM_DTYPE) u_full = _zeropower_via_newtonschulz5(g_full, steps=ns_steps) del g_full with record_function("distributed_muon::update"): adjusted_lr = adjust_lr_for_muon(lr, p.shape) p._local_tensor.mul_(1 - lr * weight_decay) local_indices = get_slices_of_dtensor( p, rank, shard_mesh, shard_placements) u_local = u_full[local_indices] p._local_tensor.add_(u_local, alpha=-adjusted_lr) del u_full qk_clip_state = get_qk_clip_info( self.clip_config, n, qk_logits) scales_full = compute_scales( p, qk_clip_state ) if qk_clip_state is not None else None if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = local_indices[0] if isinstance(idx0, slice): start = idx0.start or 0 idx0 = torch.arange(start, idx0.stop, device=scales_full.device) row_scales = scales_full[idx0 // ratio] p._local_tensor.mul_(row_scales.view(-1, 1)) if not batch_params: continue logger.info(" batched=%d, single=%d", len(batch_params), len(single_params)) # Concat all local grad shards into a single flat buffer. with record_function("distributed_muon::gather"): grad_locals = [ p.grad.to_local().to(COMM_DTYPE).flatten() for p in batch_params ] numels = [g.numel() for g in grad_locals] grad_concat = torch.cat(grad_locals) del grad_locals # Single all-gather (replaces N separate full_tensor). grad_gathered = torch.empty( grad_concat.numel() * world_size, dtype=COMM_DTYPE, device="cuda", ) dist.all_gather_into_tensor(grad_gathered, grad_concat, group=shard_pg) total_numel = grad_concat.numel() del grad_concat # Precompute per-param offsets within the concat buffer. offsets = [] off = 0 for ne in numels: offsets.append(off) off += ne # Per-param: reconstruct full grad → NS → local update. for i, (n, p) in enumerate(zip(batch_names, batch_params)): with record_function("distributed_muon::newton_schulz"): g_full = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") for r in range(world_size): r_start = r * total_numel + offsets[i] shard = grad_gathered[r_start:r_start + numels[i]] indices = get_slices_of_dtensor( p, r, shard_mesh, shard_placements) g_full[indices] = shard.reshape( g_full[indices].shape) u_full = _zeropower_via_newtonschulz5(g_full, steps=ns_steps) del g_full with record_function("distributed_muon::update"): adjusted_lr = adjust_lr_for_muon(lr, p.shape) p._local_tensor.mul_(1 - lr * weight_decay) local_indices = get_slices_of_dtensor( p, rank, shard_mesh, shard_placements) u_local = u_full[local_indices] p._local_tensor.add_(u_local, alpha=-adjusted_lr) del u_full qk_clip_state = get_qk_clip_info( self.clip_config, n, qk_logits) scales_full = compute_scales( p, qk_clip_state ) if qk_clip_state is not None else None if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = local_indices[0] if isinstance(idx0, slice): start = idx0.start or 0 idx0 = torch.arange(start, idx0.stop, device=scales_full.device) row_scales = scales_full[idx0 // ratio] p._local_tensor.mul_(row_scales.view(-1, 1)) def _setup_parallel(self, names, params, group, qk_logits): """Compute (or retrieve cached) parallel pipeline metadata. Returns: (ordered_params, param_to_state, rank, chunk_size) """ cache_key = tuple(names) if cache_key not in self._parallel_cache: # First call: compute metadata and populate cache. param_to_state, ordered_params = self.init_state_and_assign_params( names, params, group, qk_logits) shard_pg = param_to_state[id(ordered_params[0])].process_group rank = dist.get_rank(group=shard_pg) if self.chunk_size == -1: shard_ranks = dist.get_world_size(shard_pg) chunk_size = shard_ranks * DEFAULT_CHUNK_SIZE_RATIO elif self.chunk_size > 0: chunk_size = self.chunk_size else: raise ValueError( "chunk_size must be -1 or a positive integer.") ordered_names = [ param_to_state[id(p)].name for p in ordered_params ] name_to_state = { param_to_state[id(p)].name: param_to_state[id(p)] for p in ordered_params } self._parallel_cache[cache_key] = { 'ordered_names': ordered_names, 'name_to_state': name_to_state, 'rank': rank, 'chunk_size': chunk_size, } else: # Cached path: rebuild param_to_state with current id(p) keys. cache = self._parallel_cache[cache_key] rank = cache['rank'] chunk_size = cache['chunk_size'] name_to_param = dict(zip(names, params)) ordered_params = [name_to_param[n] for n in cache['ordered_names']] param_to_state = {} for p, n in zip(ordered_params, cache['ordered_names']): cached_state = cache['name_to_state'][n] param_to_state[id(p)] = _muon_state( worker_rank=cached_state.worker_rank, process_group=cached_state.process_group, rank_indices=cached_state.rank_indices, rank_numels=cached_state.rank_numels, name=n, qk_clip_state=get_qk_clip_info(self.clip_config, n, qk_logits), ) return ordered_params, param_to_state, rank, chunk_size def parallel(self, names, params, group, lr, weight_decay, qk_logits, prelaunch_gather=None): """ Perform a parallel optimization step using Muon. Parameters are chunked and each chunk is processed by a :func:`muon_chunk_pipeline` generator. :func:`run_pipeline` interleaves multiple chunks so that communication and computation overlap across chunks (the same overlap previously achieved by the warmup + main-loop index scheduling). If ``prelaunch_gather`` is provided, it is passed to the first chunk's generator to skip re-launching the already in-flight A2A gather. """ # Momentum is already applied by _step_muon before this method. ordered_params, param_to_state, rank, chunk_size = ( self._setup_parallel(names, params, group, qk_logits)) def pipelines(): first = True for start in range(0, len(ordered_params), chunk_size): chunk = ordered_params[start:start + chunk_size] if chunk: kwargs = dict( params=chunk, param_to_state=param_to_state, rank=rank, ns_steps=group["ns_steps"], lr=lr, weight_decay=weight_decay, none_grad=group["none_grad"], ) if first and prelaunch_gather is not None: kwargs['prelaunch_gather'] = prelaunch_gather first = False yield muon_chunk_pipeline(**kwargs) with record_function("muon::pipeline"): run_pipeline(pipelines(), max_concurrent=self.warmup_step + 1) def _step_muon(self, group, qk_logits=None): params = group["params"] lr = group["lr"] weight_decay = group["weight_decay"] momentum = group["momentum"] names = group["names"] # Apply momentum to all params before routing/expansion. # Batched using _foreach_* ops (compiled, fullgraph=True). with record_function("muon::momentum"): active_params = [p for p in params if p.grad is not None] if active_params: # Ensure momentum buffers exist (avoid zeros_like when already present). for p in active_params: if "momentum_buffer" not in self.state[p]: self.state[p]["momentum_buffer"] = torch.zeros_like( p.grad) # Extract local tensors for compiled batch function. local_grads = [ p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad for p in active_params ] local_bufs = [ self.state[p]["momentum_buffer"]._local_tensor if isinstance(self.state[p]["momentum_buffer"], DTensor) else self.state[p]["momentum_buffer"] for p in active_params ] # Wrap momentum as tensor for torch.compile. batch_pre_ortho(local_grads, local_bufs, torch.tensor(momentum), group["nesterov"]) # For non-nesterov, the result is the momentum buffer. if not group["nesterov"]: for p in active_params: p.grad = self.state[p]["momentum_buffer"] # Identify batched experts for deferred NS. # Detection is cheap (condition checks only); actual NS compute is # deferred so it can overlap with the first chunk's A2A gather. deferred_expert_work = [] if self.expert_keys: batched_expert_indices = [] for i, (n, p) in enumerate(zip(names, params)): if not (is_expert_param(n, self.expert_keys) and p.grad is not None): continue # Eligible: plain tensor, or DTensor with no non-dim-0 shards. if isinstance(p.data, DTensor): has_tp = any( _is_shard(pl) and pl.dim != 0 for pl in p.placements) if has_tp: continue batched_expert_indices.append(i) if batched_expert_indices: # Save refs for deferred NS; free grads from param list. for i in batched_expert_indices: p = params[i] g = p.grad local_g = (g._local_tensor if isinstance(g, DTensor) else g) local_data = (p.data._local_tensor if isinstance( p.data, DTensor) else p.data) deferred_expert_work.append((local_data, local_g)) p.grad = None # Remove batched experts from lists before expansion. keep = sorted( set(range(len(params))) - set(batched_expert_indices)) names = [names[i] for i in keep] params = [params[i] for i in keep] def _run_deferred_expert_ns(): """Execute deferred batched expert NS.""" if not deferred_expert_work: return with record_function("muon::batched_expert_ns"): ns_steps = group["ns_steps"] for local_data, local_g in deferred_expert_work: u = zeropower_via_newtonschulz5_batched( local_g.to(COMM_DTYPE), steps=ns_steps) adjusted_lr = adjust_lr_for_muon(lr, local_g.shape[1:]) local_data.mul_(1 - lr * weight_decay) local_data.add_(u, alpha=-adjusted_lr) # Expand expert params by splitting on dim 0. logger.debug("[_step_muon] before expand: %d params, expert_keys=%s", len(params), self.expert_keys) if self.expert_keys: cache_key = tuple(id(p) for p in params) cache = self._expert_expand_cache.get(cache_key) if cache is None: # Cold path: full expansion + build cache metadata. exp_names, exp_params = _expand_expert_params( names, params, self.expert_keys) # Build per-expert-group info for hot-path grad updates. grad_info = [] exp_idx = 0 for orig_idx, (n, p) in enumerate(zip(names, params)): if not is_expert_param(n, self.expert_keys): exp_idx += 1 continue is_dt = isinstance(p.data, DTensor) num_experts = (p.to_local() if is_dt else p.data).shape[0] # Detect TP mesh from the first expanded expert param. tp_mesh = None tp_pls = None sample = exp_params[exp_idx] if isinstance(sample.data, DTensor): tp_mesh = sample.data.device_mesh tp_pls = list(sample.data.placements) grad_info.append((orig_idx, num_experts, exp_idx, is_dt, tp_mesh, tp_pls)) exp_idx += num_experts self._expert_expand_cache[cache_key] = { 'names': exp_names, 'params': exp_params, 'grad_info': grad_info, } names, params = exp_names, exp_params else: # Hot path: reuse cached params, only update expert grads. for (orig_idx, num_experts, exp_start, is_dt, tp_mesh, tp_pls) in cache['grad_info']: p = params[orig_idx] g = p.grad local_grad = (g.to_local() if is_dt and isinstance(g, DTensor) else g) for i in range(num_experts): expert_p = cache['params'][exp_start + i] sg = local_grad[i] if tp_mesh is not None: expert_p.grad = DTensor.from_local( sg, device_mesh=tp_mesh, placements=tp_pls) else: expert_p.grad = sg p.grad = None names = cache['names'] params = cache['params'] else: names, params = _expand_expert_params(names, params, self.expert_keys) logger.debug("[_step_muon] after expand: %d params", len(params)) param_dtensors = [] name_dtensors = [] param_tensors = [] name_tensors = [] # distributed_muon is a reference implementation for testing only. # The parallel pipeline (all2all) path below is the production path. if self.use_distributed_muon: _run_deferred_expert_ns() self.distributed_muon(names=names, params=params, group=group, lr=lr, weight_decay=weight_decay, qk_logits=qk_logits) return for n, p in zip(names, params): if p is None or p.grad is None: continue if isinstance(p.data, DTensor): if all( isinstance(placement, Replicate) for placement in p.placements): logger.debug( "[route] %s → base (DTensor all-Replicate), " "shape=%s, placements=%s", n, p.shape, p.placements) param_tensors.append(p) name_tensors.append(n) else: logger.debug( "[route] %s → parallel (DTensor), shape=%s, " "placements=%s, mesh=%s", n, p.shape, p.placements, p.device_mesh.mesh_dim_names) param_dtensors.append(p) name_dtensors.append(n) elif isinstance(p.data, torch.Tensor): logger.debug("[route] %s → base (plain tensor), shape=%s", n, p.data.shape) param_tensors.append(p) name_tensors.append(n) else: raise TypeError(f"Unsupported parameter type: {type(p.data)}") logger.debug(f"[Muon] {len(param_dtensors)} DTensors → parallel, " f"{len(param_tensors)} Tensors → base") def group_dtensors(dtensors, names): # To support different placements, we group parameters by placements # and run parallel Muon on each group. placement_to_params = defaultdict(lambda: ([], [])) assert len(dtensors) == len(names) for p, n in zip(dtensors, names): placement_to_params[tuple([p.placements, p.device_mesh])][0].append(n) placement_to_params[tuple([p.placements, p.device_mesh])][1].append(p) return placement_to_params if len(param_dtensors) > 0: if not dist.is_initialized(): raise RuntimeError( "Parallel Muon requires torch.distributed to be initialized." ) dtensor_group = group_dtensors(param_dtensors, name_dtensors) # Pre-launch the first chunk's A2A gather so that the NCCL # communication overlaps with the (deferred) batched expert NS # compute on the default CUDA stream. prelaunch = None if deferred_expert_work: first_names, first_params = next(iter(dtensor_group.values())) ordered, pts, rnk, csz = self._setup_parallel( first_names, first_params, group, qk_logits) first_chunk = ordered[:csz] if first_chunk: prelaunch = prelaunch_first_gather(first_chunk, pts, rnk, group["none_grad"]) _run_deferred_expert_ns() first_group = True for _, (names, params) in dtensor_group.items(): pg = prelaunch if first_group else None first_group = False self.parallel( names, params, group, lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, prelaunch_gather=pg, ) else: _run_deferred_expert_ns() if len(param_tensors) > 0: self.base( name_tensors, param_tensors, group, lr=lr, weight_decay=weight_decay, qk_logits=qk_logits, ) def _register_states_for_offload(self): """Register all optimizer state tensors with the CPU offload pool. Called once after the first step when states have been lazily created. Offloads all param states (momentum buffers for Muon, moment1/moment2 for AdamW) to free GPU memory between steps. """ pool = self._cpu_offload_pool tracked = 0 for group in self.param_groups: for p in group["params"]: if p not in self.state: continue state = self.state[p] if group.get("use_muon", False): if "momentum_buffer" in state: pool.track(state["momentum_buffer"]) tracked += 1 else: if "moment1" in state: pool.track(state["moment1"]) if "moment2" in state: pool.track(state["moment2"]) tracked += 1 logger.info("[CPUOffload] Registered %d param states for offload", tracked) @torch.no_grad def step(self, closure=None, qk_logits=None): """Perform a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. qk_logits (dict[int, Tensor], optional): A dictionary mapping layer indices to 1D tensors of shape (num_heads,), representing the maximum QK logits across all tokens, computed as (1 / sqrt(head_dim)) * (Q @ K^T). """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() # H2D: reload optimizer states from CPU before computation. if self.cpu_offload and self._offload_initialized: self._cpu_offload_pool.reload() logger.debug("[Muon.step] expert_keys=%s, %d param groups", self.expert_keys, len(self.param_groups)) for i, group in enumerate(self.param_groups): if group["use_muon"]: logger.debug("[Muon.step] group %d: use_muon=True, %d params", i, len(group["params"])) self._step_muon(group, qk_logits=qk_logits) else: logger.debug( "[Muon.step] group %d: use_muon=False (AdamW), %d params", i, len(group["params"])) step_adamw(self.state, group) # D2H: offload optimizer states to CPU after computation. if self.cpu_offload: if not self._offload_initialized: if self._cpu_offload_pool is None: self._cpu_offload_pool = CPUOffloadPool() self._register_states_for_offload() self._offload_initialized = True self._cpu_offload_pool.offload() return loss # ------------------------------------------------------------------ # CPU offload public helpers # ------------------------------------------------------------------ def turn_on_cpu_offload(self): """Enable CPU offload for optimizer states.""" if self.cpu_offload: return logger.info("[Muon] turn_on_cpu_offload") self.cpu_offload = True if not self.state: return self._cpu_offload_pool = CPUOffloadPool() self._offload_initialized = False self._register_states_for_offload() self._offload_initialized = True self._cpu_offload_pool.offload() def turn_off_cpu_offload(self): """Disable CPU offload and keep optimizer states resident on GPU.""" if not self.cpu_offload: return logger.info("[Muon] turn_off_cpu_offload") if self._offload_initialized: self._cpu_offload_pool.reload() torch.cuda.current_stream().synchronize() self._cpu_offload_pool = None self._offload_initialized = False self.cpu_offload = False # ------------------------------------------------------------------ # Checkpoint support for cpu_offload # ------------------------------------------------------------------ def state_dict(self) -> dict: if self.cpu_offload: raise RuntimeError( "Muon.state_dict() requires turn_off_cpu_offload() before checkpoint save." ) return super().state_dict() def load_state_dict(self, state_dict: dict) -> None: if self.cpu_offload: raise RuntimeError( "Muon.load_state_dict() requires turn_off_cpu_offload() before checkpoint load." ) super().load_state_dict(state_dict) # Invalidate adamw.py's module-level tensor caches so that # the next step rebuilds them with the newly loaded state tensors. _placement_cache.clear() _tensor_cache.clear()