Kernels
TaehyunKim
draft commit for cpu_offload (#23)
10848ab unverified
import logging
from typing import Generator
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
from torch.profiler import record_function
from .core import _muon_state, adjust_lr_for_muon
from .newton_schulz import COMM_DTYPE, zeropower_via_newtonschulz5
from .qk_clip import compute_scales
logger = logging.getLogger(__name__)
# ======================================================================
# Stage helpers
# ======================================================================
def _launch_gather(
params: list[DTensor],
owned_params: list[DTensor],
param_to_state: dict[int, _muon_state],
rank: int,
num_ranks: int,
process_group: dist.ProcessGroup,
) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
"""Allocate gather buffers, build send/recv, and launch async all-to-all.
Returns:
work: Async operation handle.
recv_buf: Flat receive buffer (needed by ``_complete_gather``).
gathered_grads: ``{id(p): empty_tensor}`` for owned params,
``None`` for non-owned.
recv_counts: Per-source-rank element counts.
"""
# Allocate gathered-grad buffers
gathered_grads: dict[int, torch.Tensor | None] = {}
for p in params:
state = param_to_state[id(p)]
if rank == state.worker_rank:
gathered_grads[id(p)] = torch.empty(p.shape,
dtype=COMM_DTYPE,
device="cuda")
else:
gathered_grads[id(p)] = None
# Build send buffer – batch grad copies via torch.cat
# (1-2 fused kernels vs N individual narrow().copy_() calls).
send_counts = [0] * num_ranks
for p in params:
state = param_to_state[id(p)]
send_counts[state.worker_rank] += state.rank_numels[rank]
total_send = sum(send_counts)
if total_send > 0:
# Group grad slices by destination rank in a single pass.
dst_to_grads = [[] for _ in range(num_ranks)]
for p in params:
state = param_to_state[id(p)]
n = state.rank_numels[rank]
if n > 0:
g = p.grad.to_local()
dst_to_grads[state.worker_rank].append(g.reshape(-1))
# Flatten in dst order and cat once.
all_slices = []
for dst in range(num_ranks):
all_slices.extend(dst_to_grads[dst])
send_buf = torch.cat(all_slices)
if send_buf.dtype != COMM_DTYPE:
send_buf = send_buf.to(COMM_DTYPE)
else:
send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
# Build recv buffer
recv_counts = [0] * num_ranks
for src in range(num_ranks):
total = 0
for p in owned_params:
state = param_to_state[id(p)]
assert state.worker_rank == rank
total += state.rank_numels[src]
recv_counts[src] = total
recv_buf = torch.empty(sum(recv_counts), dtype=COMM_DTYPE, device="cuda")
# Launch async all-to-all
logger.debug(f"send_buf size: {send_buf.numel()}, "
f"recv_buf size: {recv_buf.numel()}, "
f"recv_counts: {recv_counts}, "
f"send_counts: {send_counts}, "
f"process_group: {str(process_group)}")
work = dist.all_to_all_single(
recv_buf,
send_buf,
output_split_sizes=recv_counts,
input_split_sizes=send_counts,
group=process_group,
async_op=True,
)
return work, recv_buf, gathered_grads, recv_counts
def _complete_gather(
recv_buf: torch.Tensor,
recv_counts: list[int],
owned_params: list[DTensor],
gathered_grads: dict[int, torch.Tensor | None],
param_to_state: dict[int, _muon_state],
rank: int,
) -> None:
"""Reconstruct gathered grads from the recv buffer (in-place)."""
off = 0
for src in range(len(recv_counts)):
if recv_counts[src] == 0:
continue
block = recv_counts[src]
inner_off = 0
for p in owned_params:
state = param_to_state[id(p)]
assert state.worker_rank == rank
indices = state.rank_indices[src]
shard_view = gathered_grads[id(p)][indices]
n = shard_view.numel()
if n == 0:
continue
sg = recv_buf.narrow(0, off + inner_off, n)
sg = sg.reshape(shard_view.shape)
gathered_grads[id(p)][indices] = sg
inner_off += n
assert inner_off == block
off += block
def _compute_ns(
owned_params: list[DTensor],
gathered_grads: dict[int, torch.Tensor | None],
ns_steps: int,
) -> dict[int, torch.Tensor | None]:
"""Run Newton-Schulz orthogonalization on owned parameters.
Returns:
computed_us: ``{id(p): orthogonalized_update}`` for owned params.
"""
computed_us: dict[int, torch.Tensor | None] = {}
for p in owned_params:
u = zeropower_via_newtonschulz5(gathered_grads[id(p)], ns_steps)
gathered_grads[id(p)] = None # free gathered grad
computed_us[id(p)] = u
return computed_us
def _launch_scatter(
params: list[DTensor],
owned_params: list[DTensor],
param_to_state: dict[int, _muon_state],
rank: int,
num_ranks: int,
process_group: dist.ProcessGroup,
computed_us: dict[int, torch.Tensor | None],
) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor], list[int]]:
"""Allocate scatter buffers, build send/recv, and launch async all-to-all.
Returns:
work: Async operation handle.
recv_buf: Flat receive buffer (needed by ``_complete_scatter``).
scattered_us: Empty dict, populated by ``_complete_scatter`` with
zero-copy views into ``recv_buf``.
recv_counts: Per-source-rank element counts.
"""
# scattered_us is populated by _complete_scatter with zero-copy views
# into recv_buf, avoiding N empty_like allocations + N copy_ calls.
# Pre-seed entries for params whose local shard is empty (rank_numels == 0)
# so _update_params can iterate all params without KeyError.
scattered_us: dict[int, torch.Tensor] = {}
for p in params:
if param_to_state[id(p)].rank_numels[rank] == 0:
scattered_us[id(p)] = torch.empty_like(p.to_local(),
dtype=COMM_DTYPE)
# Build send buffer – batch via torch.cat
# (1 fused kernel vs N*num_ranks individual narrow().copy_() calls).
send_counts = [0] * num_ranks
if owned_params:
for p in owned_params:
state = param_to_state[id(p)]
for dst_rank in range(num_ranks):
send_counts[dst_rank] += state.rank_numels[dst_rank]
total_send = sum(send_counts)
if total_send > 0:
# Cache u_full conversions to avoid redundant .to() per dst_rank.
u_fulls = {}
for p in owned_params:
u_fulls[id(p)] = computed_us[id(p)].to(COMM_DTYPE).contiguous()
# Collect slices in dst order (matches all-to-all send layout).
all_slices = []
for dst_rank in range(num_ranks):
for p in owned_params:
state = param_to_state[id(p)]
su = u_fulls[id(p)][state.rank_indices[dst_rank]].flatten()
if su.numel() > 0:
all_slices.append(su)
send_buf = torch.cat(all_slices) if all_slices else torch.empty(
0, dtype=COMM_DTYPE, device="cuda")
else:
send_buf = torch.empty(0, dtype=COMM_DTYPE, device="cuda")
# Build recv buffer
recv_counts = [0] * num_ranks
for src in range(num_ranks):
total = 0
for p in params:
state = param_to_state[id(p)]
if state.worker_rank != src:
continue
total += state.rank_numels[rank]
recv_counts[src] = total
recv_total = sum(recv_counts)
recv_buf = torch.empty(recv_total, dtype=COMM_DTYPE, device="cuda")
# Launch async all-to-all
work = dist.all_to_all_single(
recv_buf,
send_buf,
output_split_sizes=recv_counts,
input_split_sizes=send_counts,
group=process_group,
async_op=True,
)
return work, recv_buf, scattered_us, recv_counts
def _complete_scatter(
recv_buf: torch.Tensor,
recv_counts: list[int],
params: list[DTensor],
param_to_state: dict[int, _muon_state],
rank: int,
scattered_us: dict[int, torch.Tensor],
) -> None:
"""Populate scattered_us with zero-copy views into recv_buf.
Instead of pre-allocating tensors and copying, we assign views directly
from ``recv_buf``. This eliminates N ``empty_like`` + N ``copy_`` calls.
The underlying storage of ``recv_buf`` is kept alive through the views
until ``scattered_us`` is cleared after ``_update_params``.
"""
off = 0
for src in range(len(recv_counts)):
block = recv_counts[src]
if block == 0:
continue
inner_off = 0
for p in params:
state = param_to_state[id(p)]
if state.worker_rank != src:
continue
n = state.rank_numels[rank]
if n == 0:
continue
scattered_us[id(p)] = recv_buf.narrow(0, off + inner_off,
n).view_as(p.to_local())
inner_off += n
assert inner_off == block
off += block
def _update_params(
params: list[DTensor],
param_to_state: dict[int, _muon_state],
rank: int,
scattered_us: dict[int, torch.Tensor],
lr: float,
weight_decay: float,
) -> None:
"""Apply weight decay, Muon update, and optional QK clipping.
Uses batched ``_foreach_mul_`` for weight decay and batched
``_foreach_add_`` for the Muon update, grouping parameters by
adjusted_lr to minimize kernel launches while preserving float32
precision for the alpha scaling.
"""
if not params:
return
# Batched weight decay: p *= (1 - lr * wd) — single fused kernel.
p_locals = [p._local_tensor for p in params]
torch._foreach_mul_(p_locals, 1.0 - lr * weight_decay)
# Group params by adjusted_lr so _foreach_add_ can use a single
# alpha per group (preserves float32 precision for alpha scaling).
lr_groups: dict[float, tuple[list, list]] = {}
for p in params:
adjusted_lr = adjust_lr_for_muon(lr, p.shape)
if adjusted_lr not in lr_groups:
lr_groups[adjusted_lr] = ([], [])
lr_groups[adjusted_lr][0].append(p._local_tensor)
lr_groups[adjusted_lr][1].append(scattered_us[id(p)])
for adjusted_lr, (p_group, u_group) in lr_groups.items():
torch._foreach_add_(p_group, u_group, alpha=-adjusted_lr)
# QK clipping – applied directly on the local tensor to
# avoid DTensor sharding-propagation issues with _StridedShard.
for p in params:
state = param_to_state[id(p)]
if state.qk_clip_state is None:
continue
scales_full = compute_scales(p, state.qk_clip_state)
if scales_full is not None:
ratio = p.shape[0] // scales_full.shape[0]
idx0 = state.rank_indices[rank][0]
if isinstance(idx0, slice):
start = idx0.start or 0
idx0 = torch.arange(start,
idx0.stop,
device=scales_full.device)
row_scales = scales_full[idx0 // ratio]
p._local_tensor.mul_(row_scales.view(-1, 1))
# ======================================================================
# Pre-launch helper for overlapping first chunk's gather with other work.
# ======================================================================
@torch.no_grad()
def prelaunch_first_gather(
params: list[DTensor],
param_to_state: dict[int, _muon_state],
rank: int,
none_grad: bool,
) -> tuple[dist.Work, torch.Tensor, dict[int, torch.Tensor | None], list[int]]:
"""Launch the first chunk's A2A gather early for overlap with other compute.
Call this *before* expensive GPU work (e.g. batched expert NS) so that
the NCCL all-to-all runs concurrently on the NCCL stream while the
default stream executes compute.
Returns the same 4-tuple that ``_launch_gather`` produces, which should
be passed as ``prelaunch_gather`` to :func:`muon_chunk_pipeline`.
"""
process_group = param_to_state[id(params[0])].process_group
num_ranks = dist.get_world_size(group=process_group)
owned_params = [
p for p in params if param_to_state[id(p)].worker_rank == rank
]
with record_function("muon::prelaunch_gather"):
work, recv_buf, gathered_grads, recv_counts = _launch_gather(
params, owned_params, param_to_state, rank, num_ranks,
process_group)
if none_grad:
for p in params:
p.grad = None
return work, recv_buf, gathered_grads, recv_counts
# ======================================================================
# Main generator – thin orchestrator that wires stages together.
# ======================================================================
@torch.no_grad()
def muon_chunk_pipeline(
params: list[DTensor],
param_to_state: dict[int, _muon_state],
rank: int,
ns_steps: int,
lr: float,
weight_decay: float,
none_grad: bool,
prelaunch_gather: tuple | None = None,
) -> Generator[None, None, None]:
"""Process one chunk of parameters through the full Muon pipeline.
Stages: gather -> compute (Newton-Schulz) -> scatter -> update.
Each ``yield`` lets :func:`run_pipeline` interleave other chunks so
that communication and computation overlap across chunks. Async
communication is launched via ``async_op=True`` and completed after
the yield with ``work.wait()``.
Overlap happens because :func:`run_pipeline` admits one new chunk
per iteration (staggered admission). While chunk *N* does NS
compute on the default CUDA stream, chunk *N+1*'s async all-to-all
runs concurrently on the NCCL stream — no separate ``comm_stream``
is required.
If ``prelaunch_gather`` is provided, the gather was already launched
by :func:`prelaunch_first_gather` and we skip launching it again.
Yields exactly **2** times:
1. After launching async all-to-all gather (or immediately if pre-launched).
2. After launching async all-to-all scatter.
"""
process_group = param_to_state[id(params[0])].process_group
num_ranks = dist.get_world_size(group=process_group)
owned_params = [
p for p in params if param_to_state[id(p)].worker_rank == rank
]
if prelaunch_gather is not None:
# Gather was pre-launched; none_grad already handled by caller.
work, recv_buf, gathered_grads, recv_counts = prelaunch_gather
else:
# Normal path: launch async gather.
with record_function("muon::launch_gather"):
work, recv_buf, gathered_grads, recv_counts = _launch_gather(
params, owned_params, param_to_state, rank, num_ranks,
process_group)
if none_grad:
for p in params:
p.grad = None
yield # --- YIELD 1: other chunks can launch their gather ---
with record_function("muon::wait_gather"):
work.wait()
_complete_gather(recv_buf, recv_counts, owned_params, gathered_grads,
param_to_state, rank)
del recv_buf
# Stage 3: Newton-Schulz orthogonalization.
with record_function("muon::newton_schulz"):
computed_us = _compute_ns(owned_params, gathered_grads, ns_steps)
gathered_grads.clear()
# Stages 4-5: launch async scatter.
with record_function("muon::launch_scatter"):
work, recv_buf, scattered_us, recv_counts = _launch_scatter(
params, owned_params, param_to_state, rank, num_ranks,
process_group, computed_us)
computed_us.clear()
yield # --- YIELD 2: other chunks can launch their scatter ---
with record_function("muon::wait_scatter"):
work.wait()
_complete_scatter(recv_buf, recv_counts, params, param_to_state, rank,
scattered_us)
del recv_buf
# Stage 6: apply parameter updates.
with record_function("muon::update_params"):
_update_params(params, param_to_state, rank, scattered_us, lr,
weight_decay)
scattered_us.clear()