import logging from typing import Generator import torch import torch.distributed as dist from torch.distributed.tensor import DTensor from torch.profiler import record_function from .core import _muon_state, adjust_lr_for_muon from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5 from .qk_clip import compute_scales logger = logging.getLogger(__name__) # ====================================================================== # Stage helpers # ====================================================================== def _launch_gather( params: list[DTensor], owned_params: list[DTensor], param_to_state: dict[int, _muon_state], rank: int, num_ranks: int, process_group: dist.ProcessGroup, ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: """Allocate gather buffers, build send/recv, and launch async all-to-all. Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_gather``). gathered_grads: ``{id(p): empty_tensor}`` for owned params, ``None`` for non-owned. recv_counts: Per-source-rank element counts. """ # Allocate gathered-grad buffers gathered_grads: dict[int, torch.Tensor | None] = {} for p in params: state = param_to_state[id(p)] if rank == state.worker_rank: gathered_grads[id(p)] = torch.empty(p.shape, dtype=COMM_DTYPE, device="cuda") else: gathered_grads[id(p)] = None # Build send buffer – batch grad copies via torch.cat # (1-2 fused kernels vs N individual narrow().copy_() calls). send_counts = [0] * num_ranks for p in params: state = param_to_state[id(p)] send_counts[state.worker_rank] += state.rank_numels[rank] total_send = sum(send_counts) if total_send > 0: # Group grad slices by destination rank in a single pass. dst_to_grads = [[] for _ in range(num_ranks)] for p in params: state = param_to_state[id(p)] n = state.rank_numels[rank] if n > 0: g = p.grad.to_local() dst_to_grads[state.worker_rank].append(g.reshape(-1)) # Flatten in dst order and cat once. all_slices = [] for dst in range(num_ranks): all_slices.extend(dst_to_grads[dst]) send_buf = torch.cat(all_slices) if send_buf.dtype != COMM_DTYPE: send_buf = send_buf.to(COMM_DTYPE) else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks for src in range(num_ranks): total = 0 for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank total += state.rank_numels[src] recv_counts[src] = total recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all logger.debug(f"send_buf size: {send_buf.numel()}, " f"recv_buf size: {recv_buf.numel()}, " f"recv_counts: {recv_counts}, " f"send_counts: {send_counts}, " f"process_group: {str(process_group)}") work = dist.all_to_all_single( recv_buf, send_buf, output_split_sizes=recv_counts, input_split_sizes=send_counts, group=process_group, async_op=True, ) return work, recv_buf, gathered_grads, recv_counts def _complete_gather( recv_buf: torch.Tensor, recv_counts: list[int], owned_params: list[DTensor], gathered_grads: dict[int, torch.Tensor | None], param_to_state: dict[int, _muon_state], rank: int, ) -> None: """Reconstruct gathered grads from the recv buffer (in-place).""" off = 0 for src in range(len(recv_counts)): if recv_counts[src] == 0: continue block = recv_counts[src] inner_off = 0 for p in owned_params: state = param_to_state[id(p)] assert state.worker_rank == rank indices = state.rank_indices[src] shard_view = gathered_grads[id(p)][indices] n = shard_view.numel() if n == 0: continue sg = recv_buf.narrow(0, off + inner_off, n) sg = sg.reshape(shard_view.shape) gathered_grads[id(p)][indices] = sg inner_off += n assert inner_off == block off += block def _compute_ns( owned_params: list[DTensor], gathered_grads: dict[int, torch.Tensor | None], ns_steps: int, ) -> dict[int, torch.Tensor | None]: """Run Newton-Schulz orthogonalization on owned parameters. Returns: computed_us: ``{id(p): orthogonalized_update}`` for owned params. """ computed_us: dict[int, torch.Tensor | None] = {} for p in owned_params: u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps) gathered_grads[id(p)] = None # free gathered grad computed_us[id(p)] = u return computed_us def _launch_scatter( params: list[DTensor], owned_params: list[DTensor], param_to_state: dict[int, _muon_state], rank: int, num_ranks: int, process_group: dist.ProcessGroup, computed_us: dict[int, torch.Tensor | None], ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]: """Allocate scatter buffers, build send/recv, and launch async all-to-all. Returns: work: Async operation handle. recv_buf: Flat receive buffer (needed by ``_complete_scatter``). scattered_us: Empty dict, populated by ``_complete_scatter`` with zero-copy views into ``recv_buf``. recv_counts: Per-source-rank element counts. """ # scattered_us is populated by _complete_scatter with zero-copy views # into recv_buf, avoiding N empty_like allocations + N copy_ calls. # Pre-seed entries for params whose local shard is empty (rank_numels == 0) # so _update_params can iterate all params without KeyError. scattered_us: dict[int, torch.Tensor] = {} for p in params: if param_to_state[id(p)].rank_numels[rank] == 0: scattered_us[id(p)] = torch.empty_like(p.to_local(), dtype=COMM_DTYPE) # Build send buffer – batch via torch.cat # (1 fused kernel vs N*num_ranks individual narrow().copy_() calls). send_counts = [0] * num_ranks if owned_params: for p in owned_params: state = param_to_state[id(p)] for dst_rank in range(num_ranks): send_counts[dst_rank] += state.rank_numels[dst_rank] total_send = sum(send_counts) if total_send > 0: # Cache u_full conversions to avoid redundant .to() per dst_rank. u_fulls = {} for p in owned_params: u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous() # Collect slices in dst order (matches all-to-all send layout). all_slices = [] for dst_rank in range(num_ranks): for p in owned_params: state = param_to_state[id(p)] su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten() if su.numel() > 0: all_slices.append(su) send_buf = torch.cat(all_slices) if all_slices else torch.empty( 0, dtype=COMM_DTYPE, device="cuda") else: send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda") # Build recv buffer recv_counts = [0] * num_ranks for src in range(num_ranks): total = 0 for p in params: state = param_to_state[id(p)] if state.worker_rank != src: continue total += state.rank_numels[rank] recv_counts[src] = total recv_total = sum(recv_counts) recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda") # Launch async all-to-all work = dist.all_to_all_single( recv_buf, send_buf, output_split_sizes=recv_counts, input_split_sizes=send_counts, group=process_group, async_op=True, ) return work, recv_buf, scattered_us, recv_counts def _complete_scatter( recv_buf: torch.Tensor, recv_counts: list[int], params: list[DTensor], param_to_state: dict[int, _muon_state], rank: int, scattered_us: dict[int, torch.Tensor], ) -> None: """Populate scattered_us with zero-copy views into recv_buf. Instead of pre-allocating tensors and copying, we assign views directly from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls. The underlying storage of ``recv_buf`` is kept alive through the views until ``scattered_us`` is cleared after ``_update_params``. """ off = 0 for src in range(len(recv_counts)): block = recv_counts[src] if block == 0: continue inner_off = 0 for p in params: state = param_to_state[id(p)] if state.worker_rank != src: continue n = state.rank_numels[rank] if n == 0: continue scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off, n).view_as(p.to_local()) inner_off += n assert inner_off == block off += block def _update_params( params: list[DTensor], param_to_state: dict[int, _muon_state], rank: int, scattered_us: dict[int, torch.Tensor], lr: float, weight_decay: float, ) -> None: """Apply weight decay, Muon update, and optional QK clipping. Uses batched ``_foreach_mul_`` for weight decay and batched ``_foreach_add_`` for the Muon update, grouping parameters by adjusted_lr to minimize kernel launches while preserving float32 precision for the alpha scaling. """ if not params: return # Batched weight decay: p *= (1 - lr * wd) — single fused kernel. p_locals = [p._local_tensor for p in params] torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay) # Group params by adjusted_lr so _foreach_add_ can use a single # alpha per group (preserves float32 precision for alpha scaling). lr_groups: dict[float, tuple[list, list]] = {} for p in params: adjusted_lr = adjust_lr_for_muon(lr, p.shape) if adjusted_lr not in lr_groups: lr_groups[adjusted_lr] = ([], []) lr_groups[adjusted_lr][0].append(p._local_tensor) lr_groups[adjusted_lr][1].append(scattered_us[id(p)]) for adjusted_lr, (p_group, u_group) in lr_groups.items(): torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr) # QK clipping – applied directly on the local tensor to # avoid DTensor sharding-propagation issues with _StridedShard. for p in params: state = param_to_state[id(p)] if state.qk_clip_state is None: continue scales_full = compute_scales(p, state.qk_clip_state) if scales_full is not None: ratio = p.shape[0] // scales_full.shape[0] idx0 = state.rank_indices[rank][0] if isinstance(idx0, slice): start = idx0.start or 0 idx0 = torch.arange(start, idx0.stop, device=scales_full.device) row_scales = scales_full[idx0 // ratio] p._local_tensor.mul_(row_scales.view(-1, 1)) # ====================================================================== # Pre-launch helper for overlapping first chunk's gather with other work. # ====================================================================== @torch.no_grad() def prelaunch_first_gather( params: list[DTensor], param_to_state: dict[int, _muon_state], rank: int, none_grad: bool, ) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]: """Launch the first chunk's A2A gather early for overlap with other compute. Call this *before* expensive GPU work (e.g. batched expert NS) so that the NCCL all-to-all runs concurrently on the NCCL stream while the default stream executes compute. Returns the same 4-tuple that ``_launch_gather`` produces, which should be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`. """ process_group = param_to_state[id(params[0])].process_group num_ranks = dist.get_world_size(group=process_group) owned_params = [ p for p in params if param_to_state[id(p)].worker_rank == rank ] with record_function("muon::prelaunch_gather"): work, recv_buf, gathered_grads, recv_counts = _launch_gather( params, owned_params, param_to_state, rank, num_ranks, process_group) if none_grad: for p in params: p.grad = None return work, recv_buf, gathered_grads, recv_counts # ====================================================================== # Main generator – thin orchestrator that wires stages together. # ====================================================================== @torch.no_grad() def muon_chunk_pipeline( params: list[DTensor], param_to_state: dict[int, _muon_state], rank: int, ns_steps: int, lr: float, weight_decay: float, none_grad: bool, prelaunch_gather: tuple | None = None, ) -> Generator[None, None, None]: """Process one chunk of parameters through the full Muon pipeline. Stages: gather -> compute (Newton-Schulz) -> scatter -> update. Each ``yield`` lets :func:`run_pipeline` interleave other chunks so that communication and computation overlap across chunks. Async communication is launched via ``async_op=True`` and completed after the yield with ``work.wait()``. Overlap happens because :func:`run_pipeline` admits one new chunk per iteration (staggered admission). While chunk *N* does NS compute on the default CUDA stream, chunk *N+1*'s async all-to-all runs concurrently on the NCCL stream — no separate ``comm_stream`` is required. If ``prelaunch_gather`` is provided, the gather was already launched by :func:`prelaunch_first_gather` and we skip launching it again. Yields exactly **2** times: 1. After launching async all-to-all gather (or immediately if pre-launched). 2. After launching async all-to-all scatter. """ process_group = param_to_state[id(params[0])].process_group num_ranks = dist.get_world_size(group=process_group) owned_params = [ p for p in params if param_to_state[id(p)].worker_rank == rank ] if prelaunch_gather is not None: # Gather was pre-launched; none_grad already handled by caller. work, recv_buf, gathered_grads, recv_counts = prelaunch_gather else: # Normal path: launch async gather. with record_function("muon::launch_gather"): work, recv_buf, gathered_grads, recv_counts = _launch_gather( params, owned_params, param_to_state, rank, num_ranks, process_group) if none_grad: for p in params: p.grad = None yield # --- YIELD 1: other chunks can launch their gather --- with record_function("muon::wait_gather"): work.wait() _complete_gather(recv_buf, recv_counts, owned_params, gathered_grads, param_to_state, rank) del recv_buf # Stage 3: Newton-Schulz orthogonalization. with record_function("muon::newton_schulz"): computed_us = _compute_ns(owned_params, gathered_grads, ns_steps) gathered_grads.clear() # Stages 4-5: launch async scatter. with record_function("muon::launch_scatter"): work, recv_buf, scattered_us, recv_counts = _launch_scatter( params, owned_params, param_to_state, rank, num_ranks, process_group, computed_us) computed_us.clear() yield # --- YIELD 2: other chunks can launch their scatter --- with record_function("muon::wait_scatter"): work.wait() _complete_scatter(recv_buf, recv_counts, params, param_to_state, rank, scattered_us) del recv_buf # Stage 6: apply parameter updates. with record_function("muon::update_params"): _update_params(params, param_to_state, rank, scattered_us, lr, weight_decay) scattered_us.clear()