import logging import math from dataclasses import dataclass from typing import List import torch from torch.distributed import ProcessGroup from torch.distributed.tensor import DTensor # torch.compile wraps modules as OptimizedModule, inserting "_orig_mod" into # parameter FQNs. Activation checkpointing similarly inserts # "_checkpoint_wrapped_module". Strip these so name-based matching (skip_keys, # expert_keys, QK layer parsing) works regardless of wrapper nesting. _WRAPPER_PARTS = frozenset({"_orig_mod", "_checkpoint_wrapped_module"}) logger = logging.getLogger(__name__) def normalize_fqn(name: str) -> str: """Strip torch.compile / checkpoint wrapper components from a parameter FQN.""" return ".".join(p for p in name.split(".") if p not in _WRAPPER_PARTS) @dataclass class _muon_state: worker_rank: int process_group: ProcessGroup rank_indices: dict[int, tuple] # local_rank -> per-dim indices rank_numels: dict[int, int] # local_rank -> numel name: str qk_clip_state: torch.Tensor | None = None def _batch_momentum( grads: List[torch.Tensor], momentum_bufs: List[torch.Tensor], momentum: torch.Tensor, ) -> None: """Batched momentum update (no nesterov).""" torch._foreach_mul_(momentum_bufs, momentum) torch._foreach_add_(momentum_bufs, grads) def _batch_momentum_nesterov( grads: List[torch.Tensor], momentum_bufs: List[torch.Tensor], momentum: torch.Tensor, ) -> None: """Batched momentum update with nesterov correction.""" torch._foreach_mul_(momentum_bufs, momentum) torch._foreach_add_(momentum_bufs, grads) nesterov_terms = torch._foreach_mul(momentum_bufs, momentum) torch._foreach_add_(grads, nesterov_terms) _compiled_momentum: dict[bool, callable] = {} _use_momentum_compile = True def set_momentum_compile(enabled: bool): """Toggle torch.compile for batched momentum.""" global _use_momentum_compile _use_momentum_compile = enabled def batch_pre_ortho( grads: List[torch.Tensor], momentum_bufs: List[torch.Tensor], momentum: torch.Tensor, nesterov: bool, ) -> None: """Batched momentum update on lists of plain tensors. Mirrors dion's ``muon_update_pre_orthogonalize``. Inputs must be plain CUDA tensors (not DTensor). Modifies ``momentum_bufs`` and (for nesterov) ``grads`` in-place. When compile is enabled, uses separately compiled functions for nesterov=True/False to avoid graph breaks from the branch. """ fn = _batch_momentum_nesterov if nesterov else _batch_momentum if _use_momentum_compile: if nesterov not in _compiled_momentum: _compiled_momentum[nesterov] = torch.compile(fn) fn = _compiled_momentum[nesterov] fn(grads, momentum_bufs, momentum) def _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay): """Weight-decay + update on plain tensors. Not compiled: per-param @torch.compile caused ~0.25ms TorchDynamo cache lookup per call × 256+ params = massive overhead. The pipeline path uses batched _foreach_* ops instead; this function remains for base() and distributed_muon(). """ p_data.mul_(1 - lr * weight_decay) p_data.add_(u_data, alpha=-adjusted_lr) def update_p(p, u, lr, adjusted_lr, weight_decay): """Apply weight decay and orthogonalized update to parameter. Args: p: Parameter (torch.nn.Parameter or DTensor). u: Orthogonalized update tensor. lr: Base learning rate. adjusted_lr: Size-adjusted learning rate. weight_decay: Weight decay coefficient. """ # Unwrap Parameter -> underlying data tensor. p_data = p.data if isinstance(p, torch.nn.Parameter) else p # Unwrap DTensor -> local CUDA tensor for compiled kernel. if isinstance(p_data, DTensor): p_data = p_data._local_tensor u_data = u._local_tensor if isinstance(u, DTensor) else u _update_p_impl(p_data, u_data, lr, adjusted_lr, weight_decay) def adjust_lr_for_muon(lr, param_shape): """Scale learning rate based on parameter matrix dimensions. Args: lr: Base learning rate. param_shape: Shape of the parameter tensor. Returns: Adjusted learning rate. """ A, B = param_shape[:2] # We adjust the learning rate and weight decay based on the size of the parameter matrix # as described in the paper adjusted_ratio = 0.2 * math.sqrt(max(A, B)) adjusted_lr = lr * adjusted_ratio return adjusted_lr def _match_key(parts, key): """Check if key matches as contiguous components in parts. Single-component keys (e.g. "experts") match any single component. Multi-component keys (e.g. "experts.w1") match as a contiguous subsequence. """ key_parts = key.split(".") key_len = len(key_parts) if key_len == 1: return key in parts return any(parts[i:i + key_len] == key_parts for i in range(len(parts) - key_len + 1)) def is_expert_param(name, expert_keys): """Check if a parameter name matches any expert key (component-level).""" if not expert_keys: return False parts = normalize_fqn(name).split(".") return any(_match_key(parts, key) for key in expert_keys) def default_is_muon(name, x, expert_keys=None): normalized = normalize_fqn(name) parts = normalized.split(".") skip_keys = [ "embed_tokens", "lm_head", "tok_embeddings", "output", "mhc_attn", "mhc_ffn", "lambda_proj", ] if any(key in parts for key in skip_keys): logger.info( "[is_muon] %s (orig: %s): skip (matched skip_key), ndim=%d", normalized, name, x.ndim) return False effective_ndim = x.ndim is_expert = is_expert_param(name, expert_keys) if is_expert: effective_ndim -= 1 result = effective_ndim >= 2 logger.info( "[is_muon] %s (orig: %s): ndim=%d, expert=%s, effective_ndim=%d → %s", normalized, name, x.ndim, is_expert, effective_ndim, "Muon" if result else "AdamW") return result def get_default_muon_param_groups(model, is_muon_func=None, expert_keys=None): if is_muon_func is None: is_muon_func = lambda n, x: default_is_muon(n, x, expert_keys) muon_params, muon_names = [], [] non_muon_params, non_muon_names = [], [] for n, p in model.named_parameters(): if not p.requires_grad: continue if is_muon_func(n, p): muon_params.append(p) muon_names.append(n) else: non_muon_params.append(p) non_muon_names.append(n) logger.info("[param_groups] expert_keys=%s, Muon=%d, AdamW=%d", expert_keys, len(muon_names), len(non_muon_names)) return [ { "params": muon_params, "names": muon_names, "use_muon": True, }, { "params": non_muon_params, "use_muon": False, }, ]